Changeset 16057 for branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
- Timestamp:
- 08/06/18 18:15:29 (6 years ago)
- Location:
- branches/2839_HiveProjectManagement
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2839_HiveProjectManagement
- Property svn:mergeinfo changed
-
branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
/stable/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 merged eligible /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.4 10321-10322 /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.4 13329-15286 /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 6917-7005 /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.4 9070-13099 /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 4656-4721 /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4 5471-5808 /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.4 5815-6180 /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 4458-4459,4462,4464 /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.4 10085-11101 /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.4 6284-6795 /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.4 5060 /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 11570-12508 /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.4 11130-12721 /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4 13819-14091 /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4 8116-8789 /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.4 10202-10483 /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 5138-5162 /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.4 5175-5192 /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.4 7773-7810 /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.4 6350-6627 /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.4 6828 /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.4 10204-10479 /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.4 5370-5682 /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.4 6829-6865 /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.4 5594-5752 /branches/Weighted TSNE/3.4 15451-15531 /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.4 5959-6341 /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.4 14232-14825 /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4 15377-15681
-
Property
svn:mergeinfo
set to
(toggle deleted branches)
-
branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
r15234 r16057 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 6Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 38 38 namespace HeuristicLab.Algorithms.DataAnalysis { 39 39 /// <summary> 40 /// t- distributed stochastic neighbourhood embedding (tSNE) projects the data in a low dimensional40 /// t-Distributed Stochastic Neighbor Embedding (tSNE) projects the data in a low dimensional 41 41 /// space to allow visual cluster identification. 42 42 /// </summary> 43 [Item("t SNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " +44 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]43 [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " + 44 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")] 45 45 [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)] 46 46 [StorableClass] … … 57 57 } 58 58 59 #region parameter names59 #region Parameter names 60 60 private const string DistanceFunctionParameterName = "DistanceFunction"; 61 61 private const string PerplexityParameterName = "Perplexity"; … … 72 72 private const string ClassesNameParameterName = "ClassesName"; 73 73 private const string NormalizationParameterName = "Normalization"; 74 private const string RandomInitializationParameterName = "RandomInitialization"; 74 75 private const string UpdateIntervalParameterName = "UpdateInterval"; 75 76 #endregion 76 77 77 #region result names78 #region Result names 78 79 private const string IterationResultName = "Iteration"; 79 80 private const string ErrorResultName = "Error"; … … 83 84 #endregion 84 85 85 #region parameter properties86 #region Parameter properties 86 87 public IFixedValueParameter<DoubleValue> PerplexityParameter { 87 get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }88 get { return (IFixedValueParameter<DoubleValue>)Parameters[PerplexityParameterName]; } 88 89 } 89 90 public IFixedValueParameter<PercentValue> ThetaParameter { 90 get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; }91 get { return (IFixedValueParameter<PercentValue>)Parameters[ThetaParameterName]; } 91 92 } 92 93 public IFixedValueParameter<IntValue> NewDimensionsParameter { 93 get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; }94 get { return (IFixedValueParameter<IntValue>)Parameters[NewDimensionsParameterName]; } 94 95 } 95 96 public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter { 96 get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; }97 get { return (IConstrainedValueParameter<IDistance<double[]>>)Parameters[DistanceFunctionParameterName]; } 97 98 } 98 99 public IFixedValueParameter<IntValue> MaxIterationsParameter { 99 get { return Parameters[MaxIterationsParameterName] as IFixedValueParameter<IntValue>; }100 get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; } 100 101 } 101 102 public IFixedValueParameter<IntValue> StopLyingIterationParameter { 102 get { return Parameters[StopLyingIterationParameterName] as IFixedValueParameter<IntValue>; }103 get { return (IFixedValueParameter<IntValue>)Parameters[StopLyingIterationParameterName]; } 103 104 } 104 105 public IFixedValueParameter<IntValue> MomentumSwitchIterationParameter { 105 get { return Parameters[MomentumSwitchIterationParameterName] as IFixedValueParameter<IntValue>; }106 get { return (IFixedValueParameter<IntValue>)Parameters[MomentumSwitchIterationParameterName]; } 106 107 } 107 108 public IFixedValueParameter<DoubleValue> InitialMomentumParameter { 108 get { return Parameters[InitialMomentumParameterName] as IFixedValueParameter<DoubleValue>; }109 get { return (IFixedValueParameter<DoubleValue>)Parameters[InitialMomentumParameterName]; } 109 110 } 110 111 public IFixedValueParameter<DoubleValue> FinalMomentumParameter { 111 get { return Parameters[FinalMomentumParameterName] as IFixedValueParameter<DoubleValue>; }112 get { return (IFixedValueParameter<DoubleValue>)Parameters[FinalMomentumParameterName]; } 112 113 } 113 114 public IFixedValueParameter<DoubleValue> EtaParameter { 114 get { return Parameters[EtaParameterName] as IFixedValueParameter<DoubleValue>; }115 get { return (IFixedValueParameter<DoubleValue>)Parameters[EtaParameterName]; } 115 116 } 116 117 public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter { 117 get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }118 get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; } 118 119 } 119 120 public IFixedValueParameter<IntValue> SeedParameter { 120 get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }121 get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; } 121 122 } 122 123 public IConstrainedValueParameter<StringValue> ClassesNameParameter { 123 get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; }124 get { return (IConstrainedValueParameter<StringValue>)Parameters[ClassesNameParameterName]; } 124 125 } 125 126 public IFixedValueParameter<BoolValue> NormalizationParameter { 126 get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; } 127 get { return (IFixedValueParameter<BoolValue>)Parameters[NormalizationParameterName]; } 128 } 129 public IFixedValueParameter<BoolValue> RandomInitializationParameter { 130 get { return (IFixedValueParameter<BoolValue>)Parameters[RandomInitializationParameterName]; } 127 131 } 128 132 public IFixedValueParameter<IntValue> UpdateIntervalParameter { 129 get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }133 get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; } 130 134 } 131 135 #endregion … … 187 191 set { NormalizationParameter.Value.Value = value; } 188 192 } 189 193 public bool RandomInitialization { 194 get { return RandomInitializationParameter.Value.Value; } 195 set { RandomInitializationParameter.Value.Value = value; } 196 } 190 197 public int UpdateInterval { 191 198 get { return UpdateIntervalParameter.Value.Value; } … … 194 201 #endregion 195 202 203 #region Storable poperties 204 [Storable] 205 private Dictionary<string, IList<int>> dataRowIndices; 206 [Storable] 207 private TSNEStatic<double[]>.TSNEState state; 208 #endregion 209 196 210 #region Constructors & Cloning 197 211 [StorableConstructor] 198 212 private TSNEAlgorithm(bool deserializing) : base(deserializing) { } 199 213 214 [StorableHook(HookType.AfterDeserialization)] 215 private void AfterDeserialization() { 216 if (!Parameters.ContainsKey(RandomInitializationParameterName)) 217 Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true))); 218 RegisterParameterEvents(); 219 } 200 220 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { 201 if (original.dataRowNames != null) 202 this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames); 203 if (original.dataRows != null) 204 this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value)); 221 if (original.dataRowIndices != null) 222 dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices); 205 223 if (original.state != null) 206 this.state = cloner.Clone(original.state); 207 this.iter = original.iter; 208 } 209 public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); } 224 state = cloner.Clone(original.state); 225 RegisterParameterEvents(); 226 } 227 public override IDeepCloneable Clone(Cloner cloner) { 228 return new TSNEAlgorithm(this, cloner); 229 } 210 230 public TSNEAlgorithm() { 211 231 var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>()); … … 213 233 Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25))); 214 234 Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " + 215 "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +216 "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +217 "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +218 "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +219 " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));235 "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " + 236 "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " + 237 "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " + 238 "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" + 239 " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0))); 220 240 Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2))); 221 241 Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000))); … … 230 250 Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true))); 231 251 Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50))); 232 Parameters[UpdateIntervalParameterName].Hidden = true; 233 252 Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true))); 253 254 UpdateIntervalParameter.Hidden = true; 234 255 MomentumSwitchIterationParameter.Hidden = true; 235 256 InitialMomentumParameter.Hidden = true; … … 238 259 EtaParameter.Hidden = false; 239 260 Problem = new RegressionProblem(); 240 } 241 #endregion 242 243 [Storable] 244 private Dictionary<string, List<int>> dataRowNames; 245 [Storable] 246 private Dictionary<string, ScatterPlotDataRow> dataRows; 247 [Storable] 248 private TSNEStatic<double[]>.TSNEState state; 249 [Storable] 250 private int iter; 261 RegisterParameterEvents(); 262 } 263 #endregion 251 264 252 265 public override void Prepare() { 253 266 base.Prepare(); 254 dataRowNames = null; 255 dataRows = null; 267 dataRowIndices = null; 256 268 state = null; 257 269 } … … 259 271 protected override void Run(CancellationToken cancellationToken) { 260 272 var problemData = Problem.ProblemData; 261 // set up and initialized everything if necessary 273 // set up and initialize everything if necessary 274 var wdist = DistanceFunction as WeightedEuclideanDistance; 275 if (wdist != null) wdist.Initialize(problemData); 262 276 if (state == null) { 263 277 if (SetSeedRandomly) Seed = new System.Random().Next(); … … 265 279 var dataset = problemData.Dataset; 266 280 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 267 var data = new double[dataset.Rows][]; 268 for (var row = 0; row < dataset.Rows; row++) 269 data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 270 271 if (Normalization) data = NormalizeData(data); 272 273 state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, 274 StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta); 275 276 SetUpResults(data); 277 iter = 0; 278 } 279 for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) { 280 if (iter % UpdateInterval == 0) 281 Analyze(state); 281 var allindices = Problem.ProblemData.AllIndices.ToArray(); 282 283 // jagged array is required to meet the static method declarations of TSNEStatic<T> 284 var data = Enumerable.Range(0, dataset.Rows).Select(x => new double[allowedInputVariables.Length]).ToArray(); 285 var col = 0; 286 foreach (var s in allowedInputVariables) { 287 var row = 0; 288 foreach (var d in dataset.GetDoubleValues(s)) { 289 data[row][col] = d; 290 row++; 291 } 292 col++; 293 } 294 if (Normalization) data = NormalizeInputData(data); 295 state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization); 296 SetUpResults(allindices); 297 } 298 while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) { 299 if (state.iter % UpdateInterval == 0) Analyze(state); 282 300 TSNEStatic<double[]>.Iterate(state); 283 301 } … … 294 312 protected override void RegisterProblemEvents() { 295 313 base.RegisterProblemEvents(); 314 if (Problem == null) return; 296 315 Problem.ProblemDataChanged += OnProblemDataChanged; 297 } 316 if (Problem.ProblemData == null) return; 317 Problem.ProblemData.Changed += OnPerplexityChanged; 318 Problem.ProblemData.Changed += OnColumnsChanged; 319 if (Problem.ProblemData.Dataset == null) return; 320 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 321 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; 322 } 323 298 324 protected override void DeregisterProblemEvents() { 299 325 base.DeregisterProblemEvents(); 326 if (Problem == null) return; 300 327 Problem.ProblemDataChanged -= OnProblemDataChanged; 328 if (Problem.ProblemData == null) return; 329 Problem.ProblemData.Changed -= OnPerplexityChanged; 330 Problem.ProblemData.Changed -= OnColumnsChanged; 331 if (Problem.ProblemData.Dataset == null) return; 332 Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged; 333 Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged; 334 } 335 336 protected override void OnStopped() { 337 base.OnStopped(); 338 //bwerth: state objects can be very large; avoid state serialization 339 state = null; 340 dataRowIndices = null; 301 341 } 302 342 303 343 private void OnProblemDataChanged(object sender, EventArgs args) { 304 344 if (Problem == null || Problem.ProblemData == null) return; 345 OnPerplexityChanged(this, null); 346 OnColumnsChanged(this, null); 347 Problem.ProblemData.Changed += OnPerplexityChanged; 348 Problem.ProblemData.Changed += OnColumnsChanged; 349 if (Problem.ProblemData.Dataset == null) return; 350 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 351 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; 305 352 if (!Parameters.ContainsKey(ClassesNameParameterName)) return; 306 353 ClassesNameParameter.ValidValues.Clear(); … … 308 355 } 309 356 357 private void OnColumnsChanged(object sender, EventArgs e) { 358 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return; 359 DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().AdaptToProblemData(Problem.ProblemData); 360 } 361 362 private void RegisterParameterEvents() { 363 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged; 364 } 365 366 private void OnPerplexityChanged(object sender, EventArgs e) { 367 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return; 368 PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity)); 369 } 310 370 #endregion 311 371 312 372 #region Helpers 313 private void SetUpResults(IReadOnly Collection<double[]> data) {373 private void SetUpResults(IReadOnlyList<int> allIndices) { 314 374 if (Results == null) return; 315 375 var results = Results; 316 dataRowNames = new Dictionary<string, List<int>>(); 317 dataRows = new Dictionary<string, ScatterPlotDataRow>(); 376 dataRowIndices = new Dictionary<string, IList<int>>(); 318 377 var problemData = Problem.ProblemData; 319 378 320 //color datapoints acording to classes variable (be it double or string) 321 if (problemData.Dataset.VariableNames.Contains(ClassesName)) { 322 if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) { 323 var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray(); 324 for (var i = 0; i < classes.Length; i++) { 325 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 326 dataRowNames[classes[i]].Add(i); 379 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 380 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 381 if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 382 if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 383 if (!results.ContainsKey(ErrorPlotResultName)) { 384 var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") { 385 VisualProperties = { 386 XAxisTitle = "UpdateIntervall", 387 YAxisTitle = "Error", 388 YAxisLogScale = true 327 389 } 328 } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) { 329 var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray(); 330 var max = classValues.Max() + 0.1; 331 var min = classValues.Min() - 0.1; 332 const int contours = 8; 333 for (var i = 0; i < contours; i++) { 334 var contourname = GetContourName(i, min, max, contours); 335 dataRowNames.Add(contourname, new List<int>()); 336 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 337 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 338 dataRows[contourname].VisualProperties.PointSize = i + 3; 339 } 340 for (var i = 0; i < classValues.Length; i++) { 341 dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); 342 } 343 } 390 }; 391 errortable.Rows.Add(new DataRow("Errors")); 392 errortable.Rows["Errors"].VisualProperties.StartIndexZero = true; 393 results.Add(new Result(ErrorPlotResultName, errortable)); 394 } 395 396 //color datapoints acording to classes variable (be it double, datetime or string) 397 if (!problemData.Dataset.VariableNames.Contains(ClassesName)) { 398 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 399 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 400 return; 401 } 402 403 var classificationData = problemData as ClassificationProblemData; 404 if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) { 405 var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n); 406 var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray(); 407 for (var i = 0; i < classes.Length; i++) { 408 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 409 dataRowIndices[classes[i]].Add(i); 410 } 411 } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) { 412 var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray(); 413 for (var i = 0; i < classes.Length; i++) { 414 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 415 dataRowIndices[classes[i]].Add(i); 416 } 417 } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) { 418 var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 419 const int contours = 8; 420 Dictionary<int, string> contourMap; 421 IClusteringModel clusterModel; 422 double[][] borders; 423 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 424 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 425 for (var i = 0; i < contours; i++) { 426 var c = contourorder[i]; 427 var contourname = contourMap[c]; 428 dataRowIndices.Add(contourname, new List<int>()); 429 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 430 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 431 } 432 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 433 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 434 } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) { 435 var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 436 const int contours = 8; 437 Dictionary<int, string> contourMap; 438 IClusteringModel clusterModel; 439 double[][] borders; 440 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 441 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 442 for (var i = 0; i < contours; i++) { 443 var c = contourorder[i]; 444 var contourname = contourMap[c]; 445 dataRowIndices.Add(contourname, new List<int>()); 446 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 447 row.VisualProperties.PointSize = 8; 448 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 449 } 450 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 451 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 344 452 } else { 345 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 346 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 347 } 348 349 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 350 else ((IntValue)results[IterationResultName].Value).Value = 0; 351 352 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 353 else ((DoubleValue)results[ErrorResultName].Value).Value = 0; 354 355 if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"))); 356 else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"); 357 358 var plot = results[ErrorPlotResultName].Value as DataTable; 359 if (plot == null) throw new ArgumentException("could not create/access error data table in results collection"); 360 361 if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors")); 362 plot.Rows["errors"].Values.Clear(); 363 plot.Rows["errors"].VisualProperties.StartIndexZero = true; 364 365 results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 366 results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 453 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 454 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 455 } 367 456 } 368 457 … … 372 461 var plot = results[ErrorPlotResultName].Value as DataTable; 373 462 if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection."); 374 var errors = plot.Rows[" errors"].Values;463 var errors = plot.Rows["Errors"].Values; 375 464 var c = tsneState.EvaluateError(); 376 465 errors.Add(c); … … 378 467 ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last(); 379 468 380 var ndata = Normalize (tsneState.newData);469 var ndata = NormalizeProjectedData(tsneState.newData); 381 470 results[DataResultName].Value = new DoubleMatrix(ndata); 382 471 var splot = results[ScatterPlotResultName].Value as ScatterPlot; 383 FillScatterPlot(ndata, splot); 384 } 385 386 private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) { 387 foreach (var rowName in dataRowNames.Keys) { 388 if (!plot.Rows.ContainsKey(rowName)) 389 plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 390 plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 391 } 392 } 393 394 private static double[,] Normalize(double[,] data) { 472 FillScatterPlot(ndata, splot, dataRowIndices); 473 } 474 475 private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) { 476 foreach (var rowName in dataRowIndices.Keys) { 477 if (!plot.Rows.ContainsKey(rowName)) { 478 plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 479 plot.Rows[rowName].VisualProperties.PointSize = 8; 480 } 481 plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 482 } 483 } 484 485 private static double[,] NormalizeProjectedData(double[,] data) { 395 486 var max = new double[data.GetLength(1)]; 396 487 var min = new double[data.GetLength(1)]; … … 398 489 for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i]; 399 490 for (var i = 0; i < data.GetLength(0); i++) 400 401 402 403 404 491 for (var j = 0; j < data.GetLength(1); j++) { 492 var v = data[i, j]; 493 max[j] = Math.Max(max[j], v); 494 min[j] = Math.Min(min[j], v); 495 } 405 496 for (var i = 0; i < data.GetLength(0); i++) { 406 497 for (var j = 0; j < data.GetLength(1); j++) { 407 498 var d = max[j] - min[j]; 408 var s = data[i, j] - (max[j] + min[j]) / 2; 409 if (d.IsAlmost(0)) res[i, j] = data[i, j]; 410 else res[i, j] = s / d; 499 var s = data[i, j] - (max[j] + min[j]) / 2; //shift data 500 if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible 501 else res[i, j] = s / d; //scale data 411 502 } 412 503 } … … 414 505 } 415 506 416 private static double[][] Normalize Data(IReadOnlyList<double[]> data) {507 private static double[][] NormalizeInputData(IReadOnlyList<IReadOnlyList<double>> data) { 417 508 // as in tSNE implementation by van der Maaten 418 var n = data[0]. Length;509 var n = data[0].Count; 419 510 var mean = new double[n]; 420 511 var max = new double[n]; … … 426 517 for (var i = 0; i < data.Count; i++) { 427 518 nData[i] = new double[n]; 428 for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 519 for (var j = 0; j < n; j++) 520 nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 429 521 } 430 522 return nData; … … 432 524 433 525 private static Color GetHeatMapColor(int contourNr, int noContours) { 434 var q = (double)contourNr / noContours; // q in [0,1] 435 var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0); 436 return c; 437 } 438 439 private static string GetContourName(double value, double min, double max, int noContours) { 440 var size = (max - min) / noContours; 441 var contourNr = (int)((value - min) / size); 442 return GetContourName(contourNr, min, max, noContours); 443 } 444 445 private static string GetContourName(int i, double min, double max, int noContours) { 446 var size = (max - min) / noContours; 447 return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")"; 526 return ConvertTotalToRgb(0, noContours, contourNr); 527 } 528 529 private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) { 530 var cpd = new ClusteringProblemData((Dataset)data, new[] {target}); 531 contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model; 532 533 borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray(); 534 var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray(); 535 var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray(); 536 foreach (var i in cpd.AllIndices) { 537 var cl = clusters[i] - 1; 538 var clv = targetvalues[i]; 539 if (borders[cl][0] > clv) borders[cl][0] = clv; 540 if (borders[cl][1] < clv) borders[cl][1] = clv; 541 } 542 543 contourNames = new Dictionary<int, string>(); 544 for (var i = 0; i < contours; i++) 545 contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]"); 546 } 547 548 private static Color ConvertTotalToRgb(double low, double high, double cell) { 549 var colorGradient = ColorGradient.Colors; 550 var range = high - low; 551 var h = Math.Min(cell / range * colorGradient.Count, colorGradient.Count - 1); 552 return colorGradient[(int)h]; 448 553 } 449 554 #endregion
Note: See TracChangeset
for help on using the changeset viewer.