Changeset 6760 for branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine
- Timestamp:
- 09/14/11 13:59:25 (13 years ago)
- Location:
- branches/PersistenceSpeedUp
- Files:
-
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/PersistenceSpeedUp
- Property svn:ignore
-
old new 12 12 *.psess 13 13 *.vsp 14 *.docstates
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassification.cs
r6228 r6760 35 35 /// Support vector machine classification data analysis algorithm. 36 36 /// </summary> 37 [Item("Support Vector Classification", "Support vector machine classification data analysis algorithm .")]37 [Item("Support Vector Classification", "Support vector machine classification data analysis algorithm (wrapper for libSVM).")] 38 38 [Creatable("Data Analysis")] 39 39 [StorableClass] -
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorClassificationSolution.cs
r5809 r6760 20 20 #endregion 21 21 22 using System;23 using System.Collections.Generic;24 using System.Drawing;25 using System.Linq;26 22 using HeuristicLab.Common; 27 23 using HeuristicLab.Core; … … 49 45 public SupportVectorClassificationSolution(SupportVectorMachineModel model, IClassificationProblemData problemData) 50 46 : base(model, problemData) { 47 RecalculateResults(); 51 48 } 52 49 … … 54 51 return new SupportVectorClassificationSolution(this, cloner); 55 52 } 53 54 protected override void RecalculateResults() { 55 CalculateResults(); 56 } 56 57 } 57 58 } -
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineModel.cs
r5861 r6760 98 98 this.targetVariable = original.targetVariable; 99 99 this.allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 100 foreach (var dataset in original.cachedPredictions.Keys) { 101 this.cachedPredictions.Add(cloner.Clone(dataset), (double[])original.cachedPredictions[dataset].Clone()); 102 } 100 103 if (original.classValues != null) 101 104 this.classValues = (double[])original.classValues.Clone(); … … 123 126 return GetEstimatedValuesHelper(dataset, rows); 124 127 } 125 #endregion 128 public SupportVectorRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 129 return new SupportVectorRegressionSolution(this, problemData); 130 } 131 IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) { 132 return CreateRegressionSolution(problemData); 133 } 134 #endregion 135 126 136 #region IClassificationModel Members 127 137 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { … … 144 154 } 145 155 } 146 #endregion 156 157 public SupportVectorClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { 158 return new SupportVectorClassificationSolution(this, problemData); 159 } 160 IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) { 161 return CreateClassificationSolution(problemData); 162 } 163 #endregion 164 // cache for predictions, which is cloned but not persisted, must be cleared when the model is changed 165 private Dictionary<Dataset, double[]> cachedPredictions = new Dictionary<Dataset, double[]>(); 147 166 private IEnumerable<double> GetEstimatedValuesHelper(Dataset dataset, IEnumerable<int> rows) { 167 if (!cachedPredictions.ContainsKey(dataset)) { 168 // create an array of cached predictions which is initially filled with NaNs 169 double[] predictions = Enumerable.Repeat(double.NaN, dataset.Rows).ToArray(); 170 CalculatePredictions(dataset, rows, predictions); 171 cachedPredictions.Add(dataset, predictions); 172 } 173 // get the array of predictions and select the subset of requested rows 174 double[] p = cachedPredictions[dataset]; 175 var requestedPredictions = from r in rows 176 select p[r]; 177 // check if the requested predictions contain NaNs 178 // (this means for the request rows some predictions have not been cached) 179 if (requestedPredictions.Any(x => double.IsNaN(x))) { 180 // updated the predictions for currently requested rows 181 CalculatePredictions(dataset, rows, p); 182 cachedPredictions[dataset] = p; 183 // now we can be sure that for the current rows all predictions are available 184 return from r in rows 185 select p[r]; 186 } else { 187 // there were no NaNs => just return the cached predictions 188 return requestedPredictions; 189 } 190 } 191 192 private void CalculatePredictions(Dataset dataset, IEnumerable<int> rows, double[] predictions) { 193 // calculate and cache predictions for the currently requested rows 148 194 SVM.Problem problem = SupportVectorMachineUtil.CreateSvmProblem(dataset, targetVariable, allowedInputVariables, rows); 149 195 SVM.Problem scaledProblem = Scaling.Scale(RangeTransform, problem); 150 196 151 foreach (var row in Enumerable.Range(0, scaledProblem.Count)) { 152 yield return SVM.Prediction.Predict(Model, scaledProblem.X[row]); 153 } 154 } 197 // row is the index in the original dataset, 198 // i is the index in the scaled dataset (containing only the necessary rows) 199 int i = 0; 200 foreach (var row in rows) { 201 predictions[row] = SVM.Prediction.Predict(Model, scaledProblem.X[i]); 202 i++; 203 } 204 } 205 155 206 #region events 156 207 public event EventHandler Changed; 157 208 private void OnChanged(EventArgs e) { 209 cachedPredictions.Clear(); 158 210 var handlers = Changed; 159 211 if (handlers != null) -
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorMachineUtil.cs
r6002 r6760 34 34 public static SVM.Problem CreateSvmProblem(Dataset dataset, string targetVariable, IEnumerable<string> inputVariables, IEnumerable<int> rowIndices) { 35 35 double[] targetVector = 36 dataset.GetEnumeratedVariableValues(targetVariable, rowIndices) 37 .ToArray(); 36 dataset.GetDoubleValues(targetVariable, rowIndices).ToArray(); 38 37 39 38 SVM.Node[][] nodes = new SVM.Node[targetVector.Length][]; … … 46 45 int colIndex = 1; // make sure the smallest node index for SVM = 1 47 46 foreach (var inputVariable in inputVariablesList) { 48 double value = dataset [row, dataset.GetVariableIndex(inputVariable)];47 double value = dataset.GetDoubleValue(inputVariable, row); 49 48 // SVM also works with missing values 50 49 // => don't add NaN values in the dataset to the sparse SVM matrix representation -
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegression.cs
r6228 r6760 35 35 /// Support vector machine regression data analysis algorithm. 36 36 /// </summary> 37 [Item("Support Vector Regression", "Support vector machine regression data analysis algorithm .")]37 [Item("Support Vector Regression", "Support vector machine regression data analysis algorithm (wrapper for libSVM).")] 38 38 [Creatable("Data Analysis")] 39 39 [StorableClass] -
branches/PersistenceSpeedUp/HeuristicLab.Algorithms.DataAnalysis/3.4/SupportVectorMachine/SupportVectorRegressionSolution.cs
r5809 r6760 20 20 #endregion 21 21 22 using System;23 using System.Collections.Generic;24 using System.Drawing;25 using System.Linq;26 22 using HeuristicLab.Common; 27 23 using HeuristicLab.Core; … … 49 45 public SupportVectorRegressionSolution(SupportVectorMachineModel model, IRegressionProblemData problemData) 50 46 : base(model, problemData) { 47 RecalculateResults(); 51 48 } 52 49
Note: See TracChangeset
for help on using the changeset viewer.