Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/18/10 12:27:28 (15 years ago)
Author:
gkronber
Message:

Added operators for support vector regression. #1009

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine
Files:
1 added
4 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModel.cs

    r3763 r3842  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2727using System.Globalization;
    2828using System.IO;
    29 using HeuristicLab.Modeling;
     29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     30using HeuristicLab.Common;
    3031
    31 namespace HeuristicLab.SupportVectorMachines {
    32   public class SVMModel : ItemBase {
     32namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     33  /// <summary>
     34  /// Represents a support vector machine model.
     35  /// </summary>
     36  [StorableClass]
     37  [Item("SupportVectorMachineModel", "Represents a support vector machine model.")]
     38  public class SupportVectorMachineModel : NamedItem {
    3339    private SVM.Model model;
    3440    /// <summary>
     
    4955    }
    5056
    51     public override IView CreateView() {
    52       return new SVMModelView(this);
     57    #region persistence
     58    [Storable]
     59    private string ModelAsString {
     60      get {
     61        using (MemoryStream stream = new MemoryStream()) {
     62          SVM.Model.Write(stream, Model);
     63          stream.Seek(0, System.IO.SeekOrigin.Begin);
     64          StreamReader reader = new StreamReader(stream);
     65          return reader.ReadToEnd();
     66        }
     67      }
     68      set {
     69        using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(value))) {
     70          model = SVM.Model.Read(stream);
     71        }
     72      }
    5373    }
     74    [Storable]
     75    private string RangeTransformAsString {
     76      get {
     77        using (MemoryStream stream = new MemoryStream()) {
     78          SVM.RangeTransform.Write(stream, RangeTransform);
     79          stream.Seek(0, System.IO.SeekOrigin.Begin);
     80          StreamReader reader = new StreamReader(stream);
     81          return reader.ReadToEnd();
     82        }
     83      }
     84      set {
     85        using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(value))) {
     86          RangeTransform = SVM.RangeTransform.Read(stream);
     87        }
     88      }
     89    }
     90    #endregion
    5491
    55     /// <summary>
    56     /// Clones the current instance and adds it to the dictionary <paramref name="clonedObjects"/>.
    57     /// </summary>
    58     /// <param name="clonedObjects">Dictionary of all already cloned objects.</param>
    59     /// <returns>The cloned instance as <see cref="DoubleData"/>.</returns>
    60     public override object Clone(IDictionary<Guid, object> clonedObjects) {
    61       SVMModel clone = new SVMModel();
    62       clonedObjects.Add(Guid, clone);
     92    public override IDeepCloneable Clone(Cloner cloner) {
     93      SupportVectorMachineModel clone = (SupportVectorMachineModel)base.Clone(cloner);
    6394      // beware we are only using a shallow copy here! (gkronber)
    64       clone.Model = Model;
    65       clone.RangeTransform = RangeTransform;
     95      clone.model = model;
     96      clone.rangeTransform = rangeTransform;
    6697      return clone;
    6798    }
    6899
    69100    /// <summary>
    70     /// Saves the current instance as <see cref="XmlNode"/> in the specified <paramref name="document"/>.
     101    ///  Exports the <paramref name="model"/> in string representation to output stream <paramref name="s"/>
    71102    /// </summary>
    72     /// <remarks>The actual model is saved in the node's inner text as string,
    73     /// its format depending on the local culture info and its number format.</remarks>
    74     /// <param name="name">The (tag)name of the <see cref="XmlNode"/>.</param>
    75     /// <param name="document">The <see cref="XmlDocument"/> where the data is saved.</param>
    76     /// <param name="persistedObjects">A dictionary of all already persisted objects. (Needed to avoid cycles.)</param>
    77     /// <returns>The saved <see cref="XmlNode"/>.</returns>
    78     public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
    79       XmlNode node = base.GetXmlNode(name, document, persistedObjects);
    80       XmlNode model = document.CreateElement("Model");
    81       using (MemoryStream stream = new MemoryStream()) {
    82         SVM.Model.Write(stream, Model);
    83         stream.Seek(0, System.IO.SeekOrigin.Begin);
    84         StreamReader reader = new StreamReader(stream);
    85         model.InnerText = reader.ReadToEnd();
    86         node.AppendChild(model);
    87       }
    88 
    89       XmlNode rangeTransform = document.CreateElement("RangeTransform");
    90       using (MemoryStream stream = new MemoryStream()) {
    91         SVM.RangeTransform.Write(stream, RangeTransform);
    92         stream.Seek(0, System.IO.SeekOrigin.Begin);
    93         StreamReader reader = new StreamReader(stream);
    94         rangeTransform.InnerText = reader.ReadToEnd();
    95         node.AppendChild(rangeTransform);
    96       }
    97 
    98       return node;
    99     }
    100     /// <summary>
    101     /// Loads the persisted SVM model from the specified <paramref name="node"/>.
    102     /// </summary>
    103     /// <remarks>The serialized SVM model must be saved in the node's inner text as a string 
    104     /// (see <see cref="GetXmlNode"/>).</remarks>
    105     /// <param name="node">The <see cref="XmlNode"/> where the SVM model is saved.</param>
    106     /// <param name="restoredObjects">A dictionary of all already restored objects. (Needed to avoid cycles.)</param>
    107     public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
    108       base.Populate(node, restoredObjects);
    109       XmlNode model = node.SelectSingleNode("Model");
    110       using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(model.InnerText))) {
    111         Model = SVM.Model.Read(stream);
    112       }
    113       XmlNode rangeTransform = node.SelectSingleNode("RangeTransform");
    114       using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(rangeTransform.InnerText))) {
    115         RangeTransform = SVM.RangeTransform.Read(stream);
    116       }
    117     }
    118 
    119     public static void Export(SVMModel model, Stream s) {
     103    /// <param name="model">The support vector regression model to export</param>
     104    /// <param name="s">The output stream to export the model to</param>
     105    public static void Export(SupportVectorMachineModel model, Stream s) {
    120106      StreamWriter writer = new StreamWriter(s);
    121107      writer.WriteLine("RangeTransform:");
     
    136122    }
    137123
    138     public static SVMModel Import(TextReader reader) {
    139       SVMModel model = new SVMModel();
     124    /// <summary>
     125    /// Imports a support vector machine model given as string representation.
     126    /// </summary>
     127    /// <param name="reader">The reader to retrieve the string representation from</param>
     128    /// <returns>The imported support vector machine model.</returns>
     129    public static SupportVectorMachineModel Import(TextReader reader) {
     130      SupportVectorMachineModel model = new SupportVectorMachineModel();
    140131      while (reader.ReadLine().Trim() != "RangeTransform:") ; // read until line "RangeTransform";
    141132      model.RangeTransform = SVM.RangeTransform.Read(reader);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelCreator.cs

    r3763 r3842  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2009 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.DataAnalysis;
    2928using System.Threading;
     29using HeuristicLab.LibSVM;
     30using HeuristicLab.Operators;
     31using HeuristicLab.Parameters;
    3032using SVM;
     33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3134
    32 namespace HeuristicLab.SupportVectorMachines {
    33   public class SupportVectorCreator : OperatorBase {
    34     private Thread trainingThread;
    35     private object locker = new object();
    36     private bool abortRequested = false;
     35namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     36  /// <summary>
     37  /// Represents an operator that creates a support vector machine model.
     38  /// </summary>
     39  [StorableClass]
     40  [Item("SupportVectorMachineModelCreator", "Represents an operator that creates a support vector machine model.")]
     41  public class SupportVectorMachineModelCreator : SingleSuccessorOperator {
     42    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
     43    private const string SvmTypeParameterName = "SvmType";
     44    private const string KernelTypeParameterName = "KernelType";
     45    private const string CostParameterName = "Cost";
     46    private const string NuParameterName = "Nu";
     47    private const string GammaParameterName = "Gamma";
     48    private const string ModelParameterName = "SupportVectorMachineModel";
    3749
    38     public SupportVectorCreator()
     50    #region parameter properties
     51    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
     52      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
     53    }
     54    public IValueParameter<StringValue> SvmTypeParameter {
     55      get { return (IValueParameter<StringValue>)Parameters[SvmTypeParameterName]; }
     56    }
     57    public IValueParameter<StringValue> KernelTypeParameter {
     58      get { return (IValueParameter<StringValue>)Parameters[KernelTypeParameterName]; }
     59    }
     60    public IValueLookupParameter<DoubleValue> NuParameter {
     61      get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; }
     62    }
     63    public IValueLookupParameter<DoubleValue> CostParameter {
     64      get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; }
     65    }
     66    public IValueLookupParameter<DoubleValue> GammaParameter {
     67      get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; }
     68    }
     69    public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter {
     70      get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; }
     71    }
     72    #endregion
     73    #region properties
     74    public DataAnalysisProblemData DataAnalysisProblemData {
     75      get { return DataAnalysisProblemDataParameter.ActualValue; }
     76    }
     77    public StringValue SvmType {
     78      get { return SvmTypeParameter.Value; }
     79    }
     80    public StringValue KernelType {
     81      get { return KernelTypeParameter.Value; }
     82    }
     83    public DoubleValue Nu {
     84      get { return NuParameter.ActualValue; }
     85    }
     86    public DoubleValue Cost {
     87      get { return CostParameter.ActualValue; }
     88    }
     89    public DoubleValue Gamma {
     90      get { return GammaParameter.ActualValue; }
     91    }
     92    #endregion
     93
     94    public SupportVectorMachineModelCreator()
    3995      : base() {
    40       //Dataset infos
    41       AddVariableInfo(new VariableInfo("Dataset", "Dataset with all samples on which to apply the function", typeof(Dataset), VariableKind.In));
    42       AddVariableInfo(new VariableInfo("TargetVariable", "Name of the target variable", typeof(StringData), VariableKind.In));
    43       AddVariableInfo(new VariableInfo("InputVariables", "List of allowed input variable names", typeof(ItemList), VariableKind.In));
    44       AddVariableInfo(new VariableInfo("SamplesStart", "Start index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    45       AddVariableInfo(new VariableInfo("SamplesEnd", "End index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    46       AddVariableInfo(new VariableInfo("MaxTimeOffset", "(optional) Maximal time offset for time-series prognosis", typeof(IntData), VariableKind.In));
    47       AddVariableInfo(new VariableInfo("MinTimeOffset", "(optional) Minimal time offset for time-series prognosis", typeof(IntData), VariableKind.In));
    48 
    49       //SVM parameters
    50       AddVariableInfo(new VariableInfo("SVMType", "String describing which SVM type is used. Valid inputs are: C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR",
    51         typeof(StringData), VariableKind.In));
    52       AddVariableInfo(new VariableInfo("SVMKernelType", "String describing which SVM kernel is used. Valid inputs are: LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED",
    53         typeof(StringData), VariableKind.In));
    54       AddVariableInfo(new VariableInfo("SVMCost", "Cost parameter (C) of C-SVC, epsilon-SVR and nu-SVR", typeof(DoubleData), VariableKind.In));
    55       AddVariableInfo(new VariableInfo("SVMNu", "Nu parameter of nu-SVC, one-class SVM and nu-SVR", typeof(DoubleData), VariableKind.In));
    56       AddVariableInfo(new VariableInfo("SVMGamma", "Gamma parameter in kernel function", typeof(DoubleData), VariableKind.In));
    57       AddVariableInfo(new VariableInfo("SVMModel", "Represent the model learned by the SVM", typeof(SVMModel), VariableKind.New | VariableKind.Out));
     96      #region svm types
     97      StringValue cSvcType = new StringValue("C_SVC").AsReadOnly();
     98      StringValue nuSvcType = new StringValue("NU_SVC").AsReadOnly();
     99      StringValue eSvrType = new StringValue("EPSILON_SVR").AsReadOnly();
     100      StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly();
     101      ItemSet<StringValue> allowedSvmTypes = new ItemSet<StringValue>();
     102      allowedSvmTypes.Add(cSvcType);
     103      allowedSvmTypes.Add(nuSvcType);
     104      allowedSvmTypes.Add(eSvrType);
     105      allowedSvmTypes.Add(nuSvrType);
     106      #endregion
     107      #region kernel types
     108      StringValue rbfKernelType = new StringValue("RBF").AsReadOnly();
     109      StringValue linearKernelType = new StringValue("LINEAR").AsReadOnly();
     110      StringValue polynomialKernelType = new StringValue("POLY").AsReadOnly();
     111      StringValue sigmoidKernelType = new StringValue("SIGMOID").AsReadOnly();
     112      ItemSet<StringValue> allowedKernelTypes = new ItemSet<StringValue>();
     113      allowedKernelTypes.Add(rbfKernelType);
     114      allowedKernelTypes.Add(linearKernelType);
     115      allowedKernelTypes.Add(polynomialKernelType);
     116      allowedKernelTypes.Add(sigmoidKernelType);
     117      #endregion
     118      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
     119      Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", allowedSvmTypes, nuSvrType));
     120      Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", allowedKernelTypes, rbfKernelType));
     121      Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR."));
     122      Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR."));
     123      Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function."));
     124      Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM."));
    58125    }
    59126
    60     public override void Abort() {
    61       abortRequested = true;
    62       lock (locker) {
    63         if (trainingThread != null && trainingThread.ThreadState == ThreadState.Running) {
    64           trainingThread.Abort();
    65         }
    66       }
     127    public override IOperation Apply() {
     128
     129      SupportVectorMachineModel model = TrainModel(DataAnalysisProblemData,
     130                             SvmType.Value, KernelType.Value,
     131                             Cost.Value, Nu.Value, Gamma.Value);
     132      SupportVectorMachineModelParameter.ActualValue = model;
     133
     134      return base.Apply();
    67135    }
    68136
    69     public override IOperation Apply(IScope scope) {
    70       abortRequested = false;
    71       Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true);
    72       string targetVariable = GetVariableValue<StringData>("TargetVariable", scope, true).Data;
    73       ItemList inputVariables = GetVariableValue<ItemList>("InputVariables", scope, true);
    74       var inputVariableNames = from x in inputVariables
    75                                select ((StringData)x).Data;
    76       int start = GetVariableValue<IntData>("SamplesStart", scope, true).Data;
    77       int end = GetVariableValue<IntData>("SamplesEnd", scope, true).Data;
    78       IntData maxTimeOffsetData = GetVariableValue<IntData>("MaxTimeOffset", scope, true, false);
    79       int maxTimeOffset = maxTimeOffsetData == null ? 0 : maxTimeOffsetData.Data;
    80       IntData minTimeOffsetData = GetVariableValue<IntData>("MinTimeOffset", scope, true, false);
    81       int minTimeOffset = minTimeOffsetData == null ? 0 : minTimeOffsetData.Data;
    82       string svmType = GetVariableValue<StringData>("SVMType", scope, true).Data;
    83       string svmKernelType = GetVariableValue<StringData>("SVMKernelType", scope, true).Data;
    84 
    85       double cost = GetVariableValue<DoubleData>("SVMCost", scope, true).Data;
    86       double nu = GetVariableValue<DoubleData>("SVMNu", scope, true).Data;
    87       double gamma = GetVariableValue<DoubleData>("SVMGamma", scope, true).Data;
    88 
    89       SVMModel modelData = null;
    90       lock (locker) {
    91         if (!abortRequested) {
    92           trainingThread = new Thread(() => {
    93             modelData = TrainModel(dataset, targetVariable, inputVariableNames,
    94                                    start, end, minTimeOffset, maxTimeOffset,
    95                                    svmType, svmKernelType,
    96                                    cost, nu, gamma);
    97           });
    98           trainingThread.Start();
    99         }
    100       }
    101       if (!abortRequested) {
    102         trainingThread.Join();
    103         trainingThread = null;
    104       }
    105 
    106 
    107       if (!abortRequested) {
    108         //persist variables in scope
    109         scope.AddVariable(new Variable(scope.TranslateName("SVMModel"), modelData));
    110         return null;
    111       } else {
    112         return new AtomicOperation(this, scope);
    113       }
    114     }
    115 
    116     public static SVMModel TrainRegressionModel(
    117       Dataset dataset, string targetVariable, IEnumerable<string> inputVariables,
    118       int start, int end,
    119       double cost, double nu, double gamma) {
    120       return TrainModel(dataset, targetVariable, inputVariables, start, end, 0, 0, "NU_SVR", "RBF", cost, nu, gamma);
    121     }
    122 
    123     public static SVMModel TrainModel(
    124       Dataset dataset, string targetVariable, IEnumerable<string> inputVariables,
    125       int start, int end,
    126       int minTimeOffset, int maxTimeOffset,
     137    public static SupportVectorMachineModel TrainModel(
     138      DataAnalysisProblemData problemData,
    127139      string svmType, string kernelType,
    128140      double cost, double nu, double gamma) {
    129       int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
     141      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
    130142
    131143      //extract SVM parameters from scope and set them
     
    140152
    141153
    142       SVM.Problem problem = SVMHelper.CreateSVMProblem(dataset, targetVariableIndex, inputVariables, start, end, minTimeOffset, maxTimeOffset);
     154      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value);
    143155      SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem);
    144       SVM.Problem scaledProblem = rangeTransform.Scale(problem);
    145       var model = new SVMModel();
     156      SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem);
     157      var model = new SupportVectorMachineModel();
    146158
    147159      model.Model = SVM.Training.Train(scaledProblem, parameter);
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelEvaluator.cs

    r3763 r3842  
    2626using HeuristicLab.Core;
    2727using HeuristicLab.Data;
    28 using HeuristicLab.DataAnalysis;
    2928using SVM;
     29using HeuristicLab.Operators;
     30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     31using HeuristicLab.Parameters;
    3032
    31 namespace HeuristicLab.SupportVectorMachines {
    32   public class SupportVectorEvaluator : OperatorBase {
     33namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     34  [StorableClass]
     35  [Item("SupportVectorMachineModelEvaluator", "Represents a operator that evaluates a support vector machine model on a data set.")]
     36  public class SupportVectorMachineModelEvaluator : SingleSuccessorOperator {
     37    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
     38    private const string ModelParameterName = "SupportVectorMachineModel";
     39    private const string SamplesStartParameterName = "SamplesStart";
     40    private const string SamplesEndParameterName = "SamplesEnd";
     41    private const string ValuesParameterName = "Values";
    3342
    34     public SupportVectorEvaluator()
     43    #region parameter properties
     44    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
     45      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
     46    }
     47    public IValueLookupParameter<IntValue> SamplesStartParameter {
     48      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
     49    }
     50    public IValueLookupParameter<IntValue> SamplesEndParameter {
     51      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
     52    }
     53    public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter {
     54      get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; }
     55    }
     56    public ILookupParameter<DoubleMatrix> ValuesParameter {
     57      get { return (ILookupParameter<DoubleMatrix>)Parameters[ValuesParameterName]; }
     58    }
     59    #endregion
     60    #region properties
     61    public DataAnalysisProblemData DataAnalysisProblemData {
     62      get { return DataAnalysisProblemDataParameter.ActualValue; }
     63    }
     64    public SupportVectorMachineModel SupportVectorMachineModel {
     65      get { return SupportVectorMachineModelParameter.ActualValue; }
     66    }
     67    public IntValue SamplesStart {
     68      get { return SamplesStartParameter.ActualValue; }
     69    }
     70    public IntValue SamplesEnd {
     71      get { return SamplesEndParameter.ActualValue; }
     72    }
     73    #endregion
     74    public SupportVectorMachineModelEvaluator()
    3575      : base() {
    36       //Dataset infos
    37       AddVariableInfo(new VariableInfo("Dataset", "Dataset with all samples on which to apply the function", typeof(Dataset), VariableKind.In));
    38       AddVariableInfo(new VariableInfo("TargetVariable", "Name of the target variable", typeof(StringData), VariableKind.In));
    39       AddVariableInfo(new VariableInfo("InputVariables", "List of allowed input variable names", typeof(ItemList), VariableKind.In));
    40       AddVariableInfo(new VariableInfo("SamplesStart", "Start index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    41       AddVariableInfo(new VariableInfo("SamplesEnd", "End index of samples in dataset to evaluate", typeof(IntData), VariableKind.In));
    42       AddVariableInfo(new VariableInfo("MaxTimeOffset", "(optional) Maximal allowed time offset for input variables", typeof(IntData), VariableKind.In));
    43       AddVariableInfo(new VariableInfo("MinTimeOffset", "(optional) Minimal allowed time offset for input variables", typeof(IntData), VariableKind.In));
    44       AddVariableInfo(new VariableInfo("SVMModel", "Represent the model learned by the SVM", typeof(SVMModel), VariableKind.In));
    45       AddVariableInfo(new VariableInfo("Values", "Target vs predicted values", typeof(DoubleMatrixData), VariableKind.New | VariableKind.Out));
     76      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
     77      Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM."));
     78      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition on which the SVM model should be evaluated."));
     79      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition on which the SVM model should be evaluated."));
     80      Parameters.Add(new LookupParameter<DoubleMatrix>(ValuesParameterName, "A matrix of original values of the target variable and output values of the SVM model."));
    4681    }
    4782
     83    public override IOperation Apply() {
     84      int targetVariableIndex = DataAnalysisProblemData.Dataset.GetVariableIndex(DataAnalysisProblemData.TargetVariable.Value);
     85      int start = SamplesStart.Value;
     86      int end = SamplesEnd.Value;
    4887
    49     public override IOperation Apply(IScope scope) {
    50       Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true);
    51       ItemList inputVariables = GetVariableValue<ItemList>("InputVariables", scope, true);
    52       var inputVariableNames = from x in inputVariables
    53                                select ((StringData)x).Data;
    54       string targetVariable = GetVariableValue<StringData>("TargetVariable", scope, true).Data;
    55       int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
    56       int start = GetVariableValue<IntData>("SamplesStart", scope, true).Data;
    57       int end = GetVariableValue<IntData>("SamplesEnd", scope, true).Data;
    58       IntData minTimeOffsetData = GetVariableValue<IntData>("MinTimeOffset", scope, true, false);
    59       int minTimeOffset = minTimeOffsetData == null ? 0 : minTimeOffsetData.Data;
    60       IntData maxTimeOffsetData = GetVariableValue<IntData>("MaxTimeOffset", scope, true, false);
    61       int maxTimeOffset = maxTimeOffsetData == null ? 0 : maxTimeOffsetData.Data;
    62       SVMModel modelData = GetVariableValue<SVMModel>("SVMModel", scope, true);
    63 
    64       SVM.Problem problem = SVMHelper.CreateSVMProblem(dataset, targetVariableIndex, inputVariableNames, start, end, minTimeOffset, maxTimeOffset);
    65       SVM.Problem scaledProblem = modelData.RangeTransform.Scale(problem);
     88      SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(DataAnalysisProblemData, start, end);
     89      SVM.Problem scaledProblem = SupportVectorMachineModel.RangeTransform.Scale(problem);
    6690
    6791      double[,] values = new double[scaledProblem.Count, 2];
    6892      for (int i = 0; i < scaledProblem.Count; i++) {
    69         values[i, 0] = dataset.GetValue(start + i, targetVariableIndex);
    70         values[i, 1] = SVM.Prediction.Predict(modelData.Model, scaledProblem.X[i]);
     93        values[i, 0] = DataAnalysisProblemData.Dataset[start + i, targetVariableIndex];
     94        values[i, 1] = SVM.Prediction.Predict(SupportVectorMachineModel.Model, scaledProblem.X[i]);
    7195      }
    7296
    73       scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName("Values"), new DoubleMatrixData(values)));
    74       return null;
     97      ValuesParameter.ActualValue = new DoubleMatrix(values);
     98      return base.Apply();
    7599    }
    76100  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineUtil.cs

    r3763 r3842  
    1 using System;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
     5 * This file is part of HeuristicLab.
     6 *
     7 * HeuristicLab is free software: you can redistribute it and/or modify
     8 * it under the terms of the GNU General Public License as published by
     9 * the Free Software Foundation, either version 3 of the License, or
     10 * (at your option) any later version.
     11 *
     12 * HeuristicLab is distributed in the hope that it will be useful,
     13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15 * GNU General Public License for more details.
     16 *
     17 * You should have received a copy of the GNU General Public License
     18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     19 */
     20#endregion
     21
     22using System;
    223using System.Collections.Generic;
    324using System.Linq;
     
    526using HeuristicLab.Core;
    627using HeuristicLab.Data;
    7 using HeuristicLab.DataAnalysis;
    828using HeuristicLab.Common;
    929
    10 namespace HeuristicLab.SupportVectorMachines {
    11   public class SVMHelper {
    12     public static SVM.Problem CreateSVMProblem(Dataset dataset, int targetVariableIndex, IEnumerable<string> inputVariables, int start, int end, int minTimeOffset, int maxTimeOffset) {
     30namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine {
     31  public class SupportVectorMachineUtil {
     32    /// <summary>
     33    /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM.
     34    /// </summary>
     35    /// <param name="problemData">The problem data to transform</param>
     36    /// <param name="start">The index of the first row of <paramref name="problemData"/> to copy to the output.</param>
     37    /// <param name="end">The last of the first row of <paramref name="problemData"/> to copy to the output.</param>
     38    /// <returns>A problem data type that can be used to train a support vector machine.</returns>
     39    public static SVM.Problem CreateSvmProblem(DataAnalysisProblemData problemData, int start, int end) {
    1340      int rowCount = end - start;
    14 
    15       var targetVector = (from row in Enumerable.Range(start, rowCount)
    16                           let val = dataset.GetValue(row, targetVariableIndex)
    17                           where !double.IsNaN(val)
    18                           select val).ToArray();
    19 
     41      var targetVector = problemData.Dataset.GetVariableValues(problemData.TargetVariable.Value, start, end);
    2042
    2143      SVM.Node[][] nodes = new SVM.Node[targetVector.Length][];
    2244      List<SVM.Node> tempRow;
    23       int addedRows = 0;
    24       int maxColumns = 0;
     45      int maxNodeIndex = 0;
    2546      for (int row = 0; row < rowCount; row++) {
    2647        tempRow = new List<SVM.Node>();
    27         int nodeIndex = 0;
    28         foreach (var inputVariable in inputVariables) {
    29           ++nodeIndex;
    30           int col = dataset.GetVariableIndex(inputVariable);
    31           if (IsUsefulColumn(dataset, col, start, end)) {
    32             for (int timeOffset = minTimeOffset; timeOffset <= maxTimeOffset; timeOffset++) {
    33               int actualColumn = nodeIndex * (maxTimeOffset - minTimeOffset + 1) + (timeOffset - minTimeOffset);
    34               if (start + row + timeOffset >= 0 && start + row + timeOffset < dataset.Rows) {
    35                 double value = dataset.GetValue(start + row + timeOffset, col);
    36                 if (!double.IsNaN(value)) {
    37                   tempRow.Add(new SVM.Node(actualColumn, value));
    38                   if (actualColumn > maxColumns) maxColumns = actualColumn;
    39                 }
    40               }
    41             }
     48        foreach (var inputVariable in problemData.InputVariables) {
     49          int col = problemData.Dataset.GetVariableIndex(inputVariable.Value);
     50          double value = problemData.Dataset[start + row, col];
     51          if (!double.IsNaN(value)) {
     52            int nodeIndex = col + 1; // make sure the smallest nodeIndex = 1
     53            tempRow.Add(new SVM.Node(nodeIndex, value));
     54            if (nodeIndex > maxNodeIndex) maxNodeIndex = nodeIndex;
    4255          }
    4356        }
    44         if (!double.IsNaN(dataset.GetValue(start + row, targetVariableIndex))) {
    45           nodes[addedRows] = tempRow.ToArray();
    46           addedRows++;
    47         }
     57        nodes[row] = tempRow.OrderBy(x => x.Index).ToArray(); // make sure the values are sorted by node index
    4858      }
    4959
    50       return new SVM.Problem(targetVector.Length, targetVector, nodes, maxColumns);
    51     }
    52 
    53     // checks if the column has at least two different non-NaN and non-Infinity values
    54     private static bool IsUsefulColumn(Dataset dataset, int col, int start, int end) {
    55       double min = double.PositiveInfinity;
    56       double max = double.NegativeInfinity;
    57       for (int i = start; i < end; i++) {
    58         double x = dataset.GetValue(i, col);
    59         if (!double.IsNaN(x) && !double.IsInfinity(x)) {
    60           min = Math.Min(min, x);
    61           max = Math.Max(max, x);
    62         }
    63         if (min != max) return true;
    64       }
    65       return false;
     60      return new SVM.Problem(targetVector.Length, targetVector, nodes, maxNodeIndex);
    6661    }
    6762  }
Note: See TracChangeset for help on using the changeset viewer.