- Timestamp:
- 05/18/10 12:27:28 (15 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine
- Files:
-
- 1 added
- 4 copied
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModel.cs
r3763 r3842 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-20 08Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 27 27 using System.Globalization; 28 28 using System.IO; 29 using HeuristicLab.Modeling; 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 using HeuristicLab.Common; 30 31 31 namespace HeuristicLab.SupportVectorMachines { 32 public class SVMModel : ItemBase { 32 namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine { 33 /// <summary> 34 /// Represents a support vector machine model. 35 /// </summary> 36 [StorableClass] 37 [Item("SupportVectorMachineModel", "Represents a support vector machine model.")] 38 public class SupportVectorMachineModel : NamedItem { 33 39 private SVM.Model model; 34 40 /// <summary> … … 49 55 } 50 56 51 public override IView CreateView() { 52 return new SVMModelView(this); 57 #region persistence 58 [Storable] 59 private string ModelAsString { 60 get { 61 using (MemoryStream stream = new MemoryStream()) { 62 SVM.Model.Write(stream, Model); 63 stream.Seek(0, System.IO.SeekOrigin.Begin); 64 StreamReader reader = new StreamReader(stream); 65 return reader.ReadToEnd(); 66 } 67 } 68 set { 69 using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(value))) { 70 model = SVM.Model.Read(stream); 71 } 72 } 53 73 } 74 [Storable] 75 private string RangeTransformAsString { 76 get { 77 using (MemoryStream stream = new MemoryStream()) { 78 SVM.RangeTransform.Write(stream, RangeTransform); 79 stream.Seek(0, System.IO.SeekOrigin.Begin); 80 StreamReader reader = new StreamReader(stream); 81 return reader.ReadToEnd(); 82 } 83 } 84 set { 85 using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(value))) { 86 RangeTransform = SVM.RangeTransform.Read(stream); 87 } 88 } 89 } 90 #endregion 54 91 55 /// <summary> 56 /// Clones the current instance and adds it to the dictionary <paramref name="clonedObjects"/>. 57 /// </summary> 58 /// <param name="clonedObjects">Dictionary of all already cloned objects.</param> 59 /// <returns>The cloned instance as <see cref="DoubleData"/>.</returns> 60 public override object Clone(IDictionary<Guid, object> clonedObjects) { 61 SVMModel clone = new SVMModel(); 62 clonedObjects.Add(Guid, clone); 92 public override IDeepCloneable Clone(Cloner cloner) { 93 SupportVectorMachineModel clone = (SupportVectorMachineModel)base.Clone(cloner); 63 94 // beware we are only using a shallow copy here! (gkronber) 64 clone. Model = Model;65 clone. RangeTransform = RangeTransform;95 clone.model = model; 96 clone.rangeTransform = rangeTransform; 66 97 return clone; 67 98 } 68 99 69 100 /// <summary> 70 /// Saves the current instance as <see cref="XmlNode"/> in the specified <paramref name="document"/>.101 /// Exports the <paramref name="model"/> in string representation to output stream <paramref name="s"/> 71 102 /// </summary> 72 /// <remarks>The actual model is saved in the node's inner text as string, 73 /// its format depending on the local culture info and its number format.</remarks> 74 /// <param name="name">The (tag)name of the <see cref="XmlNode"/>.</param> 75 /// <param name="document">The <see cref="XmlDocument"/> where the data is saved.</param> 76 /// <param name="persistedObjects">A dictionary of all already persisted objects. (Needed to avoid cycles.)</param> 77 /// <returns>The saved <see cref="XmlNode"/>.</returns> 78 public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) { 79 XmlNode node = base.GetXmlNode(name, document, persistedObjects); 80 XmlNode model = document.CreateElement("Model"); 81 using (MemoryStream stream = new MemoryStream()) { 82 SVM.Model.Write(stream, Model); 83 stream.Seek(0, System.IO.SeekOrigin.Begin); 84 StreamReader reader = new StreamReader(stream); 85 model.InnerText = reader.ReadToEnd(); 86 node.AppendChild(model); 87 } 88 89 XmlNode rangeTransform = document.CreateElement("RangeTransform"); 90 using (MemoryStream stream = new MemoryStream()) { 91 SVM.RangeTransform.Write(stream, RangeTransform); 92 stream.Seek(0, System.IO.SeekOrigin.Begin); 93 StreamReader reader = new StreamReader(stream); 94 rangeTransform.InnerText = reader.ReadToEnd(); 95 node.AppendChild(rangeTransform); 96 } 97 98 return node; 99 } 100 /// <summary> 101 /// Loads the persisted SVM model from the specified <paramref name="node"/>. 102 /// </summary> 103 /// <remarks>The serialized SVM model must be saved in the node's inner text as a string 104 /// (see <see cref="GetXmlNode"/>).</remarks> 105 /// <param name="node">The <see cref="XmlNode"/> where the SVM model is saved.</param> 106 /// <param name="restoredObjects">A dictionary of all already restored objects. (Needed to avoid cycles.)</param> 107 public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) { 108 base.Populate(node, restoredObjects); 109 XmlNode model = node.SelectSingleNode("Model"); 110 using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(model.InnerText))) { 111 Model = SVM.Model.Read(stream); 112 } 113 XmlNode rangeTransform = node.SelectSingleNode("RangeTransform"); 114 using (MemoryStream stream = new MemoryStream(Encoding.ASCII.GetBytes(rangeTransform.InnerText))) { 115 RangeTransform = SVM.RangeTransform.Read(stream); 116 } 117 } 118 119 public static void Export(SVMModel model, Stream s) { 103 /// <param name="model">The support vector regression model to export</param> 104 /// <param name="s">The output stream to export the model to</param> 105 public static void Export(SupportVectorMachineModel model, Stream s) { 120 106 StreamWriter writer = new StreamWriter(s); 121 107 writer.WriteLine("RangeTransform:"); … … 136 122 } 137 123 138 public static SVMModel Import(TextReader reader) { 139 SVMModel model = new SVMModel(); 124 /// <summary> 125 /// Imports a support vector machine model given as string representation. 126 /// </summary> 127 /// <param name="reader">The reader to retrieve the string representation from</param> 128 /// <returns>The imported support vector machine model.</returns> 129 public static SupportVectorMachineModel Import(TextReader reader) { 130 SupportVectorMachineModel model = new SupportVectorMachineModel(); 140 131 while (reader.ReadLine().Trim() != "RangeTransform:") ; // read until line "RangeTransform"; 141 132 model.RangeTransform = SVM.RangeTransform.Read(reader); -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelCreator.cs
r3763 r3842 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-20 09Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 26 26 using HeuristicLab.Core; 27 27 using HeuristicLab.Data; 28 using HeuristicLab.DataAnalysis;29 28 using System.Threading; 29 using HeuristicLab.LibSVM; 30 using HeuristicLab.Operators; 31 using HeuristicLab.Parameters; 30 32 using SVM; 33 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 34 32 namespace HeuristicLab.SupportVectorMachines { 33 public class SupportVectorCreator : OperatorBase { 34 private Thread trainingThread; 35 private object locker = new object(); 36 private bool abortRequested = false; 35 namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine { 36 /// <summary> 37 /// Represents an operator that creates a support vector machine model. 38 /// </summary> 39 [StorableClass] 40 [Item("SupportVectorMachineModelCreator", "Represents an operator that creates a support vector machine model.")] 41 public class SupportVectorMachineModelCreator : SingleSuccessorOperator { 42 private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData"; 43 private const string SvmTypeParameterName = "SvmType"; 44 private const string KernelTypeParameterName = "KernelType"; 45 private const string CostParameterName = "Cost"; 46 private const string NuParameterName = "Nu"; 47 private const string GammaParameterName = "Gamma"; 48 private const string ModelParameterName = "SupportVectorMachineModel"; 37 49 38 public SupportVectorCreator() 50 #region parameter properties 51 public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter { 52 get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; } 53 } 54 public IValueParameter<StringValue> SvmTypeParameter { 55 get { return (IValueParameter<StringValue>)Parameters[SvmTypeParameterName]; } 56 } 57 public IValueParameter<StringValue> KernelTypeParameter { 58 get { return (IValueParameter<StringValue>)Parameters[KernelTypeParameterName]; } 59 } 60 public IValueLookupParameter<DoubleValue> NuParameter { 61 get { return (IValueLookupParameter<DoubleValue>)Parameters[NuParameterName]; } 62 } 63 public IValueLookupParameter<DoubleValue> CostParameter { 64 get { return (IValueLookupParameter<DoubleValue>)Parameters[CostParameterName]; } 65 } 66 public IValueLookupParameter<DoubleValue> GammaParameter { 67 get { return (IValueLookupParameter<DoubleValue>)Parameters[GammaParameterName]; } 68 } 69 public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter { 70 get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; } 71 } 72 #endregion 73 #region properties 74 public DataAnalysisProblemData DataAnalysisProblemData { 75 get { return DataAnalysisProblemDataParameter.ActualValue; } 76 } 77 public StringValue SvmType { 78 get { return SvmTypeParameter.Value; } 79 } 80 public StringValue KernelType { 81 get { return KernelTypeParameter.Value; } 82 } 83 public DoubleValue Nu { 84 get { return NuParameter.ActualValue; } 85 } 86 public DoubleValue Cost { 87 get { return CostParameter.ActualValue; } 88 } 89 public DoubleValue Gamma { 90 get { return GammaParameter.ActualValue; } 91 } 92 #endregion 93 94 public SupportVectorMachineModelCreator() 39 95 : base() { 40 //Dataset infos 41 AddVariableInfo(new VariableInfo("Dataset", "Dataset with all samples on which to apply the function", typeof(Dataset), VariableKind.In)); 42 AddVariableInfo(new VariableInfo("TargetVariable", "Name of the target variable", typeof(StringData), VariableKind.In)); 43 AddVariableInfo(new VariableInfo("InputVariables", "List of allowed input variable names", typeof(ItemList), VariableKind.In)); 44 AddVariableInfo(new VariableInfo("SamplesStart", "Start index of samples in dataset to evaluate", typeof(IntData), VariableKind.In)); 45 AddVariableInfo(new VariableInfo("SamplesEnd", "End index of samples in dataset to evaluate", typeof(IntData), VariableKind.In)); 46 AddVariableInfo(new VariableInfo("MaxTimeOffset", "(optional) Maximal time offset for time-series prognosis", typeof(IntData), VariableKind.In)); 47 AddVariableInfo(new VariableInfo("MinTimeOffset", "(optional) Minimal time offset for time-series prognosis", typeof(IntData), VariableKind.In)); 48 49 //SVM parameters 50 AddVariableInfo(new VariableInfo("SVMType", "String describing which SVM type is used. Valid inputs are: C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR", 51 typeof(StringData), VariableKind.In)); 52 AddVariableInfo(new VariableInfo("SVMKernelType", "String describing which SVM kernel is used. Valid inputs are: LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED", 53 typeof(StringData), VariableKind.In)); 54 AddVariableInfo(new VariableInfo("SVMCost", "Cost parameter (C) of C-SVC, epsilon-SVR and nu-SVR", typeof(DoubleData), VariableKind.In)); 55 AddVariableInfo(new VariableInfo("SVMNu", "Nu parameter of nu-SVC, one-class SVM and nu-SVR", typeof(DoubleData), VariableKind.In)); 56 AddVariableInfo(new VariableInfo("SVMGamma", "Gamma parameter in kernel function", typeof(DoubleData), VariableKind.In)); 57 AddVariableInfo(new VariableInfo("SVMModel", "Represent the model learned by the SVM", typeof(SVMModel), VariableKind.New | VariableKind.Out)); 96 #region svm types 97 StringValue cSvcType = new StringValue("C_SVC").AsReadOnly(); 98 StringValue nuSvcType = new StringValue("NU_SVC").AsReadOnly(); 99 StringValue eSvrType = new StringValue("EPSILON_SVR").AsReadOnly(); 100 StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly(); 101 ItemSet<StringValue> allowedSvmTypes = new ItemSet<StringValue>(); 102 allowedSvmTypes.Add(cSvcType); 103 allowedSvmTypes.Add(nuSvcType); 104 allowedSvmTypes.Add(eSvrType); 105 allowedSvmTypes.Add(nuSvrType); 106 #endregion 107 #region kernel types 108 StringValue rbfKernelType = new StringValue("RBF").AsReadOnly(); 109 StringValue linearKernelType = new StringValue("LINEAR").AsReadOnly(); 110 StringValue polynomialKernelType = new StringValue("POLY").AsReadOnly(); 111 StringValue sigmoidKernelType = new StringValue("SIGMOID").AsReadOnly(); 112 ItemSet<StringValue> allowedKernelTypes = new ItemSet<StringValue>(); 113 allowedKernelTypes.Add(rbfKernelType); 114 allowedKernelTypes.Add(linearKernelType); 115 allowedKernelTypes.Add(polynomialKernelType); 116 allowedKernelTypes.Add(sigmoidKernelType); 117 #endregion 118 Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training.")); 119 Parameters.Add(new ConstrainedValueParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", allowedSvmTypes, nuSvrType)); 120 Parameters.Add(new ConstrainedValueParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", allowedKernelTypes, rbfKernelType)); 121 Parameters.Add(new ValueLookupParameter<DoubleValue>(NuParameterName, "The value of the nu parameter nu-SVC, one-class SVM and nu-SVR.")); 122 Parameters.Add(new ValueLookupParameter<DoubleValue>(CostParameterName, "The value of the C (cost) parameter of C-SVC, epsilon-SVR and nu-SVR.")); 123 Parameters.Add(new ValueLookupParameter<DoubleValue>(GammaParameterName, "The value of the gamma parameter in the kernel function.")); 124 Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM.")); 58 125 } 59 126 60 public override void Abort() { 61 abortRequested = true; 62 lock (locker) { 63 if (trainingThread != null && trainingThread.ThreadState == ThreadState.Running) { 64 trainingThread.Abort(); 65 } 66 } 127 public override IOperation Apply() { 128 129 SupportVectorMachineModel model = TrainModel(DataAnalysisProblemData, 130 SvmType.Value, KernelType.Value, 131 Cost.Value, Nu.Value, Gamma.Value); 132 SupportVectorMachineModelParameter.ActualValue = model; 133 134 return base.Apply(); 67 135 } 68 136 69 public override IOperation Apply(IScope scope) { 70 abortRequested = false; 71 Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true); 72 string targetVariable = GetVariableValue<StringData>("TargetVariable", scope, true).Data; 73 ItemList inputVariables = GetVariableValue<ItemList>("InputVariables", scope, true); 74 var inputVariableNames = from x in inputVariables 75 select ((StringData)x).Data; 76 int start = GetVariableValue<IntData>("SamplesStart", scope, true).Data; 77 int end = GetVariableValue<IntData>("SamplesEnd", scope, true).Data; 78 IntData maxTimeOffsetData = GetVariableValue<IntData>("MaxTimeOffset", scope, true, false); 79 int maxTimeOffset = maxTimeOffsetData == null ? 0 : maxTimeOffsetData.Data; 80 IntData minTimeOffsetData = GetVariableValue<IntData>("MinTimeOffset", scope, true, false); 81 int minTimeOffset = minTimeOffsetData == null ? 0 : minTimeOffsetData.Data; 82 string svmType = GetVariableValue<StringData>("SVMType", scope, true).Data; 83 string svmKernelType = GetVariableValue<StringData>("SVMKernelType", scope, true).Data; 84 85 double cost = GetVariableValue<DoubleData>("SVMCost", scope, true).Data; 86 double nu = GetVariableValue<DoubleData>("SVMNu", scope, true).Data; 87 double gamma = GetVariableValue<DoubleData>("SVMGamma", scope, true).Data; 88 89 SVMModel modelData = null; 90 lock (locker) { 91 if (!abortRequested) { 92 trainingThread = new Thread(() => { 93 modelData = TrainModel(dataset, targetVariable, inputVariableNames, 94 start, end, minTimeOffset, maxTimeOffset, 95 svmType, svmKernelType, 96 cost, nu, gamma); 97 }); 98 trainingThread.Start(); 99 } 100 } 101 if (!abortRequested) { 102 trainingThread.Join(); 103 trainingThread = null; 104 } 105 106 107 if (!abortRequested) { 108 //persist variables in scope 109 scope.AddVariable(new Variable(scope.TranslateName("SVMModel"), modelData)); 110 return null; 111 } else { 112 return new AtomicOperation(this, scope); 113 } 114 } 115 116 public static SVMModel TrainRegressionModel( 117 Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, 118 int start, int end, 119 double cost, double nu, double gamma) { 120 return TrainModel(dataset, targetVariable, inputVariables, start, end, 0, 0, "NU_SVR", "RBF", cost, nu, gamma); 121 } 122 123 public static SVMModel TrainModel( 124 Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, 125 int start, int end, 126 int minTimeOffset, int maxTimeOffset, 137 public static SupportVectorMachineModel TrainModel( 138 DataAnalysisProblemData problemData, 127 139 string svmType, string kernelType, 128 140 double cost, double nu, double gamma) { 129 int targetVariableIndex = dataset.GetVariableIndex(targetVariable);141 int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value); 130 142 131 143 //extract SVM parameters from scope and set them … … 140 152 141 153 142 SVM.Problem problem = S VMHelper.CreateSVMProblem(dataset, targetVariableIndex, inputVariables, start, end, minTimeOffset, maxTimeOffset);154 SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value); 143 155 SVM.RangeTransform rangeTransform = SVM.RangeTransform.Compute(problem); 144 SVM.Problem scaledProblem = rangeTransform.Scale(problem);145 var model = new S VMModel();156 SVM.Problem scaledProblem = Scaling.Scale(rangeTransform, problem); 157 var model = new SupportVectorMachineModel(); 146 158 147 159 model.Model = SVM.Training.Train(scaledProblem, parameter); -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineModelEvaluator.cs
r3763 r3842 26 26 using HeuristicLab.Core; 27 27 using HeuristicLab.Data; 28 using HeuristicLab.DataAnalysis;29 28 using SVM; 29 using HeuristicLab.Operators; 30 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 using HeuristicLab.Parameters; 30 32 31 namespace HeuristicLab.SupportVectorMachines { 32 public class SupportVectorEvaluator : OperatorBase { 33 namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine { 34 [StorableClass] 35 [Item("SupportVectorMachineModelEvaluator", "Represents a operator that evaluates a support vector machine model on a data set.")] 36 public class SupportVectorMachineModelEvaluator : SingleSuccessorOperator { 37 private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData"; 38 private const string ModelParameterName = "SupportVectorMachineModel"; 39 private const string SamplesStartParameterName = "SamplesStart"; 40 private const string SamplesEndParameterName = "SamplesEnd"; 41 private const string ValuesParameterName = "Values"; 33 42 34 public SupportVectorEvaluator() 43 #region parameter properties 44 public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter { 45 get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; } 46 } 47 public IValueLookupParameter<IntValue> SamplesStartParameter { 48 get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; } 49 } 50 public IValueLookupParameter<IntValue> SamplesEndParameter { 51 get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; } 52 } 53 public ILookupParameter<SupportVectorMachineModel> SupportVectorMachineModelParameter { 54 get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[ModelParameterName]; } 55 } 56 public ILookupParameter<DoubleMatrix> ValuesParameter { 57 get { return (ILookupParameter<DoubleMatrix>)Parameters[ValuesParameterName]; } 58 } 59 #endregion 60 #region properties 61 public DataAnalysisProblemData DataAnalysisProblemData { 62 get { return DataAnalysisProblemDataParameter.ActualValue; } 63 } 64 public SupportVectorMachineModel SupportVectorMachineModel { 65 get { return SupportVectorMachineModelParameter.ActualValue; } 66 } 67 public IntValue SamplesStart { 68 get { return SamplesStartParameter.ActualValue; } 69 } 70 public IntValue SamplesEnd { 71 get { return SamplesEndParameter.ActualValue; } 72 } 73 #endregion 74 public SupportVectorMachineModelEvaluator() 35 75 : base() { 36 //Dataset infos 37 AddVariableInfo(new VariableInfo("Dataset", "Dataset with all samples on which to apply the function", typeof(Dataset), VariableKind.In)); 38 AddVariableInfo(new VariableInfo("TargetVariable", "Name of the target variable", typeof(StringData), VariableKind.In)); 39 AddVariableInfo(new VariableInfo("InputVariables", "List of allowed input variable names", typeof(ItemList), VariableKind.In)); 40 AddVariableInfo(new VariableInfo("SamplesStart", "Start index of samples in dataset to evaluate", typeof(IntData), VariableKind.In)); 41 AddVariableInfo(new VariableInfo("SamplesEnd", "End index of samples in dataset to evaluate", typeof(IntData), VariableKind.In)); 42 AddVariableInfo(new VariableInfo("MaxTimeOffset", "(optional) Maximal allowed time offset for input variables", typeof(IntData), VariableKind.In)); 43 AddVariableInfo(new VariableInfo("MinTimeOffset", "(optional) Minimal allowed time offset for input variables", typeof(IntData), VariableKind.In)); 44 AddVariableInfo(new VariableInfo("SVMModel", "Represent the model learned by the SVM", typeof(SVMModel), VariableKind.In)); 45 AddVariableInfo(new VariableInfo("Values", "Target vs predicted values", typeof(DoubleMatrixData), VariableKind.New | VariableKind.Out)); 76 Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training.")); 77 Parameters.Add(new LookupParameter<SupportVectorMachineModel>(ModelParameterName, "The result model generated by the SVM.")); 78 Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first index of the data set partition on which the SVM model should be evaluated.")); 79 Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last index of the data set partition on which the SVM model should be evaluated.")); 80 Parameters.Add(new LookupParameter<DoubleMatrix>(ValuesParameterName, "A matrix of original values of the target variable and output values of the SVM model.")); 46 81 } 47 82 83 public override IOperation Apply() { 84 int targetVariableIndex = DataAnalysisProblemData.Dataset.GetVariableIndex(DataAnalysisProblemData.TargetVariable.Value); 85 int start = SamplesStart.Value; 86 int end = SamplesEnd.Value; 48 87 49 public override IOperation Apply(IScope scope) { 50 Dataset dataset = GetVariableValue<Dataset>("Dataset", scope, true); 51 ItemList inputVariables = GetVariableValue<ItemList>("InputVariables", scope, true); 52 var inputVariableNames = from x in inputVariables 53 select ((StringData)x).Data; 54 string targetVariable = GetVariableValue<StringData>("TargetVariable", scope, true).Data; 55 int targetVariableIndex = dataset.GetVariableIndex(targetVariable); 56 int start = GetVariableValue<IntData>("SamplesStart", scope, true).Data; 57 int end = GetVariableValue<IntData>("SamplesEnd", scope, true).Data; 58 IntData minTimeOffsetData = GetVariableValue<IntData>("MinTimeOffset", scope, true, false); 59 int minTimeOffset = minTimeOffsetData == null ? 0 : minTimeOffsetData.Data; 60 IntData maxTimeOffsetData = GetVariableValue<IntData>("MaxTimeOffset", scope, true, false); 61 int maxTimeOffset = maxTimeOffsetData == null ? 0 : maxTimeOffsetData.Data; 62 SVMModel modelData = GetVariableValue<SVMModel>("SVMModel", scope, true); 63 64 SVM.Problem problem = SVMHelper.CreateSVMProblem(dataset, targetVariableIndex, inputVariableNames, start, end, minTimeOffset, maxTimeOffset); 65 SVM.Problem scaledProblem = modelData.RangeTransform.Scale(problem); 88 SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(DataAnalysisProblemData, start, end); 89 SVM.Problem scaledProblem = SupportVectorMachineModel.RangeTransform.Scale(problem); 66 90 67 91 double[,] values = new double[scaledProblem.Count, 2]; 68 92 for (int i = 0; i < scaledProblem.Count; i++) { 69 values[i, 0] = dataset.GetValue(start + i, targetVariableIndex);70 values[i, 1] = SVM.Prediction.Predict( modelData.Model, scaledProblem.X[i]);93 values[i, 0] = DataAnalysisProblemData.Dataset[start + i, targetVariableIndex]; 94 values[i, 1] = SVM.Prediction.Predict(SupportVectorMachineModel.Model, scaledProblem.X[i]); 71 95 } 72 96 73 scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName("Values"), new DoubleMatrixData(values)));74 return null;97 ValuesParameter.ActualValue = new DoubleMatrix(values); 98 return base.Apply(); 75 99 } 76 100 } -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/SupportVectorMachineUtil.cs
r3763 r3842 1 using System; 1 #region License Information 2 /* HeuristicLab 3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 * 5 * This file is part of HeuristicLab. 6 * 7 * HeuristicLab is free software: you can redistribute it and/or modify 8 * it under the terms of the GNU General Public License as published by 9 * the Free Software Foundation, either version 3 of the License, or 10 * (at your option) any later version. 11 * 12 * HeuristicLab is distributed in the hope that it will be useful, 13 * but WITHOUT ANY WARRANTY; without even the implied warranty of 14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 * GNU General Public License for more details. 16 * 17 * You should have received a copy of the GNU General Public License 18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>. 19 */ 20 #endregion 21 22 using System; 2 23 using System.Collections.Generic; 3 24 using System.Linq; … … 5 26 using HeuristicLab.Core; 6 27 using HeuristicLab.Data; 7 using HeuristicLab.DataAnalysis;8 28 using HeuristicLab.Common; 9 29 10 namespace HeuristicLab.SupportVectorMachines { 11 public class SVMHelper { 12 public static SVM.Problem CreateSVMProblem(Dataset dataset, int targetVariableIndex, IEnumerable<string> inputVariables, int start, int end, int minTimeOffset, int maxTimeOffset) { 30 namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine { 31 public class SupportVectorMachineUtil { 32 /// <summary> 33 /// Transforms <paramref name="problemData"/> into a data structure as needed by libSVM. 34 /// </summary> 35 /// <param name="problemData">The problem data to transform</param> 36 /// <param name="start">The index of the first row of <paramref name="problemData"/> to copy to the output.</param> 37 /// <param name="end">The last of the first row of <paramref name="problemData"/> to copy to the output.</param> 38 /// <returns>A problem data type that can be used to train a support vector machine.</returns> 39 public static SVM.Problem CreateSvmProblem(DataAnalysisProblemData problemData, int start, int end) { 13 40 int rowCount = end - start; 14 15 var targetVector = (from row in Enumerable.Range(start, rowCount) 16 let val = dataset.GetValue(row, targetVariableIndex) 17 where !double.IsNaN(val) 18 select val).ToArray(); 19 41 var targetVector = problemData.Dataset.GetVariableValues(problemData.TargetVariable.Value, start, end); 20 42 21 43 SVM.Node[][] nodes = new SVM.Node[targetVector.Length][]; 22 44 List<SVM.Node> tempRow; 23 int addedRows = 0; 24 int maxColumns = 0; 45 int maxNodeIndex = 0; 25 46 for (int row = 0; row < rowCount; row++) { 26 47 tempRow = new List<SVM.Node>(); 27 int nodeIndex = 0; 28 foreach (var inputVariable in inputVariables) { 29 ++nodeIndex; 30 int col = dataset.GetVariableIndex(inputVariable); 31 if (IsUsefulColumn(dataset, col, start, end)) { 32 for (int timeOffset = minTimeOffset; timeOffset <= maxTimeOffset; timeOffset++) { 33 int actualColumn = nodeIndex * (maxTimeOffset - minTimeOffset + 1) + (timeOffset - minTimeOffset); 34 if (start + row + timeOffset >= 0 && start + row + timeOffset < dataset.Rows) { 35 double value = dataset.GetValue(start + row + timeOffset, col); 36 if (!double.IsNaN(value)) { 37 tempRow.Add(new SVM.Node(actualColumn, value)); 38 if (actualColumn > maxColumns) maxColumns = actualColumn; 39 } 40 } 41 } 48 foreach (var inputVariable in problemData.InputVariables) { 49 int col = problemData.Dataset.GetVariableIndex(inputVariable.Value); 50 double value = problemData.Dataset[start + row, col]; 51 if (!double.IsNaN(value)) { 52 int nodeIndex = col + 1; // make sure the smallest nodeIndex = 1 53 tempRow.Add(new SVM.Node(nodeIndex, value)); 54 if (nodeIndex > maxNodeIndex) maxNodeIndex = nodeIndex; 42 55 } 43 56 } 44 if (!double.IsNaN(dataset.GetValue(start + row, targetVariableIndex))) { 45 nodes[addedRows] = tempRow.ToArray(); 46 addedRows++; 47 } 57 nodes[row] = tempRow.OrderBy(x => x.Index).ToArray(); // make sure the values are sorted by node index 48 58 } 49 59 50 return new SVM.Problem(targetVector.Length, targetVector, nodes, maxColumns); 51 } 52 53 // checks if the column has at least two different non-NaN and non-Infinity values 54 private static bool IsUsefulColumn(Dataset dataset, int col, int start, int end) { 55 double min = double.PositiveInfinity; 56 double max = double.NegativeInfinity; 57 for (int i = start; i < end; i++) { 58 double x = dataset.GetValue(i, col); 59 if (!double.IsNaN(x) && !double.IsInfinity(x)) { 60 min = Math.Min(min, x); 61 max = Math.Max(max, x); 62 } 63 if (min != max) return true; 64 } 65 return false; 60 return new SVM.Problem(targetVector.Length, targetVector, nodes, maxNodeIndex); 66 61 } 67 62 }
Note: See TracChangeset
for help on using the changeset viewer.