Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
12/20/16 15:50:11 (7 years ago)
Author:
bwerth
Message:

#2700 worked in several comments from mkommend, extended analysis during algorithm run, added more Distances, made algorithm stoppable

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:ignore
      •  

        old new  
         1*.user
         2*.vs10x
         3.vs
         4HeuristicLab.Algorithms.DataAnalysis-3.4.csproj.user
         5HeuristicLabAlgorithmsDataAnalysisPlugin.cs
         6Plugin.cs
        17bin
        28obj
        3 HeuristicLabAlgorithmsDataAnalysisPlugin.cs
        4 HeuristicLab.Algorithms.DataAnalysis-3.4.csproj.user
        5 *.vs10x
        6 Plugin.cs
        7 *.user
        8 .vs
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNE.cs

    r14414 r14512  
    5757using System.Collections.Generic;
    5858using System.Linq;
     59using HeuristicLab.Analysis;
    5960using HeuristicLab.Common;
    6061using HeuristicLab.Core;
     
    7071    private const string IterationResultName = "Iteration";
    7172    private const string ErrorResultName = "Error";
     73    private const string ErrorPlotResultName = "ErrorPlot";
     74    private const string ScatterPlotResultName = "Scatterplot";
     75    private const string DataResultName = "Projected Data";
    7276
    7377    #region Properties
     
    9094    [Storable]
    9195    private ResultCollection results;
     96    [Storable]
     97    private Dictionary<string, List<int>> dataRowLookup;
     98    [Storable]
     99    private Dictionary<string, ScatterPlotDataRow> dataRows;
     100    #endregion
     101
     102    #region Stopping
     103    public volatile bool Running;
    92104    #endregion
    93105
     
    105117      random = cloner.Clone(random);
    106118      results = cloner.Clone(results);
     119      dataRowLookup = original.dataRowLookup.ToDictionary(entry => entry.Key, entry => entry.Value.Select(x => x).ToList());
     120      dataRows = original.dataRows.ToDictionary(entry => entry.Key, entry => cloner.Clone(entry.Value));
    107121    }
    108122    public override IDeepCloneable Clone(Cloner cloner) { return new TSNE<T>(this, cloner); }
    109     public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0) {
     123    public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0, Dictionary<string, List<int>> dataRowLookup = null, Dictionary<string, ScatterPlotDataRow> dataRows = null) {
    110124      this.distance = distance;
    111125      this.maxIter = maxIter;
     
    117131      this.random = random;
    118132      this.results = results;
     133      this.dataRowLookup = dataRowLookup;
     134      if (dataRows != null)
     135        this.dataRows = dataRows;
     136      else { this.dataRows = new Dictionary<string, ScatterPlotDataRow>(); }
    119137    }
    120138    #endregion
     
    124142      var noDatapoints = data.Length;
    125143      if (noDatapoints - 1 < 3 * perplexity) throw new ArgumentException("Perplexity too large for the number of data points!");
    126 
    127       if (results != null) {
    128         if (!results.ContainsKey(IterationResultName)) {
    129           results.Add(new Result(IterationResultName, new IntValue(0)));
    130         } else ((IntValue)results[IterationResultName].Value).Value = 0;
    131         if (!results.ContainsKey(ErrorResultName)) {
    132           results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    133         } else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    134       }
    135 
    136       // Determine whether we are using an exact algorithm
     144      SetUpResults(data);
     145      Running = true;
    137146      var exact = Math.Abs(theta) < double.Epsilon;
    138147      var newData = new double[noDatapoints, newDimensions];
     
    141150      var gains = new double[noDatapoints, newDimensions];
    142151      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = 1.0;
    143 
    144       // Compute input similarities for exact t-SNE
    145152      double[,] p = null;
    146153      int[] rowP = null;
    147154      int[] colP = null;
    148155      double[] valP = null;
    149       if (exact) {
    150         // Compute similarities
    151         p = new double[noDatapoints, noDatapoints];
    152         ComputeGaussianPerplexity(data, noDatapoints, p, perplexity);
    153         // Symmetrize input similarities
    154         for (var n = 0; n < noDatapoints; n++) {
    155           for (var m = n + 1; m < noDatapoints; m++) {
    156             p[n, m] += p[m, n];
    157             p[m, n] = p[n, m];
    158           }
    159         }
    160         var sumP = .0;
    161         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) sumP += p[i, j];
    162         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= sumP;
    163       } // Compute input similarities for approximate t-SNE
    164       else {
    165         // Compute asymmetric pairwise input similarities
    166         ComputeGaussianPerplexity(data, noDatapoints, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
    167         // Symmetrize input similarities
    168         int[] sRowP, symColP;
    169         double[] sValP;
    170         SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
    171         rowP = sRowP;
    172         colP = symColP;
    173         valP = sValP;
    174         var sumP = .0;
    175         for (var i = 0; i < rowP[noDatapoints]; i++) sumP += valP[i];
    176         for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= sumP;
    177       }
     156      var rand = new NormalDistributedRandom(random, 0, 1);
     157
     158      //Calculate Similarities
     159      if (exact) p = CalculateExactSimilarites(data, perplexity);
     160      else CalculateApproximateSimilarities(data, perplexity, out rowP, out colP, out valP);
    178161
    179162      // Lie about the P-values
    180       if (exact) {
    181         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
    182       } else {
    183         for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
    184       }
    185 
    186       var rand = new NormalDistributedRandom(random, 0, 1);
     163      if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
     164      else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
     165
    187166      // Initialize solution (randomly)
    188       for (var i = 0; i < noDatapoints; i++)
    189         for (var j = 0; j < newDimensions; j++)
    190           newData[i, j] = rand.NextDouble() * .0001;
     167      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = rand.NextDouble() * .0001;
    191168
    192169      // Perform main training loop
    193       for (var iter = 0; iter < maxIter; iter++) {
    194 
    195         // Compute (approximate) gradient
     170      for (var iter = 0; iter < maxIter && Running; iter++) {
    196171        if (exact) ComputeExactGradient(p, newData, noDatapoints, newDimensions, dY);
    197172        else ComputeGradient(rowP, colP, valP, newData, noDatapoints, newDimensions, dY, theta);
    198 
    199173        // Update gains
    200         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++)
    201             gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
     174        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
    202175        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) if (gains[i, j] < .01) gains[i, j] = .01;
    203 
    204176        // Perform gradient update (with momentum and gains)
    205177        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) uY[i, j] = currentMomentum * uY[i, j] - eta * gains[i, j] * dY[i, j];
    206178        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = newData[i, j] + uY[i, j];
    207 
    208179        // Make solution zero-mean
    209180        ZeroMean(newData);
    210 
    211181        // Stop lying about the P-values after a while, and switch momentum
    212182        if (iter == stopLyingIter) {
    213           if (exact) {
    214             for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
    215           } else {
    216             for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
    217           }
     183          if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
     184          else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
    218185        }
    219186        if (iter == momSwitchIter) currentMomentum = finalMomentum;
    220187
    221         if (results == null) continue;
    222         var errors = new List<double>();
    223         // Print out progress
    224         var c = exact
    225           ? EvaluateError(p, newData, noDatapoints, newDimensions)
    226           : EvaluateError(rowP, colP, valP, newData, theta);
    227         errors.Add(c);
    228         ((IntValue)results[IterationResultName].Value).Value = iter + 1;
    229         ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
     188        Analyze(exact, iter, p, rowP, colP, valP, newData, noDatapoints, newDimensions, theta);
    230189      }
    231190      return newData;
     
    234193      return new TSNE<TR>(distance, random).Run(data, newDimensions, perplexity, theta);
    235194    }
     195    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, Func<TR, TR, double> distance, IRandom random) where TR : class, IDeepCloneable {
     196      return new TSNE<TR>(new FuctionalDistance<TR>(distance), random).Run(data, newDimensions, perplexity, theta);
     197    }
    236198
    237199    #region helpers
     200
     201    private void SetUpResults(IReadOnlyCollection<T> data) {
     202      if (dataRowLookup == null) {
     203        dataRowLookup = new Dictionary<string, List<int>>();
     204        dataRowLookup.Add("Data", Enumerable.Range(0, data.Count).ToList());
     205      }
     206      if (results == null) return;
     207      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     208      else ((IntValue)results[IterationResultName].Value).Value = 0;
     209
     210      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     211      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
     212
     213      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent")));
     214      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent");
     215
     216      var plot = results[ErrorPlotResultName].Value as DataTable;
     217      if (plot == null) throw new ArgumentException("could not create/access Error-DataTable in Results-Collection");
     218      if (!plot.Rows.ContainsKey("errors")) {
     219        plot.Rows.Add(new DataRow("errors"));
     220      }
     221      plot.Rows["errors"].Values.Clear();
     222      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     223      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     224
     225    }
     226    private void Analyze(bool exact, int iter, double[,] p, int[] rowP, int[] colP, double[] valP, double[,] newData, int noDatapoints, int newDimensions, double theta) {
     227      if (results == null) return;
     228      var plot = results[ErrorPlotResultName].Value as DataTable;
     229      if (plot == null) throw new ArgumentException("Could not create/access Error-DataTable in Results-Collection. Was it removed by some effect?");
     230      var errors = plot.Rows["errors"].Values;
     231      var c = exact
     232        ? EvaluateError(p, newData, noDatapoints, newDimensions)
     233        : EvaluateError(rowP, colP, valP, newData, theta);
     234      errors.Add(c);
     235      ((IntValue)results[IterationResultName].Value).Value = iter + 1;
     236      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
     237
     238      var ndata = Normalize(newData);
     239      results[DataResultName].Value = new DoubleMatrix(ndata);
     240      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
     241      FillScatterPlot(ndata, splot);
     242
     243
     244    }
     245    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
     246      foreach (var rowName in dataRowLookup.Keys) {
     247        if (!plot.Rows.ContainsKey(rowName)) {
     248          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     249        } else plot.Rows[rowName].Points.Clear();
     250        plot.Rows[rowName].Points.AddRange(dataRowLookup[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     251      }
     252    }
     253    private static double[,] Normalize(double[,] data) {
     254      var max = new double[data.GetLength(1)];
     255      var min = new double[data.GetLength(1)];
     256      var res = new double[data.GetLength(0), data.GetLength(1)];
     257      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
     258      for (var i = 0; i < data.GetLength(0); i++)
     259        for (var j = 0; j < data.GetLength(1); j++) {
     260          var v = data[i, j];
     261          max[j] = Math.Max(max[j], v);
     262          min[j] = Math.Min(min[j], v);
     263        }
     264      for (var i = 0; i < data.GetLength(0); i++) {
     265        for (var j = 0; j < data.GetLength(1); j++) {
     266          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
     267        }
     268      }
     269      return res;
     270    }
     271    private void CalculateApproximateSimilarities(T[] data, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
     272      // Compute asymmetric pairwise input similarities
     273      ComputeGaussianPerplexity(data, data.Length, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
     274      // Symmetrize input similarities
     275      int[] sRowP, symColP;
     276      double[] sValP;
     277      SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
     278      rowP = sRowP;
     279      colP = symColP;
     280      valP = sValP;
     281      var sumP = .0;
     282      for (var i = 0; i < rowP[data.Length]; i++) sumP += valP[i];
     283      for (var i = 0; i < rowP[data.Length]; i++) valP[i] /= sumP;
     284    }
     285    private double[,] CalculateExactSimilarites(T[] data, double perplexity) {
     286      // Compute similarities
     287      var p = new double[data.Length, data.Length];
     288      ComputeGaussianPerplexity(data, data.Length, p, perplexity);
     289      // Symmetrize input similarities
     290      for (var n = 0; n < data.Length; n++) {
     291        for (var m = n + 1; m < data.Length; m++) {
     292          p[n, m] += p[m, n];
     293          p[m, n] = p[n, m];
     294        }
     295      }
     296      var sumP = .0;
     297      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) sumP += p[i, j];
     298      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) p[i, j] /= sumP;
     299      return p;
     300    }
     301
    238302    private void ComputeGaussianPerplexity(IReadOnlyList<T> x, int n, out int[] rowP, out int[] colP, out double[] valP, double perplexity, int k) {
    239303      if (perplexity > k) throw new ArgumentException("Perplexity should be lower than K!");
Note: See TracChangeset for help on using the changeset viewer.