Changeset 8465 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
- Timestamp:
- 08/10/12 14:57:21 (12 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
r7294 r8465 33 33 /// </summary> 34 34 [StorableClass] 35 [Item("NearestNeighbourModel", "Represents a ne ural networkfor regression and classification.")]35 [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")] 36 36 public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel { 37 37 … … 56 56 [Storable] 57 57 private int k; 58 58 59 [StorableConstructor] 59 60 private NearestNeighbourModel(bool deserializing) … … 95 96 this.classValues = (double[])original.classValues.Clone(); 96 97 } 97 public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 98 : base() { 99 this.name = ItemName; 100 this.description = ItemDescription; 101 this.kdTree = kdTree; 98 public NearestNeighbourModel(Dataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) { 102 99 this.k = k; 103 100 this.targetVariable = targetVariable; 104 101 this.allowedInputVariables = allowedInputVariables.ToArray(); 105 if (classValues != null) 102 103 var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, 104 allowedInputVariables.Concat(new string[] { targetVariable }), 105 rows); 106 107 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 108 throw new NotSupportedException( 109 "Nearest neighbour classification does not support NaN or infinity values in the input dataset."); 110 111 this.kdTree = new alglib.nearestneighbor.kdtree(); 112 113 var nRows = inputMatrix.GetLength(0); 114 var nFeatures = inputMatrix.GetLength(1) - 1; 115 116 if (classValues != null) { 106 117 this.classValues = (double[])classValues.Clone(); 118 int nClasses = classValues.Length; 119 // map original class values to values [0..nClasses-1] 120 var classIndices = new Dictionary<double, double>(); 121 for (int i = 0; i < nClasses; i++) 122 classIndices[classValues[i]] = i; 123 124 for (int row = 0; row < nRows; row++) { 125 inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]]; 126 } 127 } 128 alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree); 107 129 } 108 130 … … 140 162 141 163 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 164 if (classValues == null) throw new InvalidOperationException("No class values are defined."); 142 165 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 143 166
Note: See TracChangeset
for help on using the changeset viewer.