Free cookie consent management tool by TermsFeed Policy Generator

source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNE.cs @ 14558

Last change on this file since 14558 was 14558, checked in by bwerth, 7 years ago

#2700 made TSNE compatible with the new pausible BasicAlgs, removed rescaling of scatterplots during alg to give it a more movie-esque feel

File size: 27.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20
21//Code is based on an implementation from Laurens van der Maaten
22
23/*
24*
25* Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
26* All rights reserved.
27*
28* Redistribution and use in source and binary forms, with or without
29* modification, are permitted provided that the following conditions are met:
30* 1. Redistributions of source code must retain the above copyright
31*    notice, this list of conditions and the following disclaimer.
32* 2. Redistributions in binary form must reproduce the above copyright
33*    notice, this list of conditions and the following disclaimer in the
34*    documentation and/or other materials provided with the distribution.
35* 3. All advertising materials mentioning features or use of this software
36*    must display the following acknowledgement:
37*    This product includes software developed by the Delft University of Technology.
38* 4. Neither the name of the Delft University of Technology nor the names of
39*    its contributors may be used to endorse or promote products derived from
40*    this software without specific prior written permission.
41*
42* THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
43* OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
44* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
45* EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
46* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
47* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
49* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
50* IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
51* OF SUCH DAMAGE.
52*
53*/
54#endregion
55
56using System;
57using System.Collections.Generic;
58using System.Linq;
59using HeuristicLab.Analysis;
60using HeuristicLab.Common;
61using HeuristicLab.Core;
62using HeuristicLab.Data;
63using HeuristicLab.Optimization;
64using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
65using HeuristicLab.Random;
66
67namespace HeuristicLab.Algorithms.DataAnalysis {
68  [StorableClass]
69  public class TSNE<T> : DeepCloneable, ITSNE<T> where T : class, IDeepCloneable {
70
71    private const string IterationResultName = "Iteration";
72    private const string ErrorResultName = "Error";
73    private const string ErrorPlotResultName = "ErrorPlot";
74    private const string ScatterPlotResultName = "Scatterplot";
75    private const string DataResultName = "Projected Data";
76
77    #region Properties
78    [Storable]
79    private IDistance<T> distance;
80    [Storable]
81    private int maxIter;
82    [Storable]
83    private int stopLyingIter;
84    [Storable]
85    private int momSwitchIter;
86    [Storable]
87    double momentum;
88    [Storable]
89    private double finalMomentum;
90    [Storable]
91    private double eta;
92    [Storable]
93    private IRandom random;
94    [Storable]
95    private ResultCollection results;
96    [Storable]
97    private Dictionary<string, List<int>> dataRowLookup;
98    [Storable]
99    private Dictionary<string, ScatterPlotDataRow> dataRows;
100    #endregion
101
102    #region Stopping
103    public volatile bool Running;
104    #endregion
105
106    #region HLConstructors & Cloning
107    [StorableConstructor]
108    protected TSNE(bool deserializing) { }
109    protected TSNE(TSNE<T> original, Cloner cloner) : base(original, cloner) {
110      distance = cloner.Clone(original.distance);
111      maxIter = original.maxIter;
112      stopLyingIter = original.stopLyingIter;
113      momSwitchIter = original.momSwitchIter;
114      momentum = original.momentum;
115      finalMomentum = original.finalMomentum;
116      eta = original.eta;
117      random = cloner.Clone(random);
118      results = cloner.Clone(results);
119      dataRowLookup = original.dataRowLookup.ToDictionary(entry => entry.Key, entry => entry.Value.Select(x => x).ToList());
120      dataRows = original.dataRows.ToDictionary(entry => entry.Key, entry => cloner.Clone(entry.Value));
121    }
122    public override IDeepCloneable Clone(Cloner cloner) { return new TSNE<T>(this, cloner); }
123    public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0, Dictionary<string, List<int>> dataRowLookup = null, Dictionary<string, ScatterPlotDataRow> dataRows = null) {
124      this.distance = distance;
125      this.maxIter = maxIter;
126      this.stopLyingIter = stopLyingIter;
127      this.momSwitchIter = momSwitchIter;
128      this.momentum = momentum;
129      this.finalMomentum = finalMomentum;
130      this.eta = eta;
131      this.random = random;
132      this.results = results;
133      this.dataRowLookup = dataRowLookup;
134      if (dataRows != null)
135        this.dataRows = dataRows;
136      else { this.dataRows = new Dictionary<string, ScatterPlotDataRow>(); }
137    }
138    #endregion
139
140    public double[,] Run(T[] data, int newDimensions, double perplexity, double theta) {
141      var currentMomentum = momentum;
142      var noDatapoints = data.Length;
143      if (noDatapoints - 1 < 3 * perplexity) throw new ArgumentException("Perplexity too large for the number of data points!");
144      SetUpResults(data);
145      Running = true;
146      var exact = Math.Abs(theta) < double.Epsilon;
147      var newData = new double[noDatapoints, newDimensions];
148      var dY = new double[noDatapoints, newDimensions];
149      var uY = new double[noDatapoints, newDimensions];
150      var gains = new double[noDatapoints, newDimensions];
151      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = 1.0;
152      double[,] p = null;
153      int[] rowP = null;
154      int[] colP = null;
155      double[] valP = null;
156      var rand = new NormalDistributedRandom(random, 0, 1);
157
158      //Calculate Similarities
159      if (exact) p = CalculateExactSimilarites(data, perplexity);
160      else CalculateApproximateSimilarities(data, perplexity, out rowP, out colP, out valP);
161
162      // Lie about the P-values
163      if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
164      else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
165
166      // Initialize solution (randomly)
167      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = rand.NextDouble() * .0001;
168
169      // Perform main training loop
170      for (var iter = 0; iter < maxIter && Running; iter++) {
171        if (exact) ComputeExactGradient(p, newData, noDatapoints, newDimensions, dY);
172        else ComputeGradient(rowP, colP, valP, newData, noDatapoints, newDimensions, dY, theta);
173        // Update gains
174        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
175        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) if (gains[i, j] < .01) gains[i, j] = .01;
176        // Perform gradient update (with momentum and gains)
177        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) uY[i, j] = currentMomentum * uY[i, j] - eta * gains[i, j] * dY[i, j];
178        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = newData[i, j] + uY[i, j];
179        // Make solution zero-mean
180        ZeroMean(newData);
181        // Stop lying about the P-values after a while, and switch momentum
182        if (iter == stopLyingIter) {
183          if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
184          else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
185        }
186        if (iter == momSwitchIter) currentMomentum = finalMomentum;
187
188        Analyze(exact, iter, p, rowP, colP, valP, newData, noDatapoints, newDimensions, theta);
189      }
190      return newData;
191    }
192    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, IDistance<TR> distance, IRandom random) where TR : class, IDeepCloneable {
193      return new TSNE<TR>(distance, random).Run(data, newDimensions, perplexity, theta);
194    }
195    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, Func<TR, TR, double> distance, IRandom random) where TR : class, IDeepCloneable {
196      return new TSNE<TR>(new FuctionalDistance<TR>(distance), random).Run(data, newDimensions, perplexity, theta);
197    }
198
199    #region helpers
200
201    private void SetUpResults(IReadOnlyCollection<T> data) {
202      if (dataRowLookup == null) {
203        dataRowLookup = new Dictionary<string, List<int>>();
204        dataRowLookup.Add("Data", Enumerable.Range(0, data.Count).ToList());
205      }
206      if (results == null) return;
207      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
208      else ((IntValue)results[IterationResultName].Value).Value = 0;
209
210      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
211      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
212
213      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent")));
214      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent");
215
216      var plot = results[ErrorPlotResultName].Value as DataTable;
217      if (plot == null) throw new ArgumentException("could not create/access Error-DataTable in Results-Collection");
218      if (!plot.Rows.ContainsKey("errors")) {
219        plot.Rows.Add(new DataRow("errors"));
220      }
221      plot.Rows["errors"].Values.Clear();
222      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
223      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
224
225    }
226    private void Analyze(bool exact, int iter, double[,] p, int[] rowP, int[] colP, double[] valP, double[,] newData, int noDatapoints, int newDimensions, double theta) {
227      if (results == null) return;
228      var plot = results[ErrorPlotResultName].Value as DataTable;
229      if (plot == null) throw new ArgumentException("Could not create/access Error-DataTable in Results-Collection. Was it removed by some effect?");
230      var errors = plot.Rows["errors"].Values;
231      var c = exact
232        ? EvaluateError(p, newData, noDatapoints, newDimensions)
233        : EvaluateError(rowP, colP, valP, newData, theta);
234      errors.Add(c);
235      ((IntValue)results[IterationResultName].Value).Value = iter + 1;
236      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
237
238      var ndata = Normalize(newData);
239      results[DataResultName].Value = new DoubleMatrix(ndata);
240      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
241      FillScatterPlot(ndata, splot);
242
243
244    }
245    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
246      foreach (var rowName in dataRowLookup.Keys) {
247        if (!plot.Rows.ContainsKey(rowName)) {
248          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
249        }
250        //else plot.Rows[rowName].Points.Clear();
251        plot.Rows[rowName].Points.Replace(dataRowLookup[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
252        //plot.Rows[rowName].Points.AddRange();
253      }
254    }
255    private static double[,] Normalize(double[,] data) {
256      var max = new double[data.GetLength(1)];
257      var min = new double[data.GetLength(1)];
258      var res = new double[data.GetLength(0), data.GetLength(1)];
259      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
260      for (var i = 0; i < data.GetLength(0); i++)
261        for (var j = 0; j < data.GetLength(1); j++) {
262          var v = data[i, j];
263          max[j] = Math.Max(max[j], v);
264          min[j] = Math.Min(min[j], v);
265        }
266      for (var i = 0; i < data.GetLength(0); i++) {
267        for (var j = 0; j < data.GetLength(1); j++) {
268          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
269        }
270      }
271      return res;
272    }
273    private void CalculateApproximateSimilarities(T[] data, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
274      // Compute asymmetric pairwise input similarities
275      ComputeGaussianPerplexity(data, data.Length, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
276      // Symmetrize input similarities
277      int[] sRowP, symColP;
278      double[] sValP;
279      SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
280      rowP = sRowP;
281      colP = symColP;
282      valP = sValP;
283      var sumP = .0;
284      for (var i = 0; i < rowP[data.Length]; i++) sumP += valP[i];
285      for (var i = 0; i < rowP[data.Length]; i++) valP[i] /= sumP;
286    }
287    private double[,] CalculateExactSimilarites(T[] data, double perplexity) {
288      // Compute similarities
289      var p = new double[data.Length, data.Length];
290      ComputeGaussianPerplexity(data, data.Length, p, perplexity);
291      // Symmetrize input similarities
292      for (var n = 0; n < data.Length; n++) {
293        for (var m = n + 1; m < data.Length; m++) {
294          p[n, m] += p[m, n];
295          p[m, n] = p[n, m];
296        }
297      }
298      var sumP = .0;
299      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) sumP += p[i, j];
300      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) p[i, j] /= sumP;
301      return p;
302    }
303
304    private void ComputeGaussianPerplexity(IReadOnlyList<T> x, int n, out int[] rowP, out int[] colP, out double[] valP, double perplexity, int k) {
305      if (perplexity > k) throw new ArgumentException("Perplexity should be lower than K!");
306
307      // Allocate the memory we need
308      rowP = new int[n + 1];
309      colP = new int[n * k];
310      valP = new double[n * k];
311      var curP = new double[n - 1];
312      rowP[0] = 0;
313      for (var i = 0; i < n; i++) rowP[i + 1] = rowP[i] + k;
314
315      // Build ball tree on data set
316      var tree = new VPTree<IDataPoint<T>>(new DataPointDistance<T>(distance));
317      var objX = new List<IDataPoint<T>>();
318      for (var i = 0; i < n; i++) objX.Add(new DataPoint<T>(i, x[i]));
319      tree.Create(objX);
320
321      // Loop over all points to find nearest neighbors
322      var indices = new List<IDataPoint<T>>();
323      var distances = new List<double>();
324      for (var i = 0; i < n; i++) {
325
326        // Find nearest neighbors
327        indices.Clear();
328        distances.Clear();
329        tree.Search(objX[i], k + 1, out indices, out distances);
330
331        // Initialize some variables for binary search
332        var found = false;
333        var beta = 1.0;
334        var minBeta = -double.MaxValue;
335        var maxBeta = double.MaxValue;
336        const double tol = 1e-5;
337
338        // Iterate until we found a good perplexity
339        var iter = 0; double sumP = 0;
340        while (!found && iter < 200) {
341
342          // Compute Gaussian kernel row
343          for (var m = 0; m < k; m++) curP[m] = Math.Exp(-beta * distances[m + 1]);
344
345          // Compute entropy of current row
346          sumP = double.Epsilon;
347          for (var m = 0; m < k; m++) sumP += curP[m];
348          var h = .0;
349          for (var m = 0; m < k; m++) h += beta * (distances[m + 1] * curP[m]);
350          h = h / sumP + Math.Log(sumP);
351
352          // Evaluate whether the entropy is within the tolerance level
353          var hdiff = h - Math.Log(perplexity);
354          if (hdiff < tol && -hdiff < tol) {
355            found = true;
356          } else {
357            if (hdiff > 0) {
358              minBeta = beta;
359              if (maxBeta.IsAlmost(double.MaxValue) || maxBeta.IsAlmost(double.MinValue))
360                beta *= 2.0;
361              else
362                beta = (beta + maxBeta) / 2.0;
363            } else {
364              maxBeta = beta;
365              if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
366                beta /= 2.0;
367              else
368                beta = (beta + minBeta) / 2.0;
369            }
370          }
371
372          // Update iteration counter
373          iter++;
374        }
375
376        // Row-normalize current row of P and store in matrix
377        for (var m = 0; m < k; m++) curP[m] /= sumP;
378        for (var m = 0; m < k; m++) {
379          colP[rowP[i] + m] = indices[m + 1].Index;
380          valP[rowP[i] + m] = curP[m];
381        }
382      }
383    }
384    private void ComputeGaussianPerplexity(T[] x, int n, double[,] p, double perplexity) {
385      // Compute the squared Euclidean distance matrix
386      var dd = ComputeDistances(x);
387      // Compute the Gaussian kernel row by row
388
389      for (var i = 0; i < n; i++) {
390        // Initialize some variables
391        var found = false;
392        var beta = 1.0;
393        var minBeta = -double.MaxValue;
394        var maxBeta = double.MaxValue;
395        const double tol = 1e-5;
396        double sumP = 0;
397
398        // Iterate until we found a good perplexity
399        var iter = 0;
400        while (!found && iter < 200) {
401
402          // Compute Gaussian kernel row
403          for (var m = 0; m < n; m++) p[i, m] = Math.Exp(-beta * dd[i][m]);
404          p[i, i] = double.Epsilon;
405
406          // Compute entropy of current row
407          sumP = double.Epsilon;
408          for (var m = 0; m < n; m++) sumP += p[i, m];
409          var h = 0.0;
410          for (var m = 0; m < n; m++) h += beta * (dd[i][m] * p[i, m]);
411          h = h / sumP + Math.Log(sumP);
412
413          // Evaluate whether the entropy is within the tolerance level
414          var hdiff = h - Math.Log(perplexity);
415          if (hdiff < tol && -hdiff < tol) {
416            found = true;
417          } else {
418            if (hdiff > 0) {
419              minBeta = beta;
420              if (maxBeta.IsAlmost(double.MaxValue) || maxBeta.IsAlmost(double.MinValue))
421                beta *= 2.0;
422              else
423                beta = (beta + maxBeta) / 2.0;
424            } else {
425              maxBeta = beta;
426              if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
427                beta /= 2.0;
428              else
429                beta = (beta + minBeta) / 2.0;
430            }
431          }
432
433          // Update iteration counter
434          iter++;
435        }
436
437        // Row normalize P
438        for (var m = 0; m < n; m++) p[i, m] /= sumP;
439      }
440    }
441    private double[][] ComputeDistances(T[] x) {
442      return x.Select(m => x.Select(n => distance.Get(m, n)).ToArray()).ToArray();
443    }
444    private static void ComputeExactGradient(double[,] p, double[,] y, int n, int d, double[,] dC) {
445
446      // Make sure the current gradient contains zeros
447      for (var i = 0; i < n; i++) for (var j = 0; j < d; j++) dC[i, j] = 0.0;
448
449      // Compute the squared Euclidean distance matrix
450      var dd = new double[n, n];
451      ComputeSquaredEuclideanDistance(y, n, d, dd);
452
453      // Compute Q-matrix and normalization sum
454      var q = new double[n, n];
455      var sumQ = .0;
456      for (var n1 = 0; n1 < n; n1++) {
457        for (var m = 0; m < n; m++) {
458          if (n1 == m) continue;
459          q[n1, m] = 1 / (1 + dd[n1, m]);
460          sumQ += q[n1, m];
461        }
462      }
463
464      // Perform the computation of the gradient
465      for (var n1 = 0; n1 < n; n1++) {
466        for (var m = 0; m < n; m++) {
467          if (n1 == m) continue;
468          var mult = (p[n1, m] - q[n1, m] / sumQ) * q[n1, m];
469          for (var d1 = 0; d1 < d; d1++) {
470            dC[n1, d1] += (y[n1, d1] - y[m, d1]) * mult;
471          }
472        }
473      }
474    }
475    private static void ComputeSquaredEuclideanDistance(double[,] x, int n, int d, double[,] dd) {
476      var dataSums = new double[n];
477      for (var i = 0; i < n; i++) {
478        for (var j = 0; j < d; j++) {
479          dataSums[i] += x[i, j] * x[i, j];
480        }
481      }
482      for (var i = 0; i < n; i++) {
483        for (var m = 0; m < n; m++) {
484          dd[i, m] = dataSums[i] + dataSums[m];
485        }
486      }
487      for (var i = 0; i < n; i++) {
488        dd[i, i] = 0.0;
489        for (var m = i + 1; m < n; m++) {
490          dd[i, m] = 0.0;
491          for (var j = 0; j < d; j++) {
492            dd[i, m] += (x[i, j] - x[m, j]) * (x[i, j] - x[m, j]);
493          }
494          dd[m, i] = dd[i, m];
495        }
496      }
497    }
498    private static void ComputeGradient(int[] rowP, int[] colP, double[] valP, double[,] y, int n, int d, double[,] dC, double theta) {
499      var tree = new SPTree(y);
500      double[] sumQ = { 0 };
501      var posF = new double[n, d];
502      var negF = new double[n, d];
503      tree.ComputeEdgeForces(rowP, colP, valP, n, posF);
504      var row = new double[d];
505      for (int n1 = 0; n1 < n; n1++) {
506        Buffer.BlockCopy(negF, (sizeof(double) * n1 * d), row, 0, d);
507        tree.ComputeNonEdgeForces(n1, theta, row, sumQ);
508      }
509
510      // Compute final t-SNE gradient
511      for (var i = 0; i < n; i++)
512        for (var j = 0; j < d; j++) {
513          dC[i, j] = posF[i, j] - negF[i, j] / sumQ[0];
514        }
515    }
516    private static double EvaluateError(double[,] p, double[,] y, int n, int d) {
517      // Compute the squared Euclidean distance matrix
518      var dd = new double[n, n];
519      var q = new double[n, n];
520      ComputeSquaredEuclideanDistance(y, n, d, dd);
521
522      // Compute Q-matrix and normalization sum
523      var sumQ = double.Epsilon;
524      for (var n1 = 0; n1 < n; n1++) {
525        for (var m = 0; m < n; m++) {
526          if (n1 != m) {
527            q[n1, m] = 1 / (1 + dd[n1, m]);
528            sumQ += q[n1, m];
529          } else q[n1, m] = double.Epsilon;
530        }
531      }
532      for (var i = 0; i < n; i++) for (var j = 0; j < n; j++) q[i, j] /= sumQ;
533
534      // Sum t-SNE error
535      var c = .0;
536      for (var i = 0; i < n; i++)
537        for (var j = 0; j < n; j++) {
538          c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
539        }
540      return c;
541    }
542    private static double EvaluateError(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, double[,] y, double theta) {
543      // Get estimate of normalization term
544      var n = y.GetLength(0);
545      var d = y.GetLength(1);
546      var tree = new SPTree(y);
547      var buff = new double[d];
548      double[] sumQ = { 0 };
549      for (var i = 0; i < n; i++) tree.ComputeNonEdgeForces(i, theta, buff, sumQ);
550
551      // Loop over all edges to compute t-SNE error
552      var c = .0;
553      for (var k = 0; k < n; k++) {
554        for (var i = rowP[k]; i < rowP[k + 1]; i++) {
555          var q = .0;
556          for (var j = 0; j < d; j++) buff[j] = y[k, j];
557          for (var j = 0; j < d; j++) buff[j] -= y[colP[i], j];
558          for (var j = 0; j < d; j++) q += buff[j] * buff[j];
559          q = 1.0 / (1.0 + q) / sumQ[0];
560          c += valP[i] * Math.Log((valP[i] + float.Epsilon) / (q + float.Epsilon));
561        }
562      }
563      return c;
564    }
565    private static void SymmetrizeMatrix(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, out int[] symRowP, out int[] symColP, out double[] symValP) {
566
567      // Count number of elements and row counts of symmetric matrix
568      var n = rowP.Count - 1;
569      var rowCounts = new int[n];
570      for (var j = 0; j < n; j++) {
571        for (var i = rowP[j]; i < rowP[j + 1]; i++) {
572
573          // Check whether element (col_P[i], n) is present
574          var present = false;
575          for (var m = rowP[colP[i]]; m < rowP[colP[i] + 1]; m++) {
576            if (colP[m] == j) present = true;
577          }
578          if (present) rowCounts[j]++;
579          else {
580            rowCounts[j]++;
581            rowCounts[colP[i]]++;
582          }
583        }
584      }
585      var noElem = 0;
586      for (var i = 0; i < n; i++) noElem += rowCounts[i];
587
588      // Allocate memory for symmetrized matrix
589      symRowP = new int[n + 1];
590      symColP = new int[noElem];
591      symValP = new double[noElem];
592
593      // Construct new row indices for symmetric matrix
594      symRowP[0] = 0;
595      for (var i = 0; i < n; i++) symRowP[i + 1] = symRowP[i] + rowCounts[i];
596
597      // Fill the result matrix
598      var offset = new int[n];
599      for (var j = 0; j < n; j++) {
600        for (var i = rowP[j]; i < rowP[j + 1]; i++) {                                  // considering element(n, colP[i])
601
602          // Check whether element (col_P[i], n) is present
603          var present = false;
604          for (var m = rowP[colP[i]]; m < rowP[colP[i] + 1]; m++) {
605            if (colP[m] != j) continue;
606            present = true;
607            if (j > colP[i]) continue; // make sure we do not add elements twice
608            symColP[symRowP[j] + offset[j]] = colP[i];
609            symColP[symRowP[colP[i]] + offset[colP[i]]] = j;
610            symValP[symRowP[j] + offset[j]] = valP[i] + valP[m];
611            symValP[symRowP[colP[i]] + offset[colP[i]]] = valP[i] + valP[m];
612          }
613
614          // If (colP[i], n) is not present, there is no addition involved
615          if (!present) {
616            symColP[symRowP[j] + offset[j]] = colP[i];
617            symColP[symRowP[colP[i]] + offset[colP[i]]] = j;
618            symValP[symRowP[j] + offset[j]] = valP[i];
619            symValP[symRowP[colP[i]] + offset[colP[i]]] = valP[i];
620          }
621
622          // Update offsets
623          if (present && (j > colP[i])) continue;
624          offset[j]++;
625          if (colP[i] != j) offset[colP[i]]++;
626        }
627      }
628
629      // Divide the result by two
630      for (var i = 0; i < noElem; i++) symValP[i] /= 2.0;
631    }
632    private static void ZeroMean(double[,] x) {
633      // Compute data mean
634      var n = x.GetLength(0);
635      var d = x.GetLength(1);
636      var mean = new double[d];
637      for (var i = 0; i < n; i++) {
638        for (var j = 0; j < d; j++) {
639          mean[j] += x[i, j];
640        }
641      }
642      for (var i = 0; i < d; i++) {
643        mean[i] /= n;
644      }
645      // Subtract data mean
646      for (var i = 0; i < n; i++) {
647        for (var j = 0; j < d; j++) {
648          x[i, j] -= mean[j];
649        }
650      }
651    }
652    #endregion
653  }
654}
Note: See TracBrowser for help on using the repository browser.