- Timestamp:
- 03/27/17 17:27:03 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
r14785 r14788 29 29 using HeuristicLab.Core; 30 30 using HeuristicLab.Data; 31 using HeuristicLab.Encodings.RealVectorEncoding;32 31 using HeuristicLab.Optimization; 33 32 using HeuristicLab.Parameters; … … 72 71 private const string ClassesParameterName = "ClassNames"; 73 72 private const string NormalizationParameterName = "Normalization"; 73 #endregion 74 75 #region result names 76 private const string IterationResultName = "Iteration"; 77 private const string ErrorResultName = "Error"; 78 private const string ErrorPlotResultName = "Error plot"; 79 private const string ScatterPlotResultName = "Scatterplot"; 80 private const string DataResultName = "Projected data"; 74 81 #endregion 75 82 … … 209 216 #endregion 210 217 211 public override void Stop() { 212 base.Stop(); 213 if (tsne != null) tsne.Running = false; 214 } 218 [Storable] 219 private Dictionary<string, List<int>> dataRowNames; // TODO 220 [Storable] 221 private Dictionary<string, ScatterPlotDataRow> dataRows; // TODO 222 215 223 216 224 protected override void Run(CancellationToken cancellationToken) { 217 var dataRowNames = new Dictionary<string, List<int>>(); 218 var rows = new Dictionary<string, ScatterPlotDataRow>(); 225 var problemData = Problem.ProblemData; 226 227 // set up and run tSNE 228 if (SetSeedRandomly) Seed = new System.Random().Next(); 229 var random = new MersenneTwister((uint)Seed); 230 var dataset = problemData.Dataset; 231 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 232 var data = new double[dataset.Rows][]; 233 for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 234 if (Normalization) data = NormalizeData(data); 235 236 var tsneState = TSNE<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta); 237 238 SetUpResults(data); 239 for (int iter = 0; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) 240 { 241 TSNE<double[]>.Iterate(tsneState); 242 Analyze(tsneState); 243 } 244 } 245 246 private void SetUpResults(IReadOnlyCollection<double[]> data) { 247 if (Results == null) return; 248 var results = Results; 249 dataRowNames = new Dictionary<string, List<int>>(); 250 dataRows = new Dictionary<string, ScatterPlotDataRow>(); 219 251 var problemData = Problem.ProblemData; 220 252 … … 235 267 var contourname = GetContourName(i, min, max, contours); 236 268 dataRowNames.Add(contourname, new List<int>()); 237 rows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));238 rows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);239 rows[contourname].VisualProperties.PointSize = i + 3;269 dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>())); 270 dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours); 271 dataRows[contourname].VisualProperties.PointSize = i + 3; 240 272 } 241 273 for (var i = 0; i < classValues.Length; i++) { … … 248 280 } 249 281 250 // set up and run tSNE 251 if (SetSeedRandomly) Seed = new System.Random().Next(); 252 var random = new MersenneTwister((uint)Seed); 253 tsne = new TSNE<double[]>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows); 254 var dataset = problemData.Dataset; 255 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 256 var data = new double[dataset.Rows][]; 257 for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 258 if (Normalization) data = NormalizeData(data); 259 tsne.Run(data, NewDimensions, Perplexity, Theta); 282 if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 283 else ((IntValue)results[IterationResultName].Value).Value = 0; 284 285 if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 286 else ((DoubleValue)results[ErrorResultName].Value).Value = 0; 287 288 if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"))); 289 else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"); 290 291 var plot = results[ErrorPlotResultName].Value as DataTable; 292 if (plot == null) throw new ArgumentException("could not create/access error data table in results collection"); 293 294 if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors")); 295 plot.Rows["errors"].Values.Clear(); 296 297 results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, ""))); 298 results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix())); 299 } 300 301 private void Analyze(TSNE<double[]>.TSNEState tsneState) { 302 if (Results == null) return; 303 var results = Results; 304 var plot = results[ErrorPlotResultName].Value as DataTable; 305 if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection."); 306 var errors = plot.Rows["errors"].Values; 307 var c = tsneState.EvaluateError(); 308 errors.Add(c); 309 ((IntValue)results[IterationResultName].Value).Value = tsneState.iter + 1; 310 ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last(); 311 312 var ndata = Normalize(tsneState.newData); 313 results[DataResultName].Value = new DoubleMatrix(ndata); 314 var splot = results[ScatterPlotResultName].Value as ScatterPlot; 315 FillScatterPlot(ndata, splot); 316 } 317 318 private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) { 319 foreach (var rowName in dataRowNames.Keys) { 320 if (!plot.Rows.ContainsKey(rowName)) 321 plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 322 plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); 323 } 324 } 325 326 private static double[,] Normalize(double[,] data) { 327 var max = new double[data.GetLength(1)]; 328 var min = new double[data.GetLength(1)]; 329 var res = new double[data.GetLength(0), data.GetLength(1)]; 330 for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i]; 331 for (var i = 0; i < data.GetLength(0); i++) 332 for (var j = 0; j < data.GetLength(1); j++) { 333 var v = data[i, j]; 334 max[j] = Math.Max(max[j], v); 335 min[j] = Math.Min(min[j], v); 336 } 337 for (var i = 0; i < data.GetLength(0); i++) { 338 for (var j = 0; j < data.GetLength(1); j++) { 339 res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]); 340 } 341 } 342 return res; 260 343 } 261 344 … … 276 359 return nData; 277 360 } 361 278 362 private static Color GetHeatMapColor(int contourNr, int noContours) { 279 363 var q = (double)contourNr / noContours; // q in [0,1] … … 281 365 return c; 282 366 } 367 283 368 private static string GetContourName(double value, double min, double max, int noContours) { 284 369 var size = (max - min) / noContours; … … 286 371 return GetContourName(contourNr, min, max, noContours); 287 372 } 373 288 374 private static string GetContourName(int i, double min, double max, int noContours) { 289 375 var size = (max - min) / noContours;
Note: See TracChangeset
for help on using the changeset viewer.