Changeset 8465
- Timestamp:
- 08/10/12 14:57:21 (12 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourClassification.cs
r8139 r8465 21 21 22 22 using System; 23 using System.Collections.Generic;24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 27 26 using HeuristicLab.Data; 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;29 27 using HeuristicLab.Optimization; 28 using HeuristicLab.Parameters; 30 29 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 30 using HeuristicLab.Problems.DataAnalysis; 32 using HeuristicLab.Problems.DataAnalysis.Symbolic;33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;34 using HeuristicLab.Parameters;35 31 36 32 namespace HeuristicLab.Algorithms.DataAnalysis { … … 84 80 85 81 public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k) { 86 Dataset dataset = problemData.Dataset; 87 string targetVariable = problemData.TargetVariable; 88 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 89 IEnumerable<int> rows = problemData.TrainingIndices; 90 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 91 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 92 throw new NotSupportedException("Nearest neighbour classification does not support NaN or infinity values in the input dataset."); 82 var problemDataClone = (IClassificationProblemData)problemData.Clone(); 83 return new NearestNeighbourClassificationSolution(problemDataClone, Train(problemDataClone, k)); 84 } 93 85 94 alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree(); 95 96 int nRows = inputMatrix.GetLength(0); 97 int nFeatures = inputMatrix.GetLength(1) - 1; 98 double[] classValues = dataset.GetDoubleValues(targetVariable).Distinct().OrderBy(x => x).ToArray(); 99 int nClasses = classValues.Count(); 100 // map original class values to values [0..nClasses-1] 101 Dictionary<double, double> classIndices = new Dictionary<double, double>(); 102 for (int i = 0; i < nClasses; i++) { 103 classIndices[classValues[i]] = i; 104 } 105 for (int row = 0; row < nRows; row++) { 106 inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]]; 107 } 108 alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree); 109 var problemDataClone = (IClassificationProblemData) problemData.Clone(); 110 return new NearestNeighbourClassificationSolution(problemDataClone, new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables, problemDataClone.ClassValues.ToArray())); 86 public static INearestNeighbourModel Train(IClassificationProblemData problemData, int k) { 87 return new NearestNeighbourModel(problemData.Dataset, 88 problemData.TrainingIndices, 89 k, 90 problemData.TargetVariable, 91 problemData.AllowedInputVariables, 92 problemData.ClassValues.ToArray()); 111 93 } 112 94 #endregion -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
r7294 r8465 33 33 /// </summary> 34 34 [StorableClass] 35 [Item("NearestNeighbourModel", "Represents a ne ural networkfor regression and classification.")]35 [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")] 36 36 public sealed class NearestNeighbourModel : NamedItem, INearestNeighbourModel { 37 37 … … 56 56 [Storable] 57 57 private int k; 58 58 59 [StorableConstructor] 59 60 private NearestNeighbourModel(bool deserializing) … … 95 96 this.classValues = (double[])original.classValues.Clone(); 96 97 } 97 public NearestNeighbourModel(alglib.nearestneighbor.kdtree kdTree, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) 98 : base() { 99 this.name = ItemName; 100 this.description = ItemDescription; 101 this.kdTree = kdTree; 98 public NearestNeighbourModel(Dataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null) { 102 99 this.k = k; 103 100 this.targetVariable = targetVariable; 104 101 this.allowedInputVariables = allowedInputVariables.ToArray(); 105 if (classValues != null) 102 103 var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, 104 allowedInputVariables.Concat(new string[] { targetVariable }), 105 rows); 106 107 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 108 throw new NotSupportedException( 109 "Nearest neighbour classification does not support NaN or infinity values in the input dataset."); 110 111 this.kdTree = new alglib.nearestneighbor.kdtree(); 112 113 var nRows = inputMatrix.GetLength(0); 114 var nFeatures = inputMatrix.GetLength(1) - 1; 115 116 if (classValues != null) { 106 117 this.classValues = (double[])classValues.Clone(); 118 int nClasses = classValues.Length; 119 // map original class values to values [0..nClasses-1] 120 var classIndices = new Dictionary<double, double>(); 121 for (int i = 0; i < nClasses; i++) 122 classIndices[classValues[i]] = i; 123 124 for (int row = 0; row < nRows; row++) { 125 inputMatrix[row, nFeatures] = classIndices[inputMatrix[row, nFeatures]]; 126 } 127 } 128 alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree); 107 129 } 108 130 … … 140 162 141 163 public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) { 164 if (classValues == null) throw new InvalidOperationException("No class values are defined."); 142 165 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 143 166 -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourRegression.cs
r8139 r8465 21 21 22 22 using System; 23 using System.Collections.Generic;24 using System.Linq;25 23 using HeuristicLab.Common; 26 24 using HeuristicLab.Core; 27 25 using HeuristicLab.Data; 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;29 26 using HeuristicLab.Optimization; 27 using HeuristicLab.Parameters; 30 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 31 29 using HeuristicLab.Problems.DataAnalysis; 32 using HeuristicLab.Problems.DataAnalysis.Symbolic;33 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;34 using HeuristicLab.Parameters;35 30 36 31 namespace HeuristicLab.Algorithms.DataAnalysis { … … 84 79 85 80 public static IRegressionSolution CreateNearestNeighbourRegressionSolution(IRegressionProblemData problemData, int k) { 86 Dataset dataset = problemData.Dataset; 87 string targetVariable = problemData.TargetVariable; 88 IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables; 89 IEnumerable<int> rows = problemData.TrainingIndices; 90 double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows); 91 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) 92 throw new NotSupportedException("Nearest neighbour regression does not support NaN or infinity values in the input dataset."); 81 var clonedProblemData = (IRegressionProblemData)problemData.Clone(); 82 return new NearestNeighbourRegressionSolution(clonedProblemData, Train(problemData, k)); 83 } 93 84 94 alglib.nearestneighbor.kdtree kdtree = new alglib.nearestneighbor.kdtree(); 95 96 int nRows = inputMatrix.GetLength(0); 97 98 alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdtree); 99 100 return new NearestNeighbourRegressionSolution((IRegressionProblemData)problemData.Clone(), new NearestNeighbourModel(kdtree, k, targetVariable, allowedInputVariables)); 85 public static INearestNeighbourModel Train(IRegressionProblemData problemData, int k) { 86 return new NearestNeighbourModel(problemData.Dataset, 87 problemData.TrainingIndices, 88 k, 89 problemData.TargetVariable, 90 problemData.AllowedInputVariables); 101 91 } 102 92 #endregion
Note: See TracChangeset
for help on using the changeset viewer.