Changeset 14807 for branches/TSNE/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 03/30/17 19:06:44 (8 years ago)
- Location:
- branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs
r14788 r14807 46 46 public sealed class TSNEAlgorithm : BasicAlgorithm { 47 47 public override bool SupportsPause { 48 get { return false; }48 get { return true; } 49 49 } 50 50 public override Type ProblemType { … … 182 182 set { NormalizationParameter.Value.Value = value; } 183 183 } 184 [Storable]185 public TSNE<double[]> tsne;186 184 #endregion 187 185 … … 189 187 [StorableConstructor] 190 188 private TSNEAlgorithm(bool deserializing) : base(deserializing) { } 191 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { } 189 190 private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { 191 this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames); 192 this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value)); 193 if(original.state != null) 194 this.state = cloner.Clone(original.state); 195 this.iter = original.iter; 196 } 192 197 public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); } 193 198 public TSNEAlgorithm() { 194 199 Problem = new RegressionProblem(); 195 200 Parameters.Add(new ValueParameter<IDistance<double[]>>(DistanceParameterName, "The distance function used to differentiate similar from non-similar points", new EuclideanDistance())); 196 Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity- Parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));197 Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise \n CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points\n This may exceed memory limitations", new DoubleValue(0)));201 Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25))); 202 Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points. This may exceed memory limitations.", new DoubleValue(0))); 198 203 Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2))); 199 Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent ", new IntValue(1000)));200 Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated ", new IntValue(0)));201 Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched ", new IntValue(0)));202 Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent ", new DoubleValue(0.5)));203 Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum ", new DoubleValue(0.8)));204 Parameters.Add(new FixedValueParameter<DoubleValue>(EtaParameterName, "Gradient descent learning rate ", new DoubleValue(200)));205 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random ", new BoolValue(true)));206 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random ", new IntValue(0)));207 Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. \n if the lable column can not be found training/test is used as labels", new StringValue("none")));208 Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored ", new BoolValue(true)));204 Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000))); 205 Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated.", new IntValue(0))); 206 Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched.", new IntValue(0))); 207 Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent.", new DoubleValue(0.5))); 208 Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum.", new DoubleValue(0.8))); 209 Parameters.Add(new FixedValueParameter<DoubleValue>(EtaParameterName, "Gradient descent learning rate.", new DoubleValue(200))); 210 Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random.", new BoolValue(true))); 211 Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random.", new IntValue(0))); 212 Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. If the label column can not be found training/test is used as labels.", new StringValue("none"))); 213 Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true))); 209 214 210 215 MomentumSwitchIterationParameter.Hidden = true; … … 217 222 218 223 [Storable] 219 private Dictionary<string, List<int>> dataRowNames; // TODO224 private Dictionary<string, List<int>> dataRowNames; 220 225 [Storable] 221 private Dictionary<string, ScatterPlotDataRow> dataRows; // TODO 222 226 private Dictionary<string, ScatterPlotDataRow> dataRows; 227 [Storable] 228 private TSNEStatic<double[]>.TSNEState state; 229 [Storable] 230 private int iter; 231 232 public override void Prepare() { 233 base.Prepare(); 234 dataRowNames = null; 235 dataRows = null; 236 state = null; 237 } 223 238 224 239 protected override void Run(CancellationToken cancellationToken) { 225 240 var problemData = Problem.ProblemData; 226 227 // set up and run tSNE 228 if (SetSeedRandomly) Seed = new System.Random().Next(); 229 var random = new MersenneTwister((uint)Seed); 230 var dataset = problemData.Dataset; 231 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 232 var data = new double[dataset.Rows][]; 233 for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 234 if (Normalization) data = NormalizeData(data); 235 236 var tsneState = TSNE<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta); 237 238 SetUpResults(data); 239 for (int iter = 0; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) 240 { 241 TSNE<double[]>.Iterate(tsneState); 242 Analyze(tsneState); 241 // set up and initialized everything if necessary 242 if(state == null) { 243 if(SetSeedRandomly) Seed = new System.Random().Next(); 244 var random = new MersenneTwister((uint)Seed); 245 var dataset = problemData.Dataset; 246 var allowedInputVariables = problemData.AllowedInputVariables.ToArray(); 247 var data = new double[dataset.Rows][]; 248 for(var row = 0; row < dataset.Rows; row++) 249 data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray(); 250 251 if(Normalization) data = NormalizeData(data); 252 253 state = TSNEStatic<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta, 254 StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta); 255 256 SetUpResults(data); 257 iter = 0; 258 } 259 for(; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) { 260 TSNEStatic<double[]>.Iterate(state); 261 Analyze(state); 243 262 } 244 263 } 245 264 246 265 private void SetUpResults(IReadOnlyCollection<double[]> data) { 247 if 266 if(Results == null) return; 248 267 var results = Results; 249 268 dataRowNames = new Dictionary<string, List<int>>(); … … 252 271 253 272 //color datapoints acording to classes variable (be it double or string) 254 if 255 if 273 if(problemData.Dataset.VariableNames.Contains(Classes)) { 274 if((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) { 256 275 var classes = problemData.Dataset.GetStringValues(Classes).ToArray(); 257 for 258 if 276 for(var i = 0; i < classes.Length; i++) { 277 if(!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>()); 259 278 dataRowNames[classes[i]].Add(i); 260 279 } 261 } else if 280 } else if((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) { 262 281 var classValues = problemData.Dataset.GetDoubleValues(Classes).ToArray(); 263 282 var max = classValues.Max() + 0.1; // TODO consts 264 283 var min = classValues.Min() - 0.1; 265 284 const int contours = 8; 266 for 285 for(var i = 0; i < contours; i++) { 267 286 var contourname = GetContourName(i, min, max, contours); 268 287 dataRowNames.Add(contourname, new List<int>()); … … 271 290 dataRows[contourname].VisualProperties.PointSize = i + 3; 272 291 } 273 for 292 for(var i = 0; i < classValues.Length; i++) { 274 293 dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); 275 294 } … … 280 299 } 281 300 282 if 301 if(!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0))); 283 302 else ((IntValue)results[IterationResultName].Value).Value = 0; 284 303 285 if 304 if(!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0))); 286 305 else ((DoubleValue)results[ErrorResultName].Value).Value = 0; 287 306 288 if 307 if(!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"))); 289 308 else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent"); 290 309 291 310 var plot = results[ErrorPlotResultName].Value as DataTable; 292 if 293 294 if 311 if(plot == null) throw new ArgumentException("could not create/access error data table in results collection"); 312 313 if(!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors")); 295 314 plot.Rows["errors"].Values.Clear(); 296 315 … … 299 318 } 300 319 301 private void Analyze(TSNE <double[]>.TSNEState tsneState) {302 if 320 private void Analyze(TSNEStatic<double[]>.TSNEState tsneState) { 321 if(Results == null) return; 303 322 var results = Results; 304 323 var plot = results[ErrorPlotResultName].Value as DataTable; 305 if 324 if(plot == null) throw new ArgumentException("Could not create/access error data table in results collection."); 306 325 var errors = plot.Rows["errors"].Values; 307 326 var c = tsneState.EvaluateError(); 308 327 errors.Add(c); 309 ((IntValue)results[IterationResultName].Value).Value = tsneState.iter + 1;328 ((IntValue)results[IterationResultName].Value).Value = tsneState.iter; 310 329 ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last(); 311 330 … … 317 336 318 337 private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) { 319 foreach 320 if 338 foreach(var rowName in dataRowNames.Keys) { 339 if(!plot.Rows.ContainsKey(rowName)) 321 340 plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>())); 322 341 plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))); … … 328 347 var min = new double[data.GetLength(1)]; 329 348 var res = new double[data.GetLength(0), data.GetLength(1)]; 330 for 331 for 332 for 349 for(var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i]; 350 for(var i = 0; i < data.GetLength(0); i++) 351 for(var j = 0; j < data.GetLength(1); j++) { 333 352 var v = data[i, j]; 334 353 max[j] = Math.Max(max[j], v); 335 354 min[j] = Math.Min(min[j], v); 336 355 } 337 for 338 for 356 for(var i = 0; i < data.GetLength(0); i++) { 357 for(var j = 0; j < data.GetLength(1); j++) { 339 358 res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]); 340 359 } … … 348 367 var sd = new double[n]; 349 368 var nData = new double[data.Count][]; 350 for 369 for(var i = 0; i < n; i++) { 351 370 var i1 = i; 352 371 sd[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).StandardDeviation(); 353 372 mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).Average(); 354 373 } 355 for 374 for(var i = 0; i < data.Count; i++) { 356 375 nData[i] = new double[n]; 357 for 376 for(var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j]; 358 377 } 359 378 return nData; -
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEStatic.cs
r14806 r14807 60 60 using HeuristicLab.Common; 61 61 using HeuristicLab.Core; 62 using HeuristicLab.Optimization; 62 63 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 63 64 using HeuristicLab.Random; … … 65 66 namespace HeuristicLab.Algorithms.DataAnalysis { 66 67 [StorableClass] 67 public class TSNE <T> {68 public class TSNEStatic<T> { 68 69 69 70 [StorableClass] … … 166 167 } 167 168 169 [StorableConstructor] 170 public TSNEState(bool deserializing) { } 168 171 public TSNEState(T[] data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity, double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta) { 169 172 this.distance = distance; … … 525 528 for(var i = 0; i < noElem; i++) symValP[i] /= 2.0; 526 529 } 527 528 530 } 529 531 530 public static TSNEState CreateState(T[] data, IDistance<T> distance, IRandom random, int newDimensions = 2, double perplexity = 25, double theta = 0, 531 int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0 532 /// <summary> 533 /// Simple interface to tSNE 534 /// </summary> 535 /// <param name="data"></param> 536 /// <param name="distance">The distance function used to differentiate similar from non-similar points, e.g. Euclidean distance.</param> 537 /// <param name="random">Random number generator</param> 538 /// <param name="newDimensions">Dimensionality of projected space (usually 2 for easy visual analysis).</param> 539 /// <param name="perplexity">Perplexity parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower</param> 540 /// <param name="iterations">Maximum number of iterations for gradient descent.</param> 541 /// <param name="theta">Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points. This may exceed memory limitations.</param> 542 /// <param name="stopLyingIter">Number of iterations after which p is no longer approximated.</param> 543 /// <param name="momSwitchIter">Number of iterations after which the momentum in the gradient descent is switched.</param> 544 /// <param name="momentum">The initial momentum in the gradient descent.</param> 545 /// <param name="finalMomentum">The final momentum in gradient descent (after momentum switch).</param> 546 /// <param name="eta">Gradient descent learning rate.</param> 547 /// <returns></returns> 548 public static double[,] Run(T[] data, IDistance<T> distance, IRandom random, 549 int newDimensions = 2, double perplexity = 25, int iterations = 1000, 550 double theta = 0, 551 int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, 552 double finalMomentum = .8, double eta = 200.0 553 ) { 554 var state = CreateState(data, distance, random, newDimensions, perplexity, 555 theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta); 556 557 for(int i = 0; i < iterations - 1; i++) { 558 Iterate(state); 559 } 560 return Iterate(state); 561 } 562 563 public static TSNEState CreateState(T[] data, IDistance<T> distance, IRandom random, 564 int newDimensions = 2, double perplexity = 25, double theta = 0, 565 int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, 566 double finalMomentum = .8, double eta = 200.0 532 567 ) { 533 568 return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta); … … 564 599 // Make solution zero-mean 565 600 ZeroMean(state.newData); 601 566 602 // Stop lying about the P-values after a while, and switch momentum 567 568 603 if(state.iter == state.stopLyingIter) { 569 604 if(state.exact) 570 for(var i = 0; i < state.noDatapoints; i++) for(var j = 0; j < state.noDatapoints; j++) state.p[i, j] /= 12.0; //XXX why 12? 605 for(var i = 0; i < state.noDatapoints; i++) 606 for(var j = 0; j < state.noDatapoints; j++) 607 state.p[i, j] /= 12.0; //XXX why 12? 571 608 else 572 for(var i = 0; i < state.rowP[state.noDatapoints]; i++) state.valP[i] /= 12.0; // XXX are we not scaling all values? 609 for(var i = 0; i < state.rowP[state.noDatapoints]; i++) 610 state.valP[i] /= 12.0; // XXX are we not scaling all values? 573 611 } 574 612
Note: See TracChangeset
for help on using the changeset viewer.