Changeset 14308 for stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
- Timestamp:
- 09/26/16 18:25:24 (8 years ago)
- Location:
- stable
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
/trunk/sources merged: 14235
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 14235
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs
r14186 r14308 58 58 [Storable] 59 59 private int k; 60 [Storable(DefaultValue = null)] 61 private double[] weights; // not set for old versions loaded from disk 62 [Storable(DefaultValue = null)] 63 private double[] offsets; // not set for old versions loaded from disk 60 64 61 65 [StorableConstructor] … … 93 97 94 98 k = original.k; 99 isCompatibilityLoaded = original.IsCompatibilityLoaded; 100 if (!IsCompatibilityLoaded) { 101 weights = new double[original.weights.Length]; 102 Array.Copy(original.weights, weights, weights.Length); 103 offsets = new double[original.offsets.Length]; 104 Array.Copy(original.offsets, this.offsets, this.offsets.Length); 105 } 95 106 allowedInputVariables = (string[])original.allowedInputVariables.Clone(); 96 107 if (original.classValues != null) 97 108 this.classValues = (double[])original.classValues.Clone(); 98 109 } 99 public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)110 public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null) 100 111 : base(targetVariable) { 101 112 Name = ItemName; … … 103 114 this.k = k; 104 115 this.allowedInputVariables = allowedInputVariables.ToArray(); 105 106 var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, 107 allowedInputVariables.Concat(new string[] { targetVariable }), 108 rows); 116 double[,] inputMatrix; 117 if (IsCompatibilityLoaded) { 118 // no scaling 119 inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, 120 this.allowedInputVariables.Concat(new string[] { targetVariable }), 121 rows); 122 } else { 123 this.offsets = this.allowedInputVariables 124 .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1) 125 .Concat(new double[] { 0 }) // no offset for target variable 126 .ToArray(); 127 if (weights == null) { 128 // automatic determination of weights (all features should have variance = 1) 129 this.weights = this.allowedInputVariables 130 .Select(name => 1.0 / dataset.GetDoubleValues(name, rows).StandardDeviationPop()) 131 .Concat(new double[] { 1.0 }) // no scaling for target variable 132 .ToArray(); 133 } else { 134 // user specified weights (+ 1 for target) 135 this.weights = weights.Concat(new double[] { 1.0 }).ToArray(); 136 if (this.weights.Length - 1 != this.allowedInputVariables.Length) 137 throw new ArgumentException("The number of elements in the weight vector must match the number of input variables"); 138 } 139 inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights); 140 } 109 141 110 142 if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x))) … … 132 164 } 133 165 166 private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) { 167 var x = new double[rows.Count(), variables.Count()]; 168 var colIdx = 0; 169 foreach (var variableName in variables) { 170 var rowIdx = 0; 171 foreach (var val in dataset.GetDoubleValues(variableName, rows)) { 172 x[rowIdx, colIdx] = (val + offsets[colIdx]) * factors[colIdx]; 173 rowIdx++; 174 } 175 colIdx++; 176 } 177 return x; 178 } 179 134 180 public override IDeepCloneable Clone(Cloner cloner) { 135 181 return new NearestNeighbourModel(this, cloner); … … 137 183 138 184 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 139 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 185 double[,] inputData; 186 if (IsCompatibilityLoaded) { 187 inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 188 } else { 189 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 190 } 140 191 141 192 int n = inputData.GetLength(0); 142 193 int columns = inputData.GetLength(1); 143 194 double[] x = new double[columns]; 144 double[] y = new double[1];145 195 double[] dists = new double[k]; 146 196 double[,] neighbours = new double[k, columns + 1]; … … 152 202 int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false); 153 203 alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists); 154 alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); 204 alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); // gkronber: this call changes the kdTree data structure 155 205 156 206 double distanceWeightedValue = 0.0; … … 166 216 public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) { 167 217 if (classValues == null) throw new InvalidOperationException("No class values are defined."); 168 double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 169 218 double[,] inputData; 219 if (IsCompatibilityLoaded) { 220 inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows); 221 } else { 222 inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights); 223 } 170 224 int n = inputData.GetLength(0); 171 225 int columns = inputData.GetLength(1); … … 219 273 #endregion 220 274 275 276 // BackwardsCompatibility3.3 277 #region Backwards compatible code, remove with 3.4 278 279 private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true 280 [Storable(DefaultValue = true)] 281 public bool IsCompatibilityLoaded { 282 get { return isCompatibilityLoaded; } 283 set { isCompatibilityLoaded = value; } 284 } 285 #endregion 221 286 #region persistence 222 287 [Storable]
Note: See TracChangeset
for help on using the changeset viewer.