Free cookie consent management tool by TermsFeed Policy Generator

source: branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNE.cs @ 14512

Last change on this file since 14512 was 14512, checked in by bwerth, 7 years ago

#2700 worked in several comments from mkommend, extended analysis during algorithm run, added more Distances, made algorithm stoppable

File size: 26.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20
21//Code is based on an implementation from Laurens van der Maaten
22
23/*
24*
25* Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
26* All rights reserved.
27*
28* Redistribution and use in source and binary forms, with or without
29* modification, are permitted provided that the following conditions are met:
30* 1. Redistributions of source code must retain the above copyright
31*    notice, this list of conditions and the following disclaimer.
32* 2. Redistributions in binary form must reproduce the above copyright
33*    notice, this list of conditions and the following disclaimer in the
34*    documentation and/or other materials provided with the distribution.
35* 3. All advertising materials mentioning features or use of this software
36*    must display the following acknowledgement:
37*    This product includes software developed by the Delft University of Technology.
38* 4. Neither the name of the Delft University of Technology nor the names of
39*    its contributors may be used to endorse or promote products derived from
40*    this software without specific prior written permission.
41*
42* THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
43* OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
44* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
45* EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
46* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
47* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
48* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
49* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
50* IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
51* OF SUCH DAMAGE.
52*
53*/
54#endregion
55
56using System;
57using System.Collections.Generic;
58using System.Linq;
59using HeuristicLab.Analysis;
60using HeuristicLab.Common;
61using HeuristicLab.Core;
62using HeuristicLab.Data;
63using HeuristicLab.Optimization;
64using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
65using HeuristicLab.Random;
66
67namespace HeuristicLab.Algorithms.DataAnalysis {
68  [StorableClass]
69  public class TSNE<T> : Item, ITSNE<T> where T : class, IDeepCloneable {
70
71    private const string IterationResultName = "Iteration";
72    private const string ErrorResultName = "Error";
73    private const string ErrorPlotResultName = "ErrorPlot";
74    private const string ScatterPlotResultName = "Scatterplot";
75    private const string DataResultName = "Projected Data";
76
77    #region Properties
78    [Storable]
79    private IDistance<T> distance;
80    [Storable]
81    private int maxIter;
82    [Storable]
83    private int stopLyingIter;
84    [Storable]
85    private int momSwitchIter;
86    [Storable]
87    double momentum;
88    [Storable]
89    private double finalMomentum;
90    [Storable]
91    private double eta;
92    [Storable]
93    private IRandom random;
94    [Storable]
95    private ResultCollection results;
96    [Storable]
97    private Dictionary<string, List<int>> dataRowLookup;
98    [Storable]
99    private Dictionary<string, ScatterPlotDataRow> dataRows;
100    #endregion
101
102    #region Stopping
103    public volatile bool Running;
104    #endregion
105
106    #region HLConstructors & Cloning
107    [StorableConstructor]
108    protected TSNE(bool deserializing) : base(deserializing) { }
109    protected TSNE(TSNE<T> original, Cloner cloner) : base(original, cloner) {
110      distance = cloner.Clone(original.distance);
111      maxIter = original.maxIter;
112      stopLyingIter = original.stopLyingIter;
113      momSwitchIter = original.momSwitchIter;
114      momentum = original.momentum;
115      finalMomentum = original.finalMomentum;
116      eta = original.eta;
117      random = cloner.Clone(random);
118      results = cloner.Clone(results);
119      dataRowLookup = original.dataRowLookup.ToDictionary(entry => entry.Key, entry => entry.Value.Select(x => x).ToList());
120      dataRows = original.dataRows.ToDictionary(entry => entry.Key, entry => cloner.Clone(entry.Value));
121    }
122    public override IDeepCloneable Clone(Cloner cloner) { return new TSNE<T>(this, cloner); }
123    public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0, Dictionary<string, List<int>> dataRowLookup = null, Dictionary<string, ScatterPlotDataRow> dataRows = null) {
124      this.distance = distance;
125      this.maxIter = maxIter;
126      this.stopLyingIter = stopLyingIter;
127      this.momSwitchIter = momSwitchIter;
128      this.momentum = momentum;
129      this.finalMomentum = finalMomentum;
130      this.eta = eta;
131      this.random = random;
132      this.results = results;
133      this.dataRowLookup = dataRowLookup;
134      if (dataRows != null)
135        this.dataRows = dataRows;
136      else { this.dataRows = new Dictionary<string, ScatterPlotDataRow>(); }
137    }
138    #endregion
139
140    public double[,] Run(T[] data, int newDimensions, double perplexity, double theta) {
141      var currentMomentum = momentum;
142      var noDatapoints = data.Length;
143      if (noDatapoints - 1 < 3 * perplexity) throw new ArgumentException("Perplexity too large for the number of data points!");
144      SetUpResults(data);
145      Running = true;
146      var exact = Math.Abs(theta) < double.Epsilon;
147      var newData = new double[noDatapoints, newDimensions];
148      var dY = new double[noDatapoints, newDimensions];
149      var uY = new double[noDatapoints, newDimensions];
150      var gains = new double[noDatapoints, newDimensions];
151      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = 1.0;
152      double[,] p = null;
153      int[] rowP = null;
154      int[] colP = null;
155      double[] valP = null;
156      var rand = new NormalDistributedRandom(random, 0, 1);
157
158      //Calculate Similarities
159      if (exact) p = CalculateExactSimilarites(data, perplexity);
160      else CalculateApproximateSimilarities(data, perplexity, out rowP, out colP, out valP);
161
162      // Lie about the P-values
163      if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
164      else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
165
166      // Initialize solution (randomly)
167      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = rand.NextDouble() * .0001;
168
169      // Perform main training loop
170      for (var iter = 0; iter < maxIter && Running; iter++) {
171        if (exact) ComputeExactGradient(p, newData, noDatapoints, newDimensions, dY);
172        else ComputeGradient(rowP, colP, valP, newData, noDatapoints, newDimensions, dY, theta);
173        // Update gains
174        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
175        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) if (gains[i, j] < .01) gains[i, j] = .01;
176        // Perform gradient update (with momentum and gains)
177        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) uY[i, j] = currentMomentum * uY[i, j] - eta * gains[i, j] * dY[i, j];
178        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = newData[i, j] + uY[i, j];
179        // Make solution zero-mean
180        ZeroMean(newData);
181        // Stop lying about the P-values after a while, and switch momentum
182        if (iter == stopLyingIter) {
183          if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
184          else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
185        }
186        if (iter == momSwitchIter) currentMomentum = finalMomentum;
187
188        Analyze(exact, iter, p, rowP, colP, valP, newData, noDatapoints, newDimensions, theta);
189      }
190      return newData;
191    }
192    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, IDistance<TR> distance, IRandom random) where TR : class, IDeepCloneable {
193      return new TSNE<TR>(distance, random).Run(data, newDimensions, perplexity, theta);
194    }
195    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, Func<TR, TR, double> distance, IRandom random) where TR : class, IDeepCloneable {
196      return new TSNE<TR>(new FuctionalDistance<TR>(distance), random).Run(data, newDimensions, perplexity, theta);
197    }
198
199    #region helpers
200
201    private void SetUpResults(IReadOnlyCollection<T> data) {
202      if (dataRowLookup == null) {
203        dataRowLookup = new Dictionary<string, List<int>>();
204        dataRowLookup.Add("Data", Enumerable.Range(0, data.Count).ToList());
205      }
206      if (results == null) return;
207      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
208      else ((IntValue)results[IterationResultName].Value).Value = 0;
209
210      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
211      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
212
213      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent")));
214      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent");
215
216      var plot = results[ErrorPlotResultName].Value as DataTable;
217      if (plot == null) throw new ArgumentException("could not create/access Error-DataTable in Results-Collection");
218      if (!plot.Rows.ContainsKey("errors")) {
219        plot.Rows.Add(new DataRow("errors"));
220      }
221      plot.Rows["errors"].Values.Clear();
222      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
223      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
224
225    }
226    private void Analyze(bool exact, int iter, double[,] p, int[] rowP, int[] colP, double[] valP, double[,] newData, int noDatapoints, int newDimensions, double theta) {
227      if (results == null) return;
228      var plot = results[ErrorPlotResultName].Value as DataTable;
229      if (plot == null) throw new ArgumentException("Could not create/access Error-DataTable in Results-Collection. Was it removed by some effect?");
230      var errors = plot.Rows["errors"].Values;
231      var c = exact
232        ? EvaluateError(p, newData, noDatapoints, newDimensions)
233        : EvaluateError(rowP, colP, valP, newData, theta);
234      errors.Add(c);
235      ((IntValue)results[IterationResultName].Value).Value = iter + 1;
236      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
237
238      var ndata = Normalize(newData);
239      results[DataResultName].Value = new DoubleMatrix(ndata);
240      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
241      FillScatterPlot(ndata, splot);
242
243
244    }
245    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
246      foreach (var rowName in dataRowLookup.Keys) {
247        if (!plot.Rows.ContainsKey(rowName)) {
248          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
249        } else plot.Rows[rowName].Points.Clear();
250        plot.Rows[rowName].Points.AddRange(dataRowLookup[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
251      }
252    }
253    private static double[,] Normalize(double[,] data) {
254      var max = new double[data.GetLength(1)];
255      var min = new double[data.GetLength(1)];
256      var res = new double[data.GetLength(0), data.GetLength(1)];
257      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
258      for (var i = 0; i < data.GetLength(0); i++)
259        for (var j = 0; j < data.GetLength(1); j++) {
260          var v = data[i, j];
261          max[j] = Math.Max(max[j], v);
262          min[j] = Math.Min(min[j], v);
263        }
264      for (var i = 0; i < data.GetLength(0); i++) {
265        for (var j = 0; j < data.GetLength(1); j++) {
266          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
267        }
268      }
269      return res;
270    }
271    private void CalculateApproximateSimilarities(T[] data, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
272      // Compute asymmetric pairwise input similarities
273      ComputeGaussianPerplexity(data, data.Length, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
274      // Symmetrize input similarities
275      int[] sRowP, symColP;
276      double[] sValP;
277      SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
278      rowP = sRowP;
279      colP = symColP;
280      valP = sValP;
281      var sumP = .0;
282      for (var i = 0; i < rowP[data.Length]; i++) sumP += valP[i];
283      for (var i = 0; i < rowP[data.Length]; i++) valP[i] /= sumP;
284    }
285    private double[,] CalculateExactSimilarites(T[] data, double perplexity) {
286      // Compute similarities
287      var p = new double[data.Length, data.Length];
288      ComputeGaussianPerplexity(data, data.Length, p, perplexity);
289      // Symmetrize input similarities
290      for (var n = 0; n < data.Length; n++) {
291        for (var m = n + 1; m < data.Length; m++) {
292          p[n, m] += p[m, n];
293          p[m, n] = p[n, m];
294        }
295      }
296      var sumP = .0;
297      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) sumP += p[i, j];
298      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) p[i, j] /= sumP;
299      return p;
300    }
301
302    private void ComputeGaussianPerplexity(IReadOnlyList<T> x, int n, out int[] rowP, out int[] colP, out double[] valP, double perplexity, int k) {
303      if (perplexity > k) throw new ArgumentException("Perplexity should be lower than K!");
304
305      // Allocate the memory we need
306      rowP = new int[n + 1];
307      colP = new int[n * k];
308      valP = new double[n * k];
309      var curP = new double[n - 1];
310      rowP[0] = 0;
311      for (var i = 0; i < n; i++) rowP[i + 1] = rowP[i] + k;
312
313      // Build ball tree on data set
314      var tree = new VPTree<IDataPoint<T>>(new DataPointDistance<T>(distance));
315      var objX = new List<IDataPoint<T>>();
316      for (var i = 0; i < n; i++) objX.Add(new DataPoint<T>(i, x[i]));
317      tree.Create(objX);
318
319      // Loop over all points to find nearest neighbors
320      var indices = new List<IDataPoint<T>>();
321      var distances = new List<double>();
322      for (var i = 0; i < n; i++) {
323
324        // Find nearest neighbors
325        indices.Clear();
326        distances.Clear();
327        tree.Search(objX[i], k + 1, out indices, out distances);
328
329        // Initialize some variables for binary search
330        var found = false;
331        var beta = 1.0;
332        var minBeta = -double.MaxValue;
333        var maxBeta = double.MaxValue;
334        const double tol = 1e-5;
335
336        // Iterate until we found a good perplexity
337        var iter = 0; double sumP = 0;
338        while (!found && iter < 200) {
339
340          // Compute Gaussian kernel row
341          for (var m = 0; m < k; m++) curP[m] = Math.Exp(-beta * distances[m + 1]);
342
343          // Compute entropy of current row
344          sumP = double.Epsilon;
345          for (var m = 0; m < k; m++) sumP += curP[m];
346          var h = .0;
347          for (var m = 0; m < k; m++) h += beta * (distances[m + 1] * curP[m]);
348          h = h / sumP + Math.Log(sumP);
349
350          // Evaluate whether the entropy is within the tolerance level
351          var hdiff = h - Math.Log(perplexity);
352          if (hdiff < tol && -hdiff < tol) {
353            found = true;
354          } else {
355            if (hdiff > 0) {
356              minBeta = beta;
357              if (maxBeta.IsAlmost(double.MaxValue) || maxBeta.IsAlmost(double.MinValue))
358                beta *= 2.0;
359              else
360                beta = (beta + maxBeta) / 2.0;
361            } else {
362              maxBeta = beta;
363              if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
364                beta /= 2.0;
365              else
366                beta = (beta + minBeta) / 2.0;
367            }
368          }
369
370          // Update iteration counter
371          iter++;
372        }
373
374        // Row-normalize current row of P and store in matrix
375        for (var m = 0; m < k; m++) curP[m] /= sumP;
376        for (var m = 0; m < k; m++) {
377          colP[rowP[i] + m] = indices[m + 1].Index;
378          valP[rowP[i] + m] = curP[m];
379        }
380      }
381    }
382    private void ComputeGaussianPerplexity(T[] x, int n, double[,] p, double perplexity) {
383      // Compute the squared Euclidean distance matrix
384      var dd = ComputeDistances(x);
385      // Compute the Gaussian kernel row by row
386
387      for (var i = 0; i < n; i++) {
388        // Initialize some variables
389        var found = false;
390        var beta = 1.0;
391        var minBeta = -double.MaxValue;
392        var maxBeta = double.MaxValue;
393        const double tol = 1e-5;
394        double sumP = 0;
395
396        // Iterate until we found a good perplexity
397        var iter = 0;
398        while (!found && iter < 200) {
399
400          // Compute Gaussian kernel row
401          for (var m = 0; m < n; m++) p[i, m] = Math.Exp(-beta * dd[i][m]);
402          p[i, i] = double.Epsilon;
403
404          // Compute entropy of current row
405          sumP = double.Epsilon;
406          for (var m = 0; m < n; m++) sumP += p[i, m];
407          var h = 0.0;
408          for (var m = 0; m < n; m++) h += beta * (dd[i][m] * p[i, m]);
409          h = h / sumP + Math.Log(sumP);
410
411          // Evaluate whether the entropy is within the tolerance level
412          var hdiff = h - Math.Log(perplexity);
413          if (hdiff < tol && -hdiff < tol) {
414            found = true;
415          } else {
416            if (hdiff > 0) {
417              minBeta = beta;
418              if (maxBeta.IsAlmost(double.MaxValue) || maxBeta.IsAlmost(double.MinValue))
419                beta *= 2.0;
420              else
421                beta = (beta + maxBeta) / 2.0;
422            } else {
423              maxBeta = beta;
424              if (minBeta.IsAlmost(double.MinValue) || minBeta.IsAlmost(double.MaxValue))
425                beta /= 2.0;
426              else
427                beta = (beta + minBeta) / 2.0;
428            }
429          }
430
431          // Update iteration counter
432          iter++;
433        }
434
435        // Row normalize P
436        for (var m = 0; m < n; m++) p[i, m] /= sumP;
437      }
438    }
439    private double[][] ComputeDistances(T[] x) {
440      return x.Select(m => x.Select(n => distance.Get(m, n)).ToArray()).ToArray();
441    }
442    private static void ComputeExactGradient(double[,] p, double[,] y, int n, int d, double[,] dC) {
443
444      // Make sure the current gradient contains zeros
445      for (var i = 0; i < n; i++) for (var j = 0; j < d; j++) dC[i, j] = 0.0;
446
447      // Compute the squared Euclidean distance matrix
448      var dd = new double[n, n];
449      ComputeSquaredEuclideanDistance(y, n, d, dd);
450
451      // Compute Q-matrix and normalization sum
452      var q = new double[n, n];
453      var sumQ = .0;
454      for (var n1 = 0; n1 < n; n1++) {
455        for (var m = 0; m < n; m++) {
456          if (n1 == m) continue;
457          q[n1, m] = 1 / (1 + dd[n1, m]);
458          sumQ += q[n1, m];
459        }
460      }
461
462      // Perform the computation of the gradient
463      for (var n1 = 0; n1 < n; n1++) {
464        for (var m = 0; m < n; m++) {
465          if (n1 == m) continue;
466          var mult = (p[n1, m] - q[n1, m] / sumQ) * q[n1, m];
467          for (var d1 = 0; d1 < d; d1++) {
468            dC[n1, d1] += (y[n1, d1] - y[m, d1]) * mult;
469          }
470        }
471      }
472    }
473    private static void ComputeSquaredEuclideanDistance(double[,] x, int n, int d, double[,] dd) {
474      var dataSums = new double[n];
475      for (var i = 0; i < n; i++) {
476        for (var j = 0; j < d; j++) {
477          dataSums[i] += x[i, j] * x[i, j];
478        }
479      }
480      for (var i = 0; i < n; i++) {
481        for (var m = 0; m < n; m++) {
482          dd[i, m] = dataSums[i] + dataSums[m];
483        }
484      }
485      for (var i = 0; i < n; i++) {
486        dd[i, i] = 0.0;
487        for (var m = i + 1; m < n; m++) {
488          dd[i, m] = 0.0;
489          for (var j = 0; j < d; j++) {
490            dd[i, m] += (x[i, j] - x[m, j]) * (x[i, j] - x[m, j]);
491          }
492          dd[m, i] = dd[i, m];
493        }
494      }
495    }
496    private static void ComputeGradient(int[] rowP, int[] colP, double[] valP, double[,] y, int n, int d, double[,] dC, double theta) {
497      var tree = new SPTree(y);
498      double[] sumQ = { 0 };
499      var posF = new double[n, d];
500      var negF = new double[n, d];
501      tree.ComputeEdgeForces(rowP, colP, valP, n, posF);
502      var row = new double[d];
503      for (int n1 = 0; n1 < n; n1++) {
504        Buffer.BlockCopy(negF, (sizeof(double) * n1 * d), row, 0, d);
505        tree.ComputeNonEdgeForces(n1, theta, row, sumQ);
506      }
507
508      // Compute final t-SNE gradient
509      for (var i = 0; i < n; i++)
510        for (var j = 0; j < d; j++) {
511          dC[i, j] = posF[i, j] - negF[i, j] / sumQ[0];
512        }
513    }
514    private static double EvaluateError(double[,] p, double[,] y, int n, int d) {
515      // Compute the squared Euclidean distance matrix
516      var dd = new double[n, n];
517      var q = new double[n, n];
518      ComputeSquaredEuclideanDistance(y, n, d, dd);
519
520      // Compute Q-matrix and normalization sum
521      var sumQ = double.Epsilon;
522      for (var n1 = 0; n1 < n; n1++) {
523        for (var m = 0; m < n; m++) {
524          if (n1 != m) {
525            q[n1, m] = 1 / (1 + dd[n1, m]);
526            sumQ += q[n1, m];
527          } else q[n1, m] = double.Epsilon;
528        }
529      }
530      for (var i = 0; i < n; i++) for (var j = 0; j < n; j++) q[i, j] /= sumQ;
531
532      // Sum t-SNE error
533      var c = .0;
534      for (var i = 0; i < n; i++)
535        for (var j = 0; j < n; j++) {
536          c += p[i, j] * Math.Log((p[i, j] + float.Epsilon) / (q[i, j] + float.Epsilon));
537        }
538      return c;
539    }
540    private static double EvaluateError(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, double[,] y, double theta) {
541      // Get estimate of normalization term
542      var n = y.GetLength(0);
543      var d = y.GetLength(1);
544      var tree = new SPTree(y);
545      var buff = new double[d];
546      double[] sumQ = { 0 };
547      for (var i = 0; i < n; i++) tree.ComputeNonEdgeForces(i, theta, buff, sumQ);
548
549      // Loop over all edges to compute t-SNE error
550      var c = .0;
551      for (var k = 0; k < n; k++) {
552        for (var i = rowP[k]; i < rowP[k + 1]; i++) {
553          var q = .0;
554          for (var j = 0; j < d; j++) buff[j] = y[k, j];
555          for (var j = 0; j < d; j++) buff[j] -= y[colP[i], j];
556          for (var j = 0; j < d; j++) q += buff[j] * buff[j];
557          q = 1.0 / (1.0 + q) / sumQ[0];
558          c += valP[i] * Math.Log((valP[i] + float.Epsilon) / (q + float.Epsilon));
559        }
560      }
561      return c;
562    }
563    private static void SymmetrizeMatrix(IReadOnlyList<int> rowP, IReadOnlyList<int> colP, IReadOnlyList<double> valP, out int[] symRowP, out int[] symColP, out double[] symValP) {
564
565      // Count number of elements and row counts of symmetric matrix
566      var n = rowP.Count - 1;
567      var rowCounts = new int[n];
568      for (var j = 0; j < n; j++) {
569        for (var i = rowP[j]; i < rowP[j + 1]; i++) {
570
571          // Check whether element (col_P[i], n) is present
572          var present = false;
573          for (var m = rowP[colP[i]]; m < rowP[colP[i] + 1]; m++) {
574            if (colP[m] == j) present = true;
575          }
576          if (present) rowCounts[j]++;
577          else {
578            rowCounts[j]++;
579            rowCounts[colP[i]]++;
580          }
581        }
582      }
583      var noElem = 0;
584      for (var i = 0; i < n; i++) noElem += rowCounts[i];
585
586      // Allocate memory for symmetrized matrix
587      symRowP = new int[n + 1];
588      symColP = new int[noElem];
589      symValP = new double[noElem];
590
591      // Construct new row indices for symmetric matrix
592      symRowP[0] = 0;
593      for (var i = 0; i < n; i++) symRowP[i + 1] = symRowP[i] + rowCounts[i];
594
595      // Fill the result matrix
596      var offset = new int[n];
597      for (var j = 0; j < n; j++) {
598        for (var i = rowP[j]; i < rowP[j + 1]; i++) {                                  // considering element(n, colP[i])
599
600          // Check whether element (col_P[i], n) is present
601          var present = false;
602          for (var m = rowP[colP[i]]; m < rowP[colP[i] + 1]; m++) {
603            if (colP[m] != j) continue;
604            present = true;
605            if (j > colP[i]) continue; // make sure we do not add elements twice
606            symColP[symRowP[j] + offset[j]] = colP[i];
607            symColP[symRowP[colP[i]] + offset[colP[i]]] = j;
608            symValP[symRowP[j] + offset[j]] = valP[i] + valP[m];
609            symValP[symRowP[colP[i]] + offset[colP[i]]] = valP[i] + valP[m];
610          }
611
612          // If (colP[i], n) is not present, there is no addition involved
613          if (!present) {
614            symColP[symRowP[j] + offset[j]] = colP[i];
615            symColP[symRowP[colP[i]] + offset[colP[i]]] = j;
616            symValP[symRowP[j] + offset[j]] = valP[i];
617            symValP[symRowP[colP[i]] + offset[colP[i]]] = valP[i];
618          }
619
620          // Update offsets
621          if (present && (j > colP[i])) continue;
622          offset[j]++;
623          if (colP[i] != j) offset[colP[i]]++;
624        }
625      }
626
627      // Divide the result by two
628      for (var i = 0; i < noElem; i++) symValP[i] /= 2.0;
629    }
630    private static void ZeroMean(double[,] x) {
631      // Compute data mean
632      var n = x.GetLength(0);
633      var d = x.GetLength(1);
634      var mean = new double[d];
635      for (var i = 0; i < n; i++) {
636        for (var j = 0; j < d; j++) {
637          mean[j] += x[i, j];
638        }
639      }
640      for (var i = 0; i < d; i++) {
641        mean[i] /= n;
642      }
643      // Subtract data mean
644      for (var i = 0; i < n; i++) {
645        for (var j = 0; j < d; j++) {
646          x[i, j] -= mean[j];
647        }
648      }
649    }
650    #endregion
651  }
652}
Note: See TracBrowser for help on using the repository browser.