Changeset 15556
- Timestamp:
- 12/21/17 09:14:27 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
r15551 r15556 42 42 /// </summary> 43 43 [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " + 44 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]44 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")] 45 45 [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)] 46 46 [StorableClass] … … 203 203 #region Storable poperties 204 204 [Storable] 205 private Dictionary<string, List<int>> dataRowNames; 206 [Storable] 207 private Dictionary<string, ScatterPlotDataRow> dataRows; 205 private Dictionary<string, IList<int>> dataRowIndices; 208 206 [Storable] 209 207 private TSNEStatic<double[]>.TSNEState state; 210 [Storable]211 private int iter;212 208 #endregion 213 209 … … 223 219 } 224 220 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { 225 if (original.dataRowNames != null) 226 dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames); 227 if (original.dataRows != null) 228 dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value)); 221 if (original.dataRowIndices != null) 222 dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices); 229 223 if (original.state != null) 230 224 state = cloner.Clone(original.state); 231 iter = original.iter;232 225 RegisterParameterEvents(); 233 226 } … … 259 252 Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true))); 260 253 261 Parameters[UpdateIntervalParameterName].Hidden = true; 262 254 UpdateIntervalParameter.Hidden = true; 263 255 MomentumSwitchIterationParameter.Hidden = true; 264 256 InitialMomentumParameter.Hidden = true; … … 273 265 public override void Prepare() { 274 266 base.Prepare(); 275 dataRowNames = null; 276 dataRows = null; 267 dataRowIndices = null; 277 268 state = null; 278 269 } … … 301 292 col++; 302 293 } 303 304 294 if (Normalization) data = NormalizeInputData(data); 305 295 state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization); 306 296 SetUpResults(allindices); 307 iter = 0; 308 } 309 for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) { 310 if (iter % UpdateInterval == 0) Analyze(state); 297 } 298 while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) { 299 if (state.iter % UpdateInterval == 0) Analyze(state); 311 300 TSNEStatic<double[]>.Iterate(state); 312 301 } … … 324 313 base.RegisterProblemEvents(); 325 314 if (Problem == null) return; 315 Problem.ProblemDataChanged += OnProblemDataChanged; 316 if (Problem.ProblemData == null) return; 317 Problem.ProblemData.Changed += OnPerplexityChanged; 318 Problem.ProblemData.Changed += OnColumnsChanged; 319 if (Problem.ProblemData.Dataset == null) return; 320 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 321 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; 322 } 323 324 protected override void DeregisterProblemEvents() { 325 base.DeregisterProblemEvents(); 326 if (Problem == null) return; 326 327 Problem.ProblemDataChanged -= OnProblemDataChanged; 327 Problem.ProblemDataChanged += OnProblemDataChanged;328 328 if (Problem.ProblemData == null) return; 329 329 Problem.ProblemData.Changed -= OnPerplexityChanged; 330 330 Problem.ProblemData.Changed -= OnColumnsChanged; 331 Problem.ProblemData.Changed += OnPerplexityChanged;332 Problem.ProblemData.Changed += OnColumnsChanged;333 331 if (Problem.ProblemData.Dataset == null) return; 334 332 Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged; 335 333 Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged; 336 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;337 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;338 }339 340 protected override void DeregisterProblemEvents() {341 base.DeregisterProblemEvents();342 Problem.ProblemDataChanged -= OnProblemDataChanged;343 334 } 344 335 345 336 protected override void OnStopped() { 346 337 base.OnStopped(); 338 //bwerth: state objects can be very large; avoid state serialization 347 339 state = null; 348 dataRowNames = null; 349 dataRows = null; 340 dataRowIndices = null; 350 341 } 351 342 … … 354 345 OnPerplexityChanged(this, null); 355 346 OnColumnsChanged(this, null); 356 Problem.ProblemData.Changed -= OnPerplexityChanged;357 347 Problem.ProblemData.Changed += OnPerplexityChanged; 358 Problem.ProblemData.Changed -= OnColumnsChanged;359 348 Problem.ProblemData.Changed += OnColumnsChanged; 360 349 if (Problem.ProblemData.Dataset == null) return; 361 Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;362 Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;363 350 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 364 351 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; … … 374 361 375 362 private void RegisterParameterEvents() { 376 PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;377 363 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged; 378 364 } … … 380 366 private void OnPerplexityChanged(object sender, EventArgs e) { 381 367 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return; 382 PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;383 368 PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity)); 384 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;385 369 } 386 370 #endregion … … 390 374 if (Results == null) return; 391 375 var results = Results; 392 dataRowNames = new Dictionary<string, List<int>>(); 393 dataRows = new Dictionary<string, ScatterPlotDataRow>(); 376 dataRowIndices = new Dictionary<string, IList<int>>(); 394 377 var problemData = Problem.ProblemData; 395 378 … … 411 394 } 412 395 413 //color datapoints acording to classes variable (be it double or string)396 //color datapoints acording to classes variable (be it double, datetime or string) 414 397 if (!problemData.Dataset.VariableNames.Contains(ClassesName)) { 415 dataRow Names.Add("Training", problemData.TrainingIndices.ToList());416 dataRow Names.Add("Test", problemData.TestIndices.ToList());398 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 399 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 417 400 return; 418 401 } 402 419 403 var classificationData = problemData as ClassificationProblemData; 420 404 if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) { … … 422 406 var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray(); 423 407 for (var i = 0; i < classes.Length; i++) { 424 if (!dataRow Names.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());425 dataRow Names[classes[i]].Add(i);408 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 409 dataRowIndices[classes[i]].Add(i); 426 410 } 427 411 } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) { 428 412 var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray(); 429 413 for (var i = 0; i < classes.Length; i++) { 430 if (!dataRow Names.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());431 dataRow Names[classes[i]].Add(i);414 if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>()); 415 dataRowIndices[classes[i]].Add(i); 432 416 } 433 417 } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) { … … 442 426 var c = contourorder[i]; 443 427 var contourname = contourMap[c]; 444 dataRow Names.Add(contourname, new List<int>());445 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));446 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);428 dataRowIndices.Add(contourname, new List<int>()); 429 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 430 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 447 431 } 448 432 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 449 for (var i = 0; i < clusterdata.Rows; i++) dataRow Names[contourMap[allClusters[i] - 1]].Add(i);433 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 450 434 } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) { 451 435 var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); … … 459 443 var c = contourorder[i]; 460 444 var contourname = contourMap[c]; 461 dataRowNames.Add(contourname, new List<int>()); 462 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 463 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 445 dataRowIndices.Add(contourname, new List<int>()); 446 var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}}; 447 row.VisualProperties.PointSize = 8; 448 ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row); 464 449 } 465 450 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 466 for (var i = 0; i < clusterdata.Rows; i++) dataRow Names[contourMap[allClusters[i] - 1]].Add(i);451 for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i); 467 452 } else { 468 dataRow Names.Add("Training", problemData.TrainingIndices.ToList());469 dataRow Names.Add("Test", problemData.TestIndices.ToList());453 dataRowIndices.Add("Training", problemData.TrainingIndices.ToList()); 454 dataRowIndices.Add("Test", problemData.TestIndices.ToList()); 470 455 } 471 456 } … … 485 470 results[DataResultName].Value = new DoubleMatrix(ndata); 486 471 var splot = results[ScatterPlotResultName].Value as ScatterPlot; 487 FillScatterPlot(ndata, splot );488 } 489 490 private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {491 foreach (var rowName in dataRow Names.Keys) {472 FillScatterPlot(ndata, splot, dataRowIndices); 473 } 474 475 private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) { 476 foreach (var rowName in dataRowIndices.Keys) { 492 477 if (!plot.Rows.ContainsKey(rowName)) { 493 plot.Rows.Add( dataRows.ContainsKey(rowName) ? dataRows[rowName] :new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));478 plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 494 479 plot.Rows[rowName].VisualProperties.PointSize = 8; 495 480 } 496 plot.Rows[rowName].Points.Replace(dataRow Names[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));481 plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 497 482 } 498 483 } … … 504 489 for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i]; 505 490 for (var i = 0; i < data.GetLength(0); i++) 506 507 508 509 510 491 for (var j = 0; j < data.GetLength(1); j++) { 492 var v = data[i, j]; 493 max[j] = Math.Max(max[j], v); 494 min[j] = Math.Min(min[j], v); 495 } 511 496 for (var i = 0; i < data.GetLength(0); i++) { 512 497 for (var j = 0; j < data.GetLength(1); j++) { … … 532 517 for (var i = 0; i < data.Count; i++) { 533 518 nData[i] = new double[n]; 534 for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 519 for (var j = 0; j < n; j++) 520 nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j]; 535 521 } 536 522 return nData; … … 542 528 543 529 private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) { 544 var cpd = new ClusteringProblemData((Dataset)data, new[] { target});530 var cpd = new ClusteringProblemData((Dataset)data, new[] {target}); 545 531 contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model; 546 532 547 borders = Enumerable.Range(0, contours).Select(x => new[] { double.MaxValue, double.MinValue}).ToArray();533 borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray(); 548 534 var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray(); 549 535 var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
Note: See TracChangeset
for help on using the changeset viewer.