Changeset 14807


Ignore:
Timestamp:
03/30/17 19:06:44 (5 years ago)
Author:
gkronber
Message:

#2700: support clone, persistence and pause/resume

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r14788 r14807  
    4646  public sealed class TSNEAlgorithm : BasicAlgorithm {
    4747    public override bool SupportsPause {
    48       get { return false; }
     48      get { return true; }
    4949    }
    5050    public override Type ProblemType {
     
    182182      set { NormalizationParameter.Value.Value = value; }
    183183    }
    184     [Storable]
    185     public TSNE<double[]> tsne;
    186184    #endregion
    187185
     
    189187    [StorableConstructor]
    190188    private TSNEAlgorithm(bool deserializing) : base(deserializing) { }
    191     private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) { }
     189
     190    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
     191      this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
     192      this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     193      if(original.state != null)
     194        this.state = cloner.Clone(original.state);
     195      this.iter = original.iter;
     196    }
    192197    public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
    193198    public TSNEAlgorithm() {
    194199      Problem = new RegressionProblem();
    195200      Parameters.Add(new ValueParameter<IDistance<double[]>>(DistanceParameterName, "The distance function used to differentiate similar from non-similar points", new EuclideanDistance()));
    196       Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-Parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    197       Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise \n CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points\n This may exceed memory limitations", new DoubleValue(0)));
     201      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
     202      Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points. This may exceed memory limitations.", new DoubleValue(0)));
    198203      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    199       Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent", new IntValue(1000)));
    200       Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated", new IntValue(0)));
    201       Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched", new IntValue(0)));
    202       Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent", new DoubleValue(0.5)));
    203       Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum", new DoubleValue(0.8)));
    204       Parameters.Add(new FixedValueParameter<DoubleValue>(EtaParameterName, "Gradient descent learning rate", new DoubleValue(200)));
    205       Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random", new BoolValue(true)));
    206       Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random", new IntValue(0)));
    207       Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. \n if the lable column can not be found training/test is used as labels", new StringValue("none")));
    208       Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored", new BoolValue(true)));
     204      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     205      Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated.", new IntValue(0)));
     206      Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched.", new IntValue(0)));
     207      Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent.", new DoubleValue(0.5)));
     208      Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum.", new DoubleValue(0.8)));
     209      Parameters.Add(new FixedValueParameter<DoubleValue>(EtaParameterName, "Gradient descent learning rate.", new DoubleValue(200)));
     210      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random.", new BoolValue(true)));
     211      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random.", new IntValue(0)));
     212      Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. If the label column can not be found training/test is used as labels.", new StringValue("none")));
     213      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    209214
    210215      MomentumSwitchIterationParameter.Hidden = true;
     
    217222
    218223    [Storable]
    219     private Dictionary<string, List<int>> dataRowNames;    // TODO
     224    private Dictionary<string, List<int>> dataRowNames;
    220225    [Storable]
    221     private Dictionary<string, ScatterPlotDataRow> dataRows; // TODO
    222 
     226    private Dictionary<string, ScatterPlotDataRow> dataRows;
     227    [Storable]
     228    private TSNEStatic<double[]>.TSNEState state;
     229    [Storable]
     230    private int iter;
     231
     232    public override void Prepare() {
     233      base.Prepare();
     234      dataRowNames = null;
     235      dataRows = null;
     236      state = null;
     237    }
    223238
    224239    protected override void Run(CancellationToken cancellationToken) {
    225240      var problemData = Problem.ProblemData;
    226 
    227       // set up and run tSNE
    228       if (SetSeedRandomly) Seed = new System.Random().Next();
    229       var random = new MersenneTwister((uint)Seed);
    230       var dataset = problemData.Dataset;
    231       var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    232       var data = new double[dataset.Rows][];
    233       for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    234       if (Normalization) data = NormalizeData(data);
    235 
    236       var tsneState = TSNE<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
    237 
    238       SetUpResults(data);
    239       for (int iter = 0; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++)
    240       {
    241         TSNE<double[]>.Iterate(tsneState);
    242         Analyze(tsneState);
     241      // set up and initialized everything if necessary
     242      if(state == null) {
     243        if(SetSeedRandomly) Seed = new System.Random().Next();
     244        var random = new MersenneTwister((uint)Seed);
     245        var dataset = problemData.Dataset;
     246        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
     247        var data = new double[dataset.Rows][];
     248        for(var row = 0; row < dataset.Rows; row++)
     249          data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
     250
     251        if(Normalization) data = NormalizeData(data);
     252
     253        state = TSNEStatic<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta,
     254          StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
     255
     256        SetUpResults(data);
     257        iter = 0;
     258      }
     259      for(; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
     260        TSNEStatic<double[]>.Iterate(state);
     261        Analyze(state);
    243262      }
    244263    }
    245264
    246265    private void SetUpResults(IReadOnlyCollection<double[]> data) {
    247       if (Results == null) return;
     266      if(Results == null) return;
    248267      var results = Results;
    249268      dataRowNames = new Dictionary<string, List<int>>();
     
    252271
    253272      //color datapoints acording to classes variable (be it double or string)
    254       if (problemData.Dataset.VariableNames.Contains(Classes)) {
    255         if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
     273      if(problemData.Dataset.VariableNames.Contains(Classes)) {
     274        if((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
    256275          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
    257           for (var i = 0; i < classes.Length; i++) {
    258             if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     276          for(var i = 0; i < classes.Length; i++) {
     277            if(!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    259278            dataRowNames[classes[i]].Add(i);
    260279          }
    261         } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     280        } else if((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
    262281          var classValues = problemData.Dataset.GetDoubleValues(Classes).ToArray();
    263282          var max = classValues.Max() + 0.1;     // TODO consts
    264283          var min = classValues.Min() - 0.1;
    265284          const int contours = 8;
    266           for (var i = 0; i < contours; i++) {
     285          for(var i = 0; i < contours; i++) {
    267286            var contourname = GetContourName(i, min, max, contours);
    268287            dataRowNames.Add(contourname, new List<int>());
     
    271290            dataRows[contourname].VisualProperties.PointSize = i + 3;
    272291          }
    273           for (var i = 0; i < classValues.Length; i++) {
     292          for(var i = 0; i < classValues.Length; i++) {
    274293            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    275294          }
     
    280299      }
    281300
    282       if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     301      if(!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    283302      else ((IntValue)results[IterationResultName].Value).Value = 0;
    284303
    285       if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     304      if(!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    286305      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    287306
    288       if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
     307      if(!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    289308      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    290309
    291310      var plot = results[ErrorPlotResultName].Value as DataTable;
    292       if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    293 
    294       if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
     311      if(plot == null) throw new ArgumentException("could not create/access error data table in results collection");
     312
     313      if(!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    295314      plot.Rows["errors"].Values.Clear();
    296315
     
    299318    }
    300319
    301     private void Analyze(TSNE<double[]>.TSNEState tsneState) {
    302       if (Results == null) return;
     320    private void Analyze(TSNEStatic<double[]>.TSNEState tsneState) {
     321      if(Results == null) return;
    303322      var results = Results;
    304323      var plot = results[ErrorPlotResultName].Value as DataTable;
    305       if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
     324      if(plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    306325      var errors = plot.Rows["errors"].Values;
    307326      var c = tsneState.EvaluateError();
    308327      errors.Add(c);
    309       ((IntValue)results[IterationResultName].Value).Value = tsneState.iter + 1;
     328      ((IntValue)results[IterationResultName].Value).Value = tsneState.iter;
    310329      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
    311330
     
    317336
    318337    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    319       foreach (var rowName in dataRowNames.Keys) {
    320         if (!plot.Rows.ContainsKey(rowName))
     338      foreach(var rowName in dataRowNames.Keys) {
     339        if(!plot.Rows.ContainsKey(rowName))
    321340          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    322341        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     
    328347      var min = new double[data.GetLength(1)];
    329348      var res = new double[data.GetLength(0), data.GetLength(1)];
    330       for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    331       for (var i = 0; i < data.GetLength(0); i++)
    332         for (var j = 0; j < data.GetLength(1); j++) {
     349      for(var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
     350      for(var i = 0; i < data.GetLength(0); i++)
     351        for(var j = 0; j < data.GetLength(1); j++) {
    333352          var v = data[i, j];
    334353          max[j] = Math.Max(max[j], v);
    335354          min[j] = Math.Min(min[j], v);
    336355        }
    337       for (var i = 0; i < data.GetLength(0); i++) {
    338         for (var j = 0; j < data.GetLength(1); j++) {
     356      for(var i = 0; i < data.GetLength(0); i++) {
     357        for(var j = 0; j < data.GetLength(1); j++) {
    339358          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
    340359        }
     
    348367      var sd = new double[n];
    349368      var nData = new double[data.Count][];
    350       for (var i = 0; i < n; i++) {
     369      for(var i = 0; i < n; i++) {
    351370        var i1 = i;
    352371        sd[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).StandardDeviation();
    353372        mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).Average();
    354373      }
    355       for (var i = 0; i < data.Count; i++) {
     374      for(var i = 0; i < data.Count; i++) {
    356375        nData[i] = new double[n];
    357         for (var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j];
     376        for(var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j];
    358377      }
    359378      return nData;
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEStatic.cs

    r14806 r14807  
    6060using HeuristicLab.Common;
    6161using HeuristicLab.Core;
     62using HeuristicLab.Optimization;
    6263using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    6364using HeuristicLab.Random;
     
    6566namespace HeuristicLab.Algorithms.DataAnalysis {
    6667  [StorableClass]
    67   public class TSNE<T> {
     68  public class TSNEStatic<T> {
    6869
    6970    [StorableClass]
     
    166167      }
    167168
     169      [StorableConstructor]
     170      public TSNEState(bool deserializing)  { }
    168171      public TSNEState(T[] data, IDistance<T> distance, IRandom random, int newDimensions, double perplexity, double theta, int stopLyingIter, int momSwitchIter, double momentum, double finalMomentum, double eta) {
    169172        this.distance = distance;
     
    525528        for(var i = 0; i < noElem; i++) symValP[i] /= 2.0;
    526529      }
    527 
    528530    }
    529531
    530     public static TSNEState CreateState(T[] data, IDistance<T> distance, IRandom random, int newDimensions = 2, double perplexity = 25, double theta = 0,
    531       int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0
     532    /// <summary>
     533    /// Simple interface to tSNE
     534    /// </summary>
     535    /// <param name="data"></param>
     536    /// <param name="distance">The distance function used to differentiate similar from non-similar points, e.g. Euclidean distance.</param>
     537    /// <param name="random">Random number generator</param>
     538    /// <param name="newDimensions">Dimensionality of projected space (usually 2 for easy visual analysis).</param>
     539    /// <param name="perplexity">Perplexity parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower</param>
     540    /// <param name="iterations">Maximum number of iterations for gradient descent.</param>
     541    /// <param name="theta">Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points. This may exceed memory limitations.</param>
     542    /// <param name="stopLyingIter">Number of iterations after which p is no longer approximated.</param>
     543    /// <param name="momSwitchIter">Number of iterations after which the momentum in the gradient descent is switched.</param>
     544    /// <param name="momentum">The initial momentum in the gradient descent.</param>
     545    /// <param name="finalMomentum">The final momentum in gradient descent (after momentum switch).</param>
     546    /// <param name="eta">Gradient descent learning rate.</param>
     547    /// <returns></returns>
     548    public static double[,] Run(T[] data, IDistance<T> distance, IRandom random,
     549      int newDimensions = 2, double perplexity = 25, int iterations = 1000,
     550      double theta = 0,
     551      int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5,
     552      double finalMomentum = .8, double eta = 200.0
     553      ) {
     554      var state = CreateState(data, distance, random, newDimensions, perplexity,
     555        theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     556
     557      for(int i = 0; i < iterations - 1; i++) {
     558        Iterate(state);
     559      }
     560      return Iterate(state);
     561    }
     562
     563    public static TSNEState CreateState(T[] data, IDistance<T> distance, IRandom random,
     564      int newDimensions = 2, double perplexity = 25, double theta = 0,
     565      int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5,
     566      double finalMomentum = .8, double eta = 200.0
    532567      ) {
    533568      return new TSNEState(data, distance, random, newDimensions, perplexity, theta, stopLyingIter, momSwitchIter, momentum, finalMomentum, eta);
     
    564599      // Make solution zero-mean
    565600      ZeroMean(state.newData);
     601
    566602      // Stop lying about the P-values after a while, and switch momentum
    567 
    568603      if(state.iter == state.stopLyingIter) {
    569604        if(state.exact)
    570           for(var i = 0; i < state.noDatapoints; i++) for(var j = 0; j < state.noDatapoints; j++) state.p[i, j] /= 12.0;                                   //XXX why 12?
     605          for(var i = 0; i < state.noDatapoints; i++)
     606            for(var j = 0; j < state.noDatapoints; j++)
     607              state.p[i, j] /= 12.0;                                   //XXX why 12?
    571608        else
    572           for(var i = 0; i < state.rowP[state.noDatapoints]; i++) state.valP[i] /= 12.0;                       // XXX are we not scaling all values?
     609          for(var i = 0; i < state.rowP[state.noDatapoints]; i++)
     610            state.valP[i] /= 12.0;                       // XXX are we not scaling all values?
    573611      }
    574612
Note: See TracChangeset for help on using the changeset viewer.