Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/18/10 12:27:28 (14 years ago)
Author:
gkronber
Message:

Added operators for support vector regression. #1009

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine
Files:
1 added
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelCreator.cs

    r3763 r3842  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2009 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.DataAnalysis;
    2928using System.Threading;
     29using HeuristicLab.LibSVM;
     30using HeuristicLab.Operators;
     31using HeuristicLab.Parameters;
    3032using SVM;
     33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3134
    32 namespace HeuristicLab.SupportVectorMachines {
    33   public class SupportVectorCreator : OperatorBase {
    34     private Thread trainingThread;
    35     private object locker = new object();
    36     private bool abortRequested = false;
     35namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     36  /// <summary>
     37  /// Represents an operator that creates a support vector machine model.
     38  /// </summary>
     39  [StorableClass]
     40  [Item("SupportVectorMachineModelCreator", "Represents an operator that creates a support vector machine model.")]
     41  public class SupportVectorMachineModelCreator : SingleSuccessorOperator {
     42    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
     43    private const string SvmTypeParameterName = "SvmType";
     44    private const string KernelTypeParameterName = "KernelType";
     45    private const string CostParameterName = "Cost";
     46    private const string NuParameterName = "Nu";
     47    private const string GammaParameterName = "Gamma";
     48    private const string ModelParameterName = "SupportVectorMachineModel";
    3749
    38     public SupportVectorCreator()
     50    #region parameter properties
     51    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
     52      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
     53    }
     54    public IValueParameter<StringValue> SvmTypeParameter {
     55      get { return (IValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
     56    }
     57    public IValueParameter<StringValue> KernelTypeParameter {
     58      get { return (IValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
     59    }
     60    public IValueLookupParameter<DoubleValue> NuParameter {
     61      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
     62    }
     63    public IValueLookupParameter<DoubleValue> CostParameter {
     64      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
     65    }
     66    public IValueLookupParameter<DoubleValue> GammaParameter {
     67      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
     68    }
     69    public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter {
     70      get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; }
     71    }
     72    #endregion
     73    #region properties
     74    public DataAnalysisProblemData DataAnalysisProblemData {
     75      get { return DataAnalysisProblemDataParameter.ActualValue; }
     76    }
     77    public StringValue SvmType {
     78      get { return SvmTypeParameter.Value; }
     79    }
     80    public StringValue KernelType {
     81      get { return KernelTypeParameter.Value; }
     82    }
     83    public DoubleValue Nu {
     84      get { return NuParameter.ActualValue; }
     85    }
     86    public DoubleValue Cost {
     87      get { return CostParameter.ActualValue; }
     88    }
     89    public DoubleValue Gamma {
     90      get { return GammaParameter.ActualValue; }
     91    }
     92    #endregion
     93
     94    public SupportVectorMachineModelCreator()
    3995      : base() {
    40       //Dataset infos
    41       AddVariableInfo(new VariableInfo("Dataset", "Dataset with all samples on which to apply the function", typeof(Dataset), VariableKind.In));
    42       AddVariableInfo(new VariableInfo("TargetVariable", "Name of the target variable", typeof(StringData), VariableKind.In));
    43       AddVariableInfo(new VariableInfo("InputVariables", "List of allowed input variable names", typeof(ItemList), VariableKind.In));
    44       AddVariableInfo(new VariableInfo("SamplesStart", "Start index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    45       AddVariableInfo(new VariableInfo("SamplesEnd", "End index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    46       AddVariableInfo(new VariableInfo("MaxTimeOffset", "(optional) Maximal time offset for time-series prognosis", typeof(IntData), VariableKind.In));
    47       AddVariableInfo(new VariableInfo("MinTimeOffset", "(optional) Minimal time offset for time-series prognosis", typeof(IntData), VariableKind.In));
    48 
    49       //SVM parameters
    50       AddVariableInfo(new VariableInfo("SVMType", "String describing which SVM type is used. Valid inputs are: C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR",
    51         typeof(StringData), VariableKind.In));
    52       AddVariableInfo(new VariableInfo("SVMKernelType", "String describing which SVM kernel is used. Valid inputs are: LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED",
    53         typeof(StringData), VariableKind.In));
    54       AddVariableInfo(new VariableInfo("SVMCost", "Cost parameter (C) of C-SVC, epsilon-SVR and nu-SVR", typeof(DoubleData), VariableKind.In));
    55       AddVariableInfo(new VariableInfo("SVMNu", "Nu parameter of nu-SVC, one-class SVM and nu-SVR", typeof(DoubleData), VariableKind.In));
    56       AddVariableInfo(new VariableInfo("SVMGamma", "Gamma parameter in kernel function", typeof(DoubleData), VariableKind.In));
    57       AddVariableInfo(new VariableInfo("SVMModel", "Represent the model learned by the SVM", typeof(SVMModel), VariableKind.New | VariableKind.Out));
     96      #region svm types
     97      StringValue cSvcType = new StringValue("C_SVC").AsReadOnly();
     98      StringValue nuSvcType = new StringValue("NU_SVC").AsReadOnly();
     99      StringValue eSvrType = new StringValue("EPSILON_SVR").AsReadOnly();
     100      StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly();
     101      ItemSet<StringValue> allowedSvmTypes = new ItemSet<StringValue>();
     102      allowedSvmTypes.Add(cSvcType);
     103      allowedSvmTypes.Add(nuSvcType);
     104      allowedSvmTypes.Add(eSvrType);
     105      allowedSvmTypes.Add(nuSvrType);
     106      #endregion
     107      #region kernel types
     108      StringValue rbfKernelType = new StringValue("RBF").AsReadOnly();
     109      StringValue linearKernelType = new StringValue("LINEAR").AsReadOnly();
     110      StringValue polynomialKernelType = new StringValue("POLY").AsReadOnly();
     111      StringValue sigmoidKernelType = new StringValue("SIGMOID").AsReadOnly();
     112      ItemSet<StringValue> allowedKernelTypes = new ItemSet<StringValue>();
     113      allowedKernelTypes.Add(rbfKernelType);
     114      allowedKernelTypes.Add(linearKernelType);
     115      allowedKernelTypes.Add(polynomialKernelType);
     116      allowedKernelTypes.Add(sigmoidKernelType);
     117      #endregion
     118      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
     119      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", allowedSvmTypes, nuSvrType));
     120      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", allowedKernelTypes, rbfKernelType));
     121      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
     122      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
     123      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
     124      Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM."));
    58125    }
    59126
    60     public override void Abort() {
    61       abortRequested = true;
    62       lock (locker) {
    63         if (trainingThread != null && trainingThread.ThreadState == ThreadState.Running) {
    64           trainingThread.Abort();
    65         }
    66       }
     127    public override IOperation Apply() {
     128
     129      SupportVectorMachineModel model = TrainModel(DataAnalysisProblemData,
     130                             SvmType.Value, KernelType.Value,
     131                             Cost.Value, Nu.Value, Gamma.Value);
     132      SupportVectorMachineModelParameter.ActualValue = model;
     133
     134      return base.Apply();
    67135    }
    68136
    69     public override IOperation Apply(IScope scope) {
    70       abortRequested = false;
    71       Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true);
    72       string targetVariable = GetVariableValue<StringData>("TargetVariable", scope, true).Data;
    73       ItemList inputVariables = GetVariableValue<ItemList>("InputVariables", scope, true);
    74       var inputVariableNames = from x in inputVariables
    75                                select ((StringData)x).Data;
    76       int start = GetVariableValue<IntData>("SamplesStart", scope, true).Data;
    77       int end = GetVariableValue<IntData>("SamplesEnd", scope, true).Data;
    78       IntData maxTimeOffsetData = GetVariableValue<IntData>("MaxTimeOffset", scope, true, false);
    79       int maxTimeOffset = maxTimeOffsetData == null ? 0 : maxTimeOffsetData.Data;
    80       IntData minTimeOffsetData = GetVariableValue<IntData>("MinTimeOffset", scope, true, false);
    81       int minTimeOffset = minTimeOffsetData == null ? 0 : minTimeOffsetData.Data;
    82       string svmType = GetVariableValue<StringData>("SVMType", scope, true).Data;
    83       string svmKernelType = GetVariableValue<StringData>("SVMKernelType", scope, true).Data;
    84 
    85       double cost = GetVariableValue<DoubleData>("SVMCost", scope, true).Data;
    86       double nu = GetVariableValue<DoubleData>("SVMNu", scope, true).Data;
    87       double gamma = GetVariableValue<DoubleData>("SVMGamma", scope, true).Data;
    88 
    89       SVMModel modelData = null;
    90       lock (locker) {
    91         if (!abortRequested) {
    92           trainingThread = new Thread(() => {
    93             modelData = TrainModel(dataset, targetVariable, inputVariableNames,
    94                                    start, end, minTimeOffset, maxTimeOffset,
    95                                    svmType, svmKernelType,
    96                                    cost, nu, gamma);
    97           });
    98           trainingThread.Start();
    99         }
    100       }
    101       if (!abortRequested) {
    102         trainingThread.Join();
    103         trainingThread = null;
    104       }
    105 
    106 
    107       if (!abortRequested) {
    108         //persist variables in scope
    109         scope.AddVariable(new Variable(scope.TranslateName("SVMModel"), modelData));
    110         return null;
    111       } else {
    112         return new AtomicOperation(this, scope);
    113       }
    114     }
    115 
    116     public static SVMModel TrainRegressionModel(
    117       Dataset dataset, string targetVariable, IEnumerable<string> inputVariables,
    118       int start, int end,
    119       double cost, double nu, double gamma) {
    120       return TrainModel(dataset, targetVariable, inputVariables, start, end, 0, 0, "NU_SVR", "RBF", cost, nu, gamma);
    121     }
    122 
    123     public static SVMModel TrainModel(
    124       Dataset dataset, string targetVariable, IEnumerable<string> inputVariables,
    125       int start, int end,
    126       int minTimeOffset, int maxTimeOffset,
     137    public static SupportVectorMachineModel TrainModel(
     138      DataAnalysisProblemData problemData,
    127139      string svmType, string kernelType,
    128140      double cost, double nu, double gamma) {
    129       int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
     141      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
    130142
    131143      //extract SVM parameters from scope and set them
     
    140152
    141153
    142       SVM.Problem problem = SVMHelper.CreateSVMProblem(dataset, targetVariableIndex, inputVariables, start, end, minTimeOffset, maxTimeOffset);
     154      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value);
    143155      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
    144       SVM.Problem scaledProblem = rangeTransform.Scale(problem);
    145       var model = new SVMModel();
     156      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
     157      var model = new SupportVectorMachineModel();
    146158
    147159      model.Model = SVM.Training.Train(scaledProblem, parameter);
Note: See TracChangeset for help on using the changeset viewer.