Changeset 15455
- Timestamp:
- 11/07/17 13:15:55 (7 years ago)
- Location:
- branches/Weighted TSNE/3.4/TSNE
- Files:
-
- 1 added
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/Weighted TSNE/3.4/TSNE/TSNEAlgorithm.cs
r15451 r15455 29 29 using HeuristicLab.Core; 30 30 using HeuristicLab.Data; 31 using HeuristicLab.Encodings.RealVectorEncoding; 31 32 using HeuristicLab.Optimization; 32 33 using HeuristicLab.Parameters; … … 57 58 } 58 59 59 #region parameter names60 #region Parameter names 60 61 private const string DistanceFunctionParameterName = "DistanceFunction"; 61 62 private const string PerplexityParameterName = "Perplexity"; … … 76 77 #endregion 77 78 78 #region result names79 #region Result names 79 80 private const string IterationResultName = "Iteration"; 80 81 private const string ErrorResultName = "Error"; … … 84 85 #endregion 85 86 86 #region parameter properties87 #region Parameter properties 87 88 public IFixedValueParameter<DoubleValue> PerplexityParameter { 88 89 get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; } … … 202 203 #endregion 203 204 205 #region Storable poperties 206 [Storable] 207 private Dictionary<string, List<int>> dataRowNames; 208 [Storable] 209 private Dictionary<string, ScatterPlotDataRow> dataRows; 210 [Storable] 211 private TSNEStatic<double[]>.TSNEState state; 212 [Storable] 213 private int iter; 214 #endregion 215 204 216 #region Constructors & Cloning 205 217 [StorableConstructor] 206 218 private TSNEAlgorithm(bool deserializing) : base(deserializing) { } 207 219 220 [StorableHook(HookType.AfterDeserialization)] 221 private void AfterDeserialization() { 222 RegisterParameterEvents(); 223 } 208 224 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { 209 225 if (original.dataRowNames != null) … … 250 266 EtaParameter.Hidden = false; 251 267 Problem = new RegressionProblem(); 252 } 253 #endregion 254 255 [Storable] 256 private Dictionary<string, List<int>> dataRowNames; 257 [Storable] 258 private Dictionary<string, ScatterPlotDataRow> dataRows; 259 [Storable] 260 private TSNEStatic<double[]>.TSNEState state; 261 [Storable] 262 private int iter; 268 RegisterParameterEvents(); 269 } 270 #endregion 263 271 264 272 public override void Prepare() { … … 285 293 } 286 294 for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) { 287 if (iter % UpdateInterval == 0) 288 Analyze(state); 295 if (iter % UpdateInterval == 0) Analyze(state); 289 296 TSNEStatic<double[]>.Iterate(state); 290 297 } 291 298 Analyze(state); 292 dataRowNames = null;293 dataRows = null;294 state = null;295 299 } 296 300 … … 306 310 Problem.ProblemDataChanged += OnProblemDataChanged; 307 311 } 312 308 313 protected override void DeregisterProblemEvents() { 309 314 base.DeregisterProblemEvents(); … … 311 316 } 312 317 318 protected override void OnStopped() { 319 base.OnStopped(); 320 state = null; 321 dataRowNames = null; 322 dataRows = null; 323 } 324 313 325 private void OnProblemDataChanged(object sender, EventArgs args) { 314 326 if (Problem == null || Problem.ProblemData == null) return; 327 OnPerplexityChanged(this, null); 328 Problem.ProblemData.Changed += OnPerplexityChanged; 329 Problem.ProblemData.Changed += OnColumnsChanged; 330 Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged; 331 Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged; 315 332 if (!Parameters.ContainsKey(ClassesNameParameterName)) return; 316 333 ClassesNameParameter.ValidValues.Clear(); 317 334 foreach (var input in Problem.ProblemData.InputVariables) ClassesNameParameter.ValidValues.Add(input); 335 } 336 private void OnColumnsChanged(object sender, EventArgs e) { 337 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return; 338 DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().Weights = new RealVector(Problem.ProblemData.AllowedInputVariables.Select(x => 1.0).ToArray()); 339 } 340 341 private void RegisterParameterEvents() { 342 PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged; 343 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged; 344 } 345 346 private void OnPerplexityChanged(object sender, EventArgs e) { 347 if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return; 348 PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged; 349 PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity)); 350 PerplexityParameter.Value.ValueChanged += OnPerplexityChanged; 318 351 } 319 352 #endregion … … 327 360 var problemData = Problem.ProblemData; 328 361 362 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 363 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 364 if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 365 if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 366 if (!results.ContainsKey(ErrorPlotResultName)) { 367 var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") { 368 VisualProperties = { 369 XAxisTitle = "UpdateIntervall", 370 YAxisTitle = "Error", 371 YAxisLogScale = true 372 } 373 }; 374 errortable.Rows.Add(new DataRow("Errors")); 375 errortable.Rows["Errors"].VisualProperties.StartIndexZero = true; 376 results.Add(new Result(ErrorPlotResultName, errortable)); 377 } 378 329 379 //color datapoints acording to classes variable (be it double or string) 330 if (problemData.Dataset.VariableNames.Contains(ClassesName)) { 331 var classificationData = problemData as ClassificationProblemData; 332 if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) { 333 var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n); 334 var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray(); 335 for (var i = 0; i < classes.Length; i++) { 336 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 337 dataRowNames[classes[i]].Add(i); 338 } 380 if (!problemData.Dataset.VariableNames.Contains(ClassesName)) { 381 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 382 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 383 return; 384 } 385 var classificationData = problemData as ClassificationProblemData; 386 if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) { 387 var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n); 388 var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray(); 389 for (var i = 0; i < classes.Length; i++) { 390 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 391 dataRowNames[classes[i]].Add(i); 339 392 } 340 else if (((Dataset) problemData.Dataset).VariableHasType<string>(ClassesName)) {341 var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();342 for (var i = 0; i < classes.Length; i++) {343 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());344 dataRowNames[classes[i]].Add(i);345 }393 } 394 else if (((Dataset) problemData.Dataset).VariableHasType<string>(ClassesName)) { 395 var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray(); 396 for (var i = 0; i < classes.Length; i++) { 397 if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 398 dataRowNames[classes[i]].Add(i); 346 399 } 347 else if (((Dataset) problemData.Dataset).VariableHasType<double>(ClassesName)) { 348 var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 349 const int contours = 8; 350 Dictionary<int, string> contourMap; 351 IClusteringModel clusterModel; 352 double[][] borders; 353 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 354 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 355 for (var i = 0; i < contours; i++) { 356 var c = contourorder[i]; 357 var contourname = contourMap[c]; 358 dataRowNames.Add(contourname, new List<int>()); 359 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 360 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 361 dataRows[contourname].VisualProperties.PointSize = i + 3; 362 } 363 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 364 for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i); 400 } 401 else if (((Dataset) problemData.Dataset).VariableHasType<double>(ClassesName)) { 402 var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 403 const int contours = 8; 404 Dictionary<int, string> contourMap; 405 IClusteringModel clusterModel; 406 double[][] borders; 407 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 408 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 409 for (var i = 0; i < contours; i++) { 410 var c = contourorder[i]; 411 var contourname = contourMap[c]; 412 dataRowNames.Add(contourname, new List<int>()); 413 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 414 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 365 415 } 366 else if (((Dataset) problemData.Dataset).VariableHasType<DateTime>(ClassesName)) { 367 var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 368 const int contours = 8; 369 Dictionary<int, string> contourMap; 370 IClusteringModel clusterModel; 371 double[][] borders; 372 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 373 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 374 for (var i = 0; i < contours; i++) { 375 var c = contourorder[i]; 376 var contourname = contourMap[c]; 377 dataRowNames.Add(contourname, new List<int>()); 378 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 379 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 380 dataRows[contourname].VisualProperties.PointSize = i + 3; 381 } 382 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 383 for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i); 416 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 417 for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i); 418 } 419 else if (((Dataset) problemData.Dataset).VariableHasType<DateTime>(ClassesName)) { 420 var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList())); 421 const int contours = 8; 422 Dictionary<int, string> contourMap; 423 IClusteringModel clusterModel; 424 double[][] borders; 425 CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders); 426 var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray(); 427 for (var i = 0; i < contours; i++) { 428 var c = contourorder[i]; 429 var contourname = contourMap[c]; 430 dataRowNames.Add(contourname, new List<int>()); 431 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 432 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 384 433 } 385 else { 386 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 387 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 388 } 389 390 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 391 else ((IntValue) results[IterationResultName].Value).Value = 0; 392 393 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 394 else ((DoubleValue) results[ErrorResultName].Value).Value = 0; 395 396 if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"))); 397 else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"); 398 399 var plot = results[ErrorPlotResultName].Value as DataTable; 400 if (plot == null) throw new ArgumentException("could not create/access error data table in results collection"); 401 402 if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors")); 403 plot.Rows["errors"].Values.Clear(); 404 plot.Rows["errors"].VisualProperties.StartIndexZero = true; 405 406 results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 407 results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 434 var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray(); 435 for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i); 436 } 437 else { 438 dataRowNames.Add("Training", problemData.TrainingIndices.ToList()); 439 dataRowNames.Add("Test", problemData.TestIndices.ToList()); 408 440 } 409 441 } … … 414 446 var plot = results[ErrorPlotResultName].Value as DataTable; 415 447 if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection."); 416 var errors = plot.Rows[" errors"].Values;448 var errors = plot.Rows["Errors"].Values; 417 449 var c = tsneState.EvaluateError(); 418 450 errors.Add(c); … … 430 462 if (!plot.Rows.ContainsKey(rowName)) { 431 463 plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 432 plot.Rows[rowName].VisualProperties.PointSize = 6;464 plot.Rows[rowName].VisualProperties.PointSize = 8; 433 465 } 434 466 plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); … … 504 536 } 505 537 538 //taken from https://stackoverflow.com/a/17099130 506 539 private static Color HsVtoRgb(double hue, double saturation, double value) { 507 while (hue > 1 f) { hue -= 1f; }508 while (hue < 0 f) { hue += 1f; }509 while (saturation > 1 f) { saturation -= 1f; }510 while (saturation < 0 f) { saturation += 1f; }511 while (value > 1 f) { value -= 1f; }512 while (value < 0 f) { value += 1f; }513 if (hue > 0.999 f) { hue = 0.999f; }514 if (hue < 0.001 f) { hue = 0.001f; }515 if (saturation > 0.999 f) { saturation = 0.999f; }516 if (saturation < 0.001 f) { return Color.FromArgb((int) (value * 255f), (int) (value * 255f), (int) (value * 255f)); }517 if (value > 0.999 f) { value = 0.999f; }518 if (value < 0.001 f) { value = 0.001f; }519 520 var h6 = hue * 6 f;521 if (h6.IsAlmost(6 f)) { h6 = 0f; }540 while (hue > 1.0) { hue -= 1.0; } 541 while (hue < 0.0) { hue += 1.0; } 542 while (saturation > 1.0) { saturation -= 1.0; } 543 while (saturation < 0.0) { saturation += 1.0; } 544 while (value > 1.0) { value -= 1.0; } 545 while (value < 0.0) { value += 1.0; } 546 if (hue > 0.999) { hue = 0.999; } 547 if (hue < 0.001) { hue = 0.001; } 548 if (saturation > 0.999) { saturation = 0.999; } 549 if (saturation < 0.001) { return Color.FromArgb((int) (value * 255.0), (int) (value * 255.0), (int) (value * 255.0)); } 550 if (value > 0.999) { value = 0.999; } 551 if (value < 0.001) { value = 0.001; } 552 553 var h6 = hue * 6.0; 554 if (h6.IsAlmost(6.0)) { h6 = 0.0; } 522 555 var ihue = (int) h6; 523 var p = value * (1 f- saturation);524 var q = value * (1 f- saturation * (h6 - ihue));525 var t = value * (1 f - saturation * (1f- (h6 - ihue)));556 var p = value * (1.0 - saturation); 557 var q = value * (1.0 - saturation * (h6 - ihue)); 558 var t = value * (1.0 - saturation * (1.0 - (h6 - ihue))); 526 559 switch (ihue) { 527 560 case 0: -
branches/Weighted TSNE/3.4/TSNE/TSNEStatic.cs
r15451 r15455 216 216 newData[i, j] = rand.NextDouble() * .0001; 217 217 218 if (data[0] is IReadOnlyList<double> && !randomInit) { 219 for (var i = 0; i < noDatapoints; i++) 220 for (var j = 0; j < newDimensions; j++) { 221 var row = (IReadOnlyList<double>) data[i]; 222 newData[i, j] = row[j % row.Count]; 223 } 218 if (!(data[0] is IReadOnlyList<double>) || randomInit) return; 219 for (var i = 0; i < noDatapoints; i++) 220 for (var j = 0; j < newDimensions; j++) { 221 var row = (IReadOnlyList<double>) data[i]; 222 newData[i, j] = row[j % row.Count]; 224 223 } 225 224 } … … 404 403 } 405 404 } 406 407 405 private static double[][] ComputeDistances(T[] x, IDistance<T> distance) { 408 406 var res = new double[x.Length][]; … … 422 420 // return x.Select(m => x.Select(n => distance.Get(m, n)).ToArray()).ToArray(); 423 421 } 424 425 422 private static double EvaluateErrorExact(double[,] p, double[,] y, int n, int d) { 426 423 // Compute the squared Euclidean distance matrix … … 450 447 return c; 451 448 } 452 453 449 private static double EvaluateErrorApproximate(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, double[,] y, double theta) { 454 450 // Get estimate of normalization term … … 592 588 ? state.gains[i, j] + .2 // +0.2 nd *0.8 are used in two separate implementations of tSNE -> seems to be correct 593 589 : state.gains[i, j] * .8; 594 595 590 if (state.gains[i, j] < .01) state.gains[i, j] = .01; 596 591 } -
branches/Weighted TSNE/3.4/TSNE/TSNEUtils.cs
r14414 r15455 35 35 } 36 36 37 internal static IList<T>Swap<T>(this IList<T> list, int indexA, int indexB) {37 internal static void Swap<T>(this IList<T> list, int indexA, int indexB) { 38 38 var tmp = list[indexA]; 39 39 list[indexA] = list[indexB]; 40 40 list[indexB] = tmp; 41 return list;42 41 } 43 42 44 internalstatic int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) {43 private static int Partition<T>(this IList<T> list, int left, int right, int pivotindex, IComparer<T> comparer) { 45 44 var pivotValue = list[pivotindex]; 46 45 list.Swap(pivotindex, right); … … 67 66 /// <param name="comparer">comparer for list elemnts </param> 68 67 /// <returns></returns> 69 internal static TNthElement<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) {68 internal static void NthElement<T>(this IList<T> list, int left, int right, int n, IComparer<T> comparer) { 70 69 while (true) { 71 if (left == right) return list[left];72 var pivotindex = left + (int) Math.Floor(new System.Random().Next() % (right - (double)left + 1));70 if (left == right) return; 71 var pivotindex = left + (int) Math.Floor(new System.Random().Next() % (right - (double) left + 1)); 73 72 pivotindex = list.Partition(left, right, pivotindex, comparer); 74 if (n == pivotindex) return list[n];73 if (n == pivotindex) return; 75 74 if (n < pivotindex) right = pivotindex - 1; 76 75 else left = pivotindex + 1;
Note: See TracChangeset
for help on using the changeset viewer.