Changeset 16308 for branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
- Timestamp:
- 11/20/18 13:52:40 (6 years ago)
- Location:
- branches/2845_EnhancedProgress
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2845_EnhancedProgress
- Property svn:mergeinfo changed
/stable reverse-merged: 15587-15588 /trunk/sources removed
- Property svn:mergeinfo changed
-
branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/stable/HeuristicLab.Algorithms.DataAnalysis reverse-merged: 15587 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis removed
- Property svn:mergeinfo changed
-
branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis/3.4
- Property svn:mergeinfo deleted
-
branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
r16307 r16308 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 8Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 38 38 namespace HeuristicLab.Algorithms.DataAnalysis { 39 39 /// <summary> 40 /// t- Distributed Stochastic Neighbor Embedding (tSNE) projects the data in a low dimensional40 /// t-distributed stochastic neighbourhood embedding (tSNE) projects the data in a low dimensional 41 41 /// space to allow visual cluster identification. 42 42 /// </summary> 43 [Item("t -Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " +44 43 [Item("tSNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " + 44 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")] 45 45 [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)] 46 46 [StorableClass] … … 57 57 } 58 58 59 #region Parameter names59 #region parameter names 60 60 private const string DistanceFunctionParameterName = "DistanceFunction"; 61 61 private const string PerplexityParameterName = "Perplexity"; … … 72 72 private const string ClassesNameParameterName = "ClassesName"; 73 73 private const string NormalizationParameterName = "Normalization"; 74 private const string RandomInitializationParameterName = "RandomInitialization";75 74 private const string UpdateIntervalParameterName = "UpdateInterval"; 76 75 #endregion 77 76 78 #region Result names77 #region result names 79 78 private const string IterationResultName = "Iteration"; 80 79 private const string ErrorResultName = "Error"; … … 84 83 #endregion 85 84 86 #region Parameter properties85 #region parameter properties 87 86 public IFixedValueParameter<DoubleValue> PerplexityParameter { 88 get { return (IFixedValueParameter<DoubleValue>)Parameters[PerplexityParameterName]; }87 get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; } 89 88 } 90 89 public IFixedValueParameter<PercentValue> ThetaParameter { 91 get { return (IFixedValueParameter<PercentValue>)Parameters[ThetaParameterName]; }90 get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; } 92 91 } 93 92 public IFixedValueParameter<IntValue> NewDimensionsParameter { 94 get { return (IFixedValueParameter<IntValue>)Parameters[NewDimensionsParameterName]; }93 get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; } 95 94 } 96 95 public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter { 97 get { return (IConstrainedValueParameter<IDistance<double[]>>)Parameters[DistanceFunctionParameterName]; }96 get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; } 98 97 } 99 98 public IFixedValueParameter<IntValue> MaxIterationsParameter { 100 get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }99 get { return Parameters[MaxIterationsParameterName] as IFixedValueParameter<IntValue>; } 101 100 } 102 101 public IFixedValueParameter<IntValue> StopLyingIterationParameter { 103 get { return (IFixedValueParameter<IntValue>)Parameters[StopLyingIterationParameterName]; }102 get { return Parameters[StopLyingIterationParameterName] as IFixedValueParameter<IntValue>; } 104 103 } 105 104 public IFixedValueParameter<IntValue> MomentumSwitchIterationParameter { 106 get { return (IFixedValueParameter<IntValue>)Parameters[MomentumSwitchIterationParameterName]; }105 get { return Parameters[MomentumSwitchIterationParameterName] as IFixedValueParameter<IntValue>; } 107 106 } 108 107 public IFixedValueParameter<DoubleValue> InitialMomentumParameter { 109 get { return (IFixedValueParameter<DoubleValue>)Parameters[InitialMomentumParameterName]; }108 get { return Parameters[InitialMomentumParameterName] as IFixedValueParameter<DoubleValue>; } 110 109 } 111 110 public IFixedValueParameter<DoubleValue> FinalMomentumParameter { 112 get { return (IFixedValueParameter<DoubleValue>)Parameters[FinalMomentumParameterName]; }111 get { return Parameters[FinalMomentumParameterName] as IFixedValueParameter<DoubleValue>; } 113 112 } 114 113 public IFixedValueParameter<DoubleValue> EtaParameter { 115 get { return (IFixedValueParameter<DoubleValue>)Parameters[EtaParameterName]; }114 get { return Parameters[EtaParameterName] as IFixedValueParameter<DoubleValue>; } 116 115 } 117 116 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 118 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }117 get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; } 119 118 } 120 119 public IFixedValueParameter<IntValue> SeedParameter { 121 get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }120 get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; } 122 121 } 123 122 public IConstrainedValueParameter<StringValue> ClassesNameParameter { 124 get { return (IConstrainedValueParameter<StringValue>)Parameters[ClassesNameParameterName]; }123 get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; } 125 124 } 126 125 public IFixedValueParameter<BoolValue> NormalizationParameter { 127 get { return (IFixedValueParameter<BoolValue>)Parameters[NormalizationParameterName]; } 128 } 129 public IFixedValueParameter<BoolValue> RandomInitializationParameter { 130 get { return (IFixedValueParameter<BoolValue>)Parameters[RandomInitializationParameterName]; } 126 get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; } 131 127 } 132 128 public IFixedValueParameter<IntValue> UpdateIntervalParameter { 133 get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }129 get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; } 134 130 } 135 131 #endregion … … 191 187 set { NormalizationParameter.Value.Value = value; } 192 188 } 193 public bool RandomInitialization { 194 get { return RandomInitializationParameter.Value.Value; } 195 set { RandomInitializationParameter.Value.Value = value; } 196 } 189 197 190 public int UpdateInterval { 198 191 get { return UpdateIntervalParameter.Value.Value; } … … 201 194 #endregion 202 195 203 #region Storable poperties204 [Storable]205 private Dictionary<string, IList<int>> dataRowIndices;206 [Storable]207 private TSNEStatic<double[]>.TSNEState state;208 #endregion209 210 196 #region Constructors & Cloning 211 197 [StorableConstructor] 212 198 private TSNEAlgorithm(bool deserializing) : base(deserializing) { } 213 199 214 [StorableHook(HookType.AfterDeserialization)]215 private void AfterDeserialization() {216 if (!Parameters.ContainsKey(RandomInitializationParameterName))217 Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));218 RegisterParameterEvents();219 }220 200 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { 221 if (original.dataRowIndices != null) 222 dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices); 201 if (original.dataRowNames != null) 202 this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames); 203 if (original.dataRows != null) 204 this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value)); 223 205 if (original.state != null) 224 state = cloner.Clone(original.state); 225 RegisterParameterEvents(); 226 } 227 public override IDeepCloneable Clone(Cloner cloner) { 228 return new TSNEAlgorithm(this, cloner); 229 } 206 this.state = cloner.Clone(original.state); 207 this.iter = original.iter; 208 } 209 public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); } 230 210 public TSNEAlgorithm() { 231 211 var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>()); … … 233 213 Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25))); 234 214 Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " + 235 236 237 238 239 215 "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " + 216 "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " + 217 "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " + 218 "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" + 219 " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0))); 240 220 Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2))); 241 221 Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000))); … … 250 230 Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true))); 251 231 Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50))); 252 Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true))); 253 254 UpdateIntervalParameter.Hidden = true; 232 Parameters[UpdateIntervalParameterName].Hidden = true; 233 255 234 MomentumSwitchIterationParameter.Hidden = true; 256 235 InitialMomentumParameter.Hidden = true; … … 259 238 EtaParameter.Hidden = false; 260 239 Problem = new RegressionProblem(); 261 RegisterParameterEvents(); 262 } 263 #endregion 240 } 241 #endregion 242 243 [Storable] 244 private Dictionary<string, List<int>> dataRowNames; 245 [Storable] 246 private Dictionary<string, ScatterPlotDataRow> dataRows; 247 [Storable] 248 private TSNEStatic<double[]>.TSNEState state; 249 [Storable] 250 private int iter; 264 251 265 252 public override void Prepare() { 266 253 base.Prepare(); 267 dataRowIndices = null; 254 dataRowNames = null; 255 dataRows = null; 268 256 state = null; 269 257 } … … 271 259 protected override void Run(CancellationToken cancellationToken) { 272 260 var problemData = Problem.ProblemData; 273 // set up and initialize everything if necessary 274 var wdist = DistanceFunction as WeightedEuclideanDistance; 275 if (wdist != null) wdist.Initialize(problemData); 261 // set up and initialized everything if necessary 276 262 if (state == null) { 277 263 if (SetSeedRandomly) Seed = new System.Random().Next(); … … 279 265 var dataset = problemData.Dataset; 280 266 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 281 var allindices = Problem.ProblemData.AllIndices.ToArray(); 282 283 // jagged array is required to meet the static method declarations of TSNEStatic<T> 284 var data = Enumerable.Range(0, dataset.Rows).Select(x => new double[allowedInputVariables.Length]).ToArray(); 285 var col = 0; 286 foreach (var s in allowedInputVariables) { 287 var row = 0; 288 foreach (var d in dataset.GetDoubleValues(s)) { 289 data[row][col] = d; 290 row++; 291 } 292 col++; 293 } 294 if (Normalization) data = NormalizeInputData(data); 295 state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization); 296 SetUpResults(allindices); 297 } 298 while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) { 299 if (state.iter % UpdateInterval == 0) Analyze(state); 267 var data = new double[dataset.Rows][]; 268 for (var row = 0; row < dataset.Rows; row++) 269 data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 270 271 if (Normalization) data = NormalizeData(data); 272 273 state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, 274 StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta); 275 276 SetUpResults(data); 277 iter = 0; 278 } 279 for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) { 280 if (iter % UpdateInterval == 0) 281 Analyze(state); 300 282 TSNEStatic<double[]>.Iterate(state); 301 283 } … … 312 294 protected override void RegisterProblemEvents() { 313 295 base.RegisterProblemEvents(); 314 if (Problem == null) return;315 296 Problem.ProblemDataChanged += OnProblemDataChanged; 316 if (Problem.ProblemData == null) return; 317 Problem.ProblemData.Changed += OnPerplexityChanged; 318 Problem.ProblemData.Changed += OnColumnsChanged; 319 if (Problem.ProblemData.Dataset == null) return; 320 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 321 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; 322 } 323 297 } 324 298 protected override void DeregisterProblemEvents() { 325 299 base.DeregisterProblemEvents(); 326 if (Problem == null) return;327 300 Problem.ProblemDataChanged -= OnProblemDataChanged; 328 if (Problem.ProblemData == null) return;329 Problem.ProblemData.Changed -= OnPerplexityChanged;330 Problem.ProblemData.Changed -= OnColumnsChanged;331 if (Problem.ProblemData.Dataset == null) return;332 Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;333 Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;334 }335 336 protected override void OnStopped() {337 base.OnStopped();338 //bwerth: state objects can be very large; avoid state serialization339 state = null;340 dataRowIndices = null;341 301 } 342 302 343 303 private void OnProblemDataChanged(object sender, EventArgs args) { 344 304 if (Problem == null || Problem.ProblemData == null) return; 345 OnPerplexityChanged(this, null);346 OnColumnsChanged(this, null);347 Problem.ProblemData.Changed += OnPerplexityChanged;348 Problem.ProblemData.Changed += OnColumnsChanged;349 if (Problem.ProblemData.Dataset == null) return;350 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;351 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;352 305 if (!Parameters.ContainsKey(ClassesNameParameterName)) return; 353 306 ClassesNameParameter.ValidValues.Clear(); … … 355 308 } 356 309 357 private void OnColumnsChanged(object sender, EventArgs e) {358 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return;359 DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().AdaptToProblemData(Problem.ProblemData);360 }361 362 private void RegisterParameterEvents() {363 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;364 }365 366 private void OnPerplexityChanged(object sender, EventArgs e) {367 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;368 PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));369 }370 310 #endregion 371 311 372 312 #region Helpers 373 private void SetUpResults(IReadOnly List<int> allIndices) {313 private void SetUpResults(IReadOnlyCollection<double[]> data) { 374 314 if (Results == null) return; 375 315 var results = Results; 376 dataRowIndices = new Dictionary<string, IList<int>>(); 316 dataRowNames = new Dictionary<string, List<int>>(); 317 dataRows = new Dictionary<string, ScatterPlotDataRow>(); 377 318 var problemData = Problem.ProblemData; 378 319 320 //color datapoints acording to classes variable (be it double or string) 321 if (problemData.Dataset.VariableNames.Contains(ClassesName)) { 322 if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) { 323 var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray(); 324 for (var i = 0; i < classes.Length; i++) { 325 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 326 dataRowNames[classes[i]].Add(i); 327 } 328 } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) { 329 var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray(); 330 var max = classValues.Max() + 0.1; 331 var min = classValues.Min() - 0.1; 332 const int contours = 8; 333 for (var i = 0; i < contours; i++) { 334 var contourname = GetContourName(i, min, max, contours); 335 dataRowNames.Add(contourname, new List<int>()); 336 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 337 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 338 dataRows[contourname].VisualProperties.PointSize = i + 3; 339 } 340 for (var i = 0; i < classValues.Length; i++) { 341 dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); 342 } 343 } 344 } else { 345 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 346 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 347 } 348 379 349 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 350 else ((IntValue)results[IterationResultName].Value).Value = 0; 351 380 352 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 381 if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 382 if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 383 if (!results.ContainsKey(ErrorPlotResultName)) { 384 var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") { 385 VisualProperties = { 386 XAxisTitle = "UpdateIntervall", 387 YAxisTitle = "Error", 388 YAxisLogScale = true 389 } 390 }; 391 errortable.Rows.Add(new DataRow("Errors")); 392 errortable.Rows["Errors"].VisualProperties.StartIndexZero = true; 393 results.Add(new Result(ErrorPlotResultName, errortable)); 394 } 395 396 //color datapoints acording to classes variable (be it double, datetime or string) 397 if (!problemData.Dataset.VariableNames.Contains(ClassesName)) { 398 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 399 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 400 return; 401 } 402 403 var classificationData = problemData as ClassificationProblemData; 404 if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) { 405 var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n); 406 var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray(); 407 for (var i = 0; i < classes.Length; i++) { 408 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 409 dataRowIndices[classes[i]].Add(i); 410 } 411 } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) { 412 var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray(); 413 for (var i = 0; i < classes.Length; i++) { 414 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 415 dataRowIndices[classes[i]].Add(i); 416 } 417 } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) { 418 var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 419 const int contours = 8; 420 Dictionary<int, string> contourMap; 421 IClusteringModel clusterModel; 422 double[][] borders; 423 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 424 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 425 for (var i = 0; i < contours; i++) { 426 var c = contourorder[i]; 427 var contourname = contourMap[c]; 428 dataRowIndices.Add(contourname, new List<int>()); 429 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 430 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 431 } 432 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 433 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 434 } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) { 435 var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 436 const int contours = 8; 437 Dictionary<int, string> contourMap; 438 IClusteringModel clusterModel; 439 double[][] borders; 440 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 441 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 442 for (var i = 0; i < contours; i++) { 443 var c = contourorder[i]; 444 var contourname = contourMap[c]; 445 dataRowIndices.Add(contourname, new List<int>()); 446 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 447 row.VisualProperties.PointSize = 8; 448 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 449 } 450 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 451 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 452 } else { 453 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 454 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 455 } 353 else ((DoubleValue)results[ErrorResultName].Value).Value = 0; 354 355 if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"))); 356 else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"); 357 358 var plot = results[ErrorPlotResultName].Value as DataTable; 359 if (plot == null) throw new ArgumentException("could not create/access error data table in results collection"); 360 361 if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors")); 362 plot.Rows["errors"].Values.Clear(); 363 plot.Rows["errors"].VisualProperties.StartIndexZero = true; 364 365 results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 366 results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 456 367 } 457 368 … … 461 372 var plot = results[ErrorPlotResultName].Value as DataTable; 462 373 if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection."); 463 var errors = plot.Rows[" Errors"].Values;374 var errors = plot.Rows["errors"].Values; 464 375 var c = tsneState.EvaluateError(); 465 376 errors.Add(c); … … 467 378 ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last(); 468 379 469 var ndata = Normalize ProjectedData(tsneState.newData);380 var ndata = Normalize(tsneState.newData); 470 381 results[DataResultName].Value = new DoubleMatrix(ndata); 471 382 var splot = results[ScatterPlotResultName].Value as ScatterPlot; 472 FillScatterPlot(ndata, splot, dataRowIndices); 473 } 474 475 private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) { 476 foreach (var rowName in dataRowIndices.Keys) { 477 if (!plot.Rows.ContainsKey(rowName)) { 478 plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 479 plot.Rows[rowName].VisualProperties.PointSize = 8; 480 } 481 plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 482 } 483 } 484 485 private static double[,] NormalizeProjectedData(double[,] data) { 383 FillScatterPlot(ndata, splot); 384 } 385 386 private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) { 387 foreach (var rowName in dataRowNames.Keys) { 388 if (!plot.Rows.ContainsKey(rowName)) 389 plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 390 plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 391 } 392 } 393 394 private static double[,] Normalize(double[,] data) { 486 395 var max = new double[data.GetLength(1)]; 487 396 var min = new double[data.GetLength(1)]; … … 489 398 for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i]; 490 399 for (var i = 0; i < data.GetLength(0); i++) 491 for (var j = 0; j < data.GetLength(1); j++) {492 var v = data[i, j];493 max[j] = Math.Max(max[j], v);494 min[j] = Math.Min(min[j], v);495 }400 for (var j = 0; j < data.GetLength(1); j++) { 401 var v = data[i, j]; 402 max[j] = Math.Max(max[j], v); 403 min[j] = Math.Min(min[j], v); 404 } 496 405 for (var i = 0; i < data.GetLength(0); i++) { 497 406 for (var j = 0; j < data.GetLength(1); j++) { 498 407 var d = max[j] - min[j]; 499 var s = data[i, j] - (max[j] + min[j]) / 2; //shift data500 if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible501 else res[i, j] = s / d; //scale data408 var s = data[i, j] - (max[j] + min[j]) / 2; //shift data 409 if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible 410 else res[i, j] = s / d; //scale data 502 411 } 503 412 } … … 505 414 } 506 415 507 private static double[][] Normalize InputData(IReadOnlyList<IReadOnlyList<double>> data) {416 private static double[][] NormalizeData(IReadOnlyList<double[]> data) { 508 417 // as in tSNE implementation by van der Maaten 509 var n = data[0]. Count;418 var n = data[0].Length; 510 419 var mean = new double[n]; 511 420 var max = new double[n]; … … 517 426 for (var i = 0; i < data.Count; i++) { 518 427 nData[i] = new double[n]; 519 for (var j = 0; j < n; j++) 520 nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 428 for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 521 429 } 522 430 return nData; … … 524 432 525 433 private static Color GetHeatMapColor(int contourNr, int noContours) { 526 return ConvertTotalToRgb(0, noContours, contourNr); 527 } 528 529 private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) { 530 var cpd = new ClusteringProblemData((Dataset)data, new[] {target}); 531 contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model; 532 533 borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray(); 534 var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray(); 535 var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray(); 536 foreach (var i in cpd.AllIndices) { 537 var cl = clusters[i] - 1; 538 var clv = targetvalues[i]; 539 if (borders[cl][0] > clv) borders[cl][0] = clv; 540 if (borders[cl][1] < clv) borders[cl][1] = clv; 541 } 542 543 contourNames = new Dictionary<int, string>(); 544 for (var i = 0; i < contours; i++) 545 contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]"); 546 } 547 548 private static Color ConvertTotalToRgb(double low, double high, double cell) { 549 var colorGradient = ColorGradient.Colors; 550 var range = high - low; 551 var h = Math.Min(cell / range * colorGradient.Count, colorGradient.Count - 1); 552 return colorGradient[(int)h]; 434 var q = (double)contourNr / noContours; // q in [0,1] 435 var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0); 436 return c; 437 } 438 439 private static string GetContourName(double value, double min, double max, int noContours) { 440 var size = (max - min) / noContours; 441 var contourNr = (int)((value - min) / size); 442 return GetContourName(contourNr, min, max, noContours); 443 } 444 445 private static string GetContourName(int i, double min, double max, int noContours) { 446 var size = (max - min) / noContours; 447 return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")"; 553 448 } 554 449 #endregion
Note: See TracChangeset
for help on using the changeset viewer.