Free cookie consent management tool by TermsFeed Policy Generator

Changeset 14512


Ignore:
Timestamp:
12/20/16 15:50:11 (7 years ago)
Author:
bwerth
Message:

#2700 worked in several comments from mkommend, extended analysis during algorithm run, added more Distances, made algorithm stoppable

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
18 added
7 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:ignore
      •  

        old new  
         1*.user
         2*.vs10x
         3.vs
         4HeuristicLab.Algorithms.DataAnalysis-3.4.csproj.user
         5HeuristicLabAlgorithmsDataAnalysisPlugin.cs
         6Plugin.cs
        17bin
        28obj
        3 HeuristicLabAlgorithmsDataAnalysisPlugin.cs
        4 HeuristicLab.Algorithms.DataAnalysis-3.4.csproj.user
        5 *.vs10x
        6 Plugin.cs
        7 *.user
        8 .vs
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r14413 r14512  
    331331    <Compile Include="Interfaces\ISupportVectorMachineSolution.cs" />
    332332    <Compile Include="Interfaces\IDataAnalysisAlgorithm.cs" />
     333    <Compile Include="Interfaces\TSNEInterfaces\IKernelFunction.cs" />
    333334    <Compile Include="kMeans\KMeansClustering.cs" />
    334335    <Compile Include="kMeans\KMeansClusteringModel.cs" />
     
    337338      <SubType>Code</SubType>
    338339    </Compile>
     340    <Compile Include="KPCA\SelfOrganizingMap.cs" />
     341    <Compile Include="KPCA\KernelFunctions\CicularKernel.cs" />
     342    <Compile Include="KPCA\KernelFunctions\GaussianKernel.cs" />
     343    <Compile Include="KPCA\KernelFunctions\InverseMultiquadraticKernel .cs" />
     344    <Compile Include="KPCA\KernelFunctions\LaplacianKernel.cs" />
     345    <Compile Include="KPCA\KernelFunctions\MultiquadraticKernel.cs" />
     346    <Compile Include="KPCA\KernelFunctions\NoKernel.cs" />
     347    <Compile Include="KPCA\KernelFunctions\PolysplineKernel.cs" />
     348    <Compile Include="KPCA\KernelFunctions\RadialBasisKernelBase.cs" />
     349    <Compile Include="KPCA\KernelFunctions\ThinPlatePolysplineKernel.cs" />
     350    <Compile Include="KPCA\KernelFunctions\TricubicKernel.cs" />
     351    <Compile Include="KPCA\KernelPrincipleComponentAnalysis.cs" />
     352    <Compile Include="KPCA\Isomap.cs" />
     353    <Compile Include="KPCA\KPCA.cs" />
     354    <Compile Include="KPCA\MatrixUtilities.cs" />
    339355    <Compile Include="Linear\AlglibUtil.cs" />
    340356    <Compile Include="Linear\Scaling.cs" />
     
    415431    <Compile Include="TSNE\Cell.cs" />
    416432    <Compile Include="TSNE\DataPoint.cs" />
     433    <Compile Include="TSNE\Distances\FuctionalDistance.cs" />
    417434    <Compile Include="TSNE\Distances\DistanceBase.cs" />
    418435    <Compile Include="TSNE\Distances\DataPointDistance.cs" />
    419436    <Compile Include="TSNE\Distances\EuclidianDistance.cs" />
    420     <Compile Include="TSNE\TSNEInterfaces\IDistance.cs" />
     437    <Compile Include="Interfaces\TSNEInterfaces\IDistance.cs" />
     438    <Compile Include="TSNE\Distances\InnerProductDistance.cs" />
    421439    <Compile Include="TSNE\TSNEAnalysis.cs" />
    422440    <Compile Include="TSNE\PriorityQueue.cs" />
    423441    <Compile Include="TSNE\SPtree.cs" />
    424442    <Compile Include="TSNE\TSNE.cs" />
    425     <Compile Include="TSNE\TSNEInterfaces\ICell.cs" />
    426     <Compile Include="TSNE\TSNEInterfaces\IDataPoint.cs" />
    427     <Compile Include="TSNE\TSNEInterfaces\IHeap.cs" />
    428     <Compile Include="TSNE\TSNEInterfaces\ISPTree.cs" />
    429     <Compile Include="TSNE\TSNEInterfaces\ITSNE.cs" />
    430     <Compile Include="TSNE\TSNEInterfaces\IVPTree.cs" />
     443    <Compile Include="Interfaces\TSNEInterfaces\ICell.cs" />
     444    <Compile Include="Interfaces\TSNEInterfaces\IDataPoint.cs" />
     445    <Compile Include="Interfaces\TSNEInterfaces\IHeap.cs" />
     446    <Compile Include="Interfaces\TSNEInterfaces\ISPTree.cs" />
     447    <Compile Include="Interfaces\TSNEInterfaces\ITSNE.cs" />
     448    <Compile Include="Interfaces\TSNEInterfaces\IVPTree.cs" />
    431449    <Compile Include="TSNE\TSNEUtils.cs" />
    432450    <Compile Include="TSNE\VPTree.cs" />
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/DataPointDistance.cs

    r14414 r14512  
    2020#endregion
    2121
    22 using HeuristicLab.Algorithms.DataAnalysis.Distances;
    2322using HeuristicLab.Common;
    2423using HeuristicLab.Core;
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/DistanceBase.cs

    r14414 r14512  
    2525using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2626
    27 namespace HeuristicLab.Algorithms.DataAnalysis.Distances {
     27namespace HeuristicLab.Algorithms.DataAnalysis {
    2828  [StorableClass]
    2929  public abstract class DistanceBase<T> : Item, IDistance<T> {
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/Distances/EuclidianDistance.cs

    r14414 r14512  
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using HeuristicLab.Algorithms.DataAnalysis.Distances;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNE.cs

    r14414 r14512  
    5757using System.Collections.Generic;
    5858using System.Linq;
     59using HeuristicLab.Analysis;
    5960using HeuristicLab.Common;
    6061using HeuristicLab.Core;
     
    7071    private const string IterationResultName = "Iteration";
    7172    private const string ErrorResultName = "Error";
     73    private const string ErrorPlotResultName = "ErrorPlot";
     74    private const string ScatterPlotResultName = "Scatterplot";
     75    private const string DataResultName = "Projected Data";
    7276
    7377    #region Properties
     
    9094    [Storable]
    9195    private ResultCollection results;
     96    [Storable]
     97    private Dictionary<string, List<int>> dataRowLookup;
     98    [Storable]
     99    private Dictionary<string, ScatterPlotDataRow> dataRows;
     100    #endregion
     101
     102    #region Stopping
     103    public volatile bool Running;
    92104    #endregion
    93105
     
    105117      random = cloner.Clone(random);
    106118      results = cloner.Clone(results);
     119      dataRowLookup = original.dataRowLookup.ToDictionary(entry => entry.Key, entry => entry.Value.Select(x => x).ToList());
     120      dataRows = original.dataRows.ToDictionary(entry => entry.Key, entry => cloner.Clone(entry.Value));
    107121    }
    108122    public override IDeepCloneable Clone(Cloner cloner) { return new TSNE<T>(this, cloner); }
    109     public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0) {
     123    public TSNE(IDistance<T> distance, IRandom random, ResultCollection results = null, int maxIter = 1000, int stopLyingIter = 250, int momSwitchIter = 250, double momentum = .5, double finalMomentum = .8, double eta = 200.0, Dictionary<string, List<int>> dataRowLookup = null, Dictionary<string, ScatterPlotDataRow> dataRows = null) {
    110124      this.distance = distance;
    111125      this.maxIter = maxIter;
     
    117131      this.random = random;
    118132      this.results = results;
     133      this.dataRowLookup = dataRowLookup;
     134      if (dataRows != null)
     135        this.dataRows = dataRows;
     136      else { this.dataRows = new Dictionary<string, ScatterPlotDataRow>(); }
    119137    }
    120138    #endregion
     
    124142      var noDatapoints = data.Length;
    125143      if (noDatapoints - 1 < 3 * perplexity) throw new ArgumentException("Perplexity too large for the number of data points!");
    126 
    127       if (results != null) {
    128         if (!results.ContainsKey(IterationResultName)) {
    129           results.Add(new Result(IterationResultName, new IntValue(0)));
    130         } else ((IntValue)results[IterationResultName].Value).Value = 0;
    131         if (!results.ContainsKey(ErrorResultName)) {
    132           results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    133         } else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    134       }
    135 
    136       // Determine whether we are using an exact algorithm
     144      SetUpResults(data);
     145      Running = true;
    137146      var exact = Math.Abs(theta) < double.Epsilon;
    138147      var newData = new double[noDatapoints, newDimensions];
     
    141150      var gains = new double[noDatapoints, newDimensions];
    142151      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = 1.0;
    143 
    144       // Compute input similarities for exact t-SNE
    145152      double[,] p = null;
    146153      int[] rowP = null;
    147154      int[] colP = null;
    148155      double[] valP = null;
    149       if (exact) {
    150         // Compute similarities
    151         p = new double[noDatapoints, noDatapoints];
    152         ComputeGaussianPerplexity(data, noDatapoints, p, perplexity);
    153         // Symmetrize input similarities
    154         for (var n = 0; n < noDatapoints; n++) {
    155           for (var m = n + 1; m < noDatapoints; m++) {
    156             p[n, m] += p[m, n];
    157             p[m, n] = p[n, m];
    158           }
    159         }
    160         var sumP = .0;
    161         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) sumP += p[i, j];
    162         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= sumP;
    163       } // Compute input similarities for approximate t-SNE
    164       else {
    165         // Compute asymmetric pairwise input similarities
    166         ComputeGaussianPerplexity(data, noDatapoints, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
    167         // Symmetrize input similarities
    168         int[] sRowP, symColP;
    169         double[] sValP;
    170         SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
    171         rowP = sRowP;
    172         colP = symColP;
    173         valP = sValP;
    174         var sumP = .0;
    175         for (var i = 0; i < rowP[noDatapoints]; i++) sumP += valP[i];
    176         for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= sumP;
    177       }
     156      var rand = new NormalDistributedRandom(random, 0, 1);
     157
     158      //Calculate Similarities
     159      if (exact) p = CalculateExactSimilarites(data, perplexity);
     160      else CalculateApproximateSimilarities(data, perplexity, out rowP, out colP, out valP);
    178161
    179162      // Lie about the P-values
    180       if (exact) {
    181         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
    182       } else {
    183         for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
    184       }
    185 
    186       var rand = new NormalDistributedRandom(random, 0, 1);
     163      if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] *= 12.0;
     164      else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] *= 12.0;
     165
    187166      // Initialize solution (randomly)
    188       for (var i = 0; i < noDatapoints; i++)
    189         for (var j = 0; j < newDimensions; j++)
    190           newData[i, j] = rand.NextDouble() * .0001;
     167      for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = rand.NextDouble() * .0001;
    191168
    192169      // Perform main training loop
    193       for (var iter = 0; iter < maxIter; iter++) {
    194 
    195         // Compute (approximate) gradient
     170      for (var iter = 0; iter < maxIter && Running; iter++) {
    196171        if (exact) ComputeExactGradient(p, newData, noDatapoints, newDimensions, dY);
    197172        else ComputeGradient(rowP, colP, valP, newData, noDatapoints, newDimensions, dY, theta);
    198 
    199173        // Update gains
    200         for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++)
    201             gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
     174        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) gains[i, j] = Math.Sign(dY[i, j]) != Math.Sign(uY[i, j]) ? gains[i, j] + .2 : gains[i, j] * .8;
    202175        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) if (gains[i, j] < .01) gains[i, j] = .01;
    203 
    204176        // Perform gradient update (with momentum and gains)
    205177        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) uY[i, j] = currentMomentum * uY[i, j] - eta * gains[i, j] * dY[i, j];
    206178        for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < newDimensions; j++) newData[i, j] = newData[i, j] + uY[i, j];
    207 
    208179        // Make solution zero-mean
    209180        ZeroMean(newData);
    210 
    211181        // Stop lying about the P-values after a while, and switch momentum
    212182        if (iter == stopLyingIter) {
    213           if (exact) {
    214             for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
    215           } else {
    216             for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
    217           }
     183          if (exact) for (var i = 0; i < noDatapoints; i++) for (var j = 0; j < noDatapoints; j++) p[i, j] /= 12.0;
     184          else for (var i = 0; i < rowP[noDatapoints]; i++) valP[i] /= 12.0;
    218185        }
    219186        if (iter == momSwitchIter) currentMomentum = finalMomentum;
    220187
    221         if (results == null) continue;
    222         var errors = new List<double>();
    223         // Print out progress
    224         var c = exact
    225           ? EvaluateError(p, newData, noDatapoints, newDimensions)
    226           : EvaluateError(rowP, colP, valP, newData, theta);
    227         errors.Add(c);
    228         ((IntValue)results[IterationResultName].Value).Value = iter + 1;
    229         ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
     188        Analyze(exact, iter, p, rowP, colP, valP, newData, noDatapoints, newDimensions, theta);
    230189      }
    231190      return newData;
     
    234193      return new TSNE<TR>(distance, random).Run(data, newDimensions, perplexity, theta);
    235194    }
     195    public static double[,] Run<TR>(TR[] data, int newDimensions, double perplexity, double theta, Func<TR, TR, double> distance, IRandom random) where TR : class, IDeepCloneable {
     196      return new TSNE<TR>(new FuctionalDistance<TR>(distance), random).Run(data, newDimensions, perplexity, theta);
     197    }
    236198
    237199    #region helpers
     200
     201    private void SetUpResults(IReadOnlyCollection<T> data) {
     202      if (dataRowLookup == null) {
     203        dataRowLookup = new Dictionary<string, List<int>>();
     204        dataRowLookup.Add("Data", Enumerable.Range(0, data.Count).ToList());
     205      }
     206      if (results == null) return;
     207      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     208      else ((IntValue)results[IterationResultName].Value).Value = 0;
     209
     210      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     211      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
     212
     213      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent")));
     214      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during Gradiant descent");
     215
     216      var plot = results[ErrorPlotResultName].Value as DataTable;
     217      if (plot == null) throw new ArgumentException("could not create/access Error-DataTable in Results-Collection");
     218      if (!plot.Rows.ContainsKey("errors")) {
     219        plot.Rows.Add(new DataRow("errors"));
     220      }
     221      plot.Rows["errors"].Values.Clear();
     222      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     223      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     224
     225    }
     226    private void Analyze(bool exact, int iter, double[,] p, int[] rowP, int[] colP, double[] valP, double[,] newData, int noDatapoints, int newDimensions, double theta) {
     227      if (results == null) return;
     228      var plot = results[ErrorPlotResultName].Value as DataTable;
     229      if (plot == null) throw new ArgumentException("Could not create/access Error-DataTable in Results-Collection. Was it removed by some effect?");
     230      var errors = plot.Rows["errors"].Values;
     231      var c = exact
     232        ? EvaluateError(p, newData, noDatapoints, newDimensions)
     233        : EvaluateError(rowP, colP, valP, newData, theta);
     234      errors.Add(c);
     235      ((IntValue)results[IterationResultName].Value).Value = iter + 1;
     236      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
     237
     238      var ndata = Normalize(newData);
     239      results[DataResultName].Value = new DoubleMatrix(ndata);
     240      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
     241      FillScatterPlot(ndata, splot);
     242
     243
     244    }
     245    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
     246      foreach (var rowName in dataRowLookup.Keys) {
     247        if (!plot.Rows.ContainsKey(rowName)) {
     248          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     249        } else plot.Rows[rowName].Points.Clear();
     250        plot.Rows[rowName].Points.AddRange(dataRowLookup[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     251      }
     252    }
     253    private static double[,] Normalize(double[,] data) {
     254      var max = new double[data.GetLength(1)];
     255      var min = new double[data.GetLength(1)];
     256      var res = new double[data.GetLength(0), data.GetLength(1)];
     257      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
     258      for (var i = 0; i < data.GetLength(0); i++)
     259        for (var j = 0; j < data.GetLength(1); j++) {
     260          var v = data[i, j];
     261          max[j] = Math.Max(max[j], v);
     262          min[j] = Math.Min(min[j], v);
     263        }
     264      for (var i = 0; i < data.GetLength(0); i++) {
     265        for (var j = 0; j < data.GetLength(1); j++) {
     266          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
     267        }
     268      }
     269      return res;
     270    }
     271    private void CalculateApproximateSimilarities(T[] data, double perplexity, out int[] rowP, out int[] colP, out double[] valP) {
     272      // Compute asymmetric pairwise input similarities
     273      ComputeGaussianPerplexity(data, data.Length, out rowP, out colP, out valP, perplexity, (int)(3 * perplexity));
     274      // Symmetrize input similarities
     275      int[] sRowP, symColP;
     276      double[] sValP;
     277      SymmetrizeMatrix(rowP, colP, valP, out sRowP, out symColP, out sValP);
     278      rowP = sRowP;
     279      colP = symColP;
     280      valP = sValP;
     281      var sumP = .0;
     282      for (var i = 0; i < rowP[data.Length]; i++) sumP += valP[i];
     283      for (var i = 0; i < rowP[data.Length]; i++) valP[i] /= sumP;
     284    }
     285    private double[,] CalculateExactSimilarites(T[] data, double perplexity) {
     286      // Compute similarities
     287      var p = new double[data.Length, data.Length];
     288      ComputeGaussianPerplexity(data, data.Length, p, perplexity);
     289      // Symmetrize input similarities
     290      for (var n = 0; n < data.Length; n++) {
     291        for (var m = n + 1; m < data.Length; m++) {
     292          p[n, m] += p[m, n];
     293          p[m, n] = p[n, m];
     294        }
     295      }
     296      var sumP = .0;
     297      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) sumP += p[i, j];
     298      for (var i = 0; i < data.Length; i++) for (var j = 0; j < data.Length; j++) p[i, j] /= sumP;
     299      return p;
     300    }
     301
    238302    private void ComputeGaussianPerplexity(IReadOnlyList<T> x, int n, out int[] rowP, out int[] colP, out double[] valP, double perplexity, int k) {
    239303      if (perplexity > k) throw new ArgumentException("Perplexity should be lower than K!");
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAnalysis.cs

    r14414 r14512  
    2020#endregion
    2121
    22 using System;
     22using System.Collections.Generic;
     23using System.Drawing;
    2324using System.Linq;
    2425using HeuristicLab.Analysis;
     
    2728using HeuristicLab.Data;
    2829using HeuristicLab.Encodings.RealVectorEncoding;
    29 using HeuristicLab.Optimization;
    3030using HeuristicLab.Parameters;
    3131using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    6060    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
    6161    private const string SeedParameterName = "Seed";
     62    private const string ClassesParameterName = "ClassNames";
    6263    #endregion
    6364
     
    6768      get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
    6869    }
    69     public IFixedValueParameter<DoubleValue> ThetaParameter
    70     {
    71       get { return Parameters[ThetaParameterName] as IFixedValueParameter<DoubleValue>; }
     70    public OptionalValueParameter<DoubleValue> ThetaParameter
     71    {
     72      get { return Parameters[ThetaParameterName] as OptionalValueParameter<DoubleValue>; }
    7273    }
    7374    public IFixedValueParameter<IntValue> NewDimensionsParameter
     
    110111    {
    111112      get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
     113    }
     114    public IFixedValueParameter<StringValue> ClassesParameter
     115    {
     116      get { return Parameters[ClassesParameterName] as IFixedValueParameter<StringValue>; }
    112117    }
    113118    #endregion
     
    124129    public double Theta
    125130    {
    126       get { return ThetaParameter.Value.Value; }
     131      get { return ThetaParameter.Value == null ? 0 : ThetaParameter.Value.Value; }
    127132    }
    128133    public int NewDimensions
     
    152157    public double Eta
    153158    {
    154       get { return EtaParameter.Value.Value; }
     159      get
     160      {
     161        return EtaParameter.Value == null ? 0 : EtaParameter.Value.Value;
     162      }
    155163    }
    156164    public bool SetSeedRandomly
     
    162170      get { return (uint)SeedParameter.Value.Value; }
    163171    }
     172    public string Classes
     173    {
     174      get { return ClassesParameter.Value.Value; }
     175    }
     176
     177    [Storable]
     178    public TSNE<RealVector> tsne;
    164179    #endregion
    165180
     
    172187      Problem = new RegressionProblem();
    173188      Parameters.Add(new ValueParameter<IDistance<RealVector>>(DistanceParameterName, "The distance function used to differentiate similar from non-similar points", new EuclidianDistance()));
    174       Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-Parameter of TSNE. Comparable to k in a k-nearest neighbour algorithm", new DoubleValue(25)));
    175       Parameters.Add(new FixedValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation", new DoubleValue(0.1)));
     189      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-Parameter of TSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended Value is Floor(number of points /3) or lower", new DoubleValue(25)));
     190      Parameters.Add(new OptionalValueParameter<DoubleValue>(ThetaParameterName, "Value describing how much appoximated gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise \n CAUTION: exact calculation of forces requires building a non-sparse N*N matrix where N is the number of data points\n This may exceed memory limitations", new DoubleValue(0.1)));
    176191      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis", new IntValue(2)));
    177192      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent", new IntValue(1000)));
     
    183198      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "If the seed should be random", new BoolValue(true)));
    184199      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random", new IntValue(0)));
     200      Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. \n if the lable column can not be found Training/Test is used as labels", new StringValue("none")));
    185201    }
    186202    #endregion
    187203
    188204    protected override void Run() {
    189       var lowDimData = new DoubleMatrix(GetProjectedData(Problem.ProblemData));
    190       Results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", CreateScatterPlot(lowDimData, Problem.ProblemData)));
    191       Results.Add(new Result(DataResultName, "Projected Data", lowDimData));
    192     }
    193 
    194     private ScatterPlot CreateScatterPlot(DoubleMatrix lowDimData, IDataAnalysisProblemData problemData) {
    195       var plot = new ScatterPlot(DataResultName, "");
    196       Normalize(lowDimData);
    197       plot.Rows.Add(new ScatterPlotDataRow("Training", "Points of the training set", problemData.TrainingIndices.Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))));
    198       plot.Rows.Add(new ScatterPlotDataRow("Test", "Points of the test set", problemData.TestIndices.Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1]))));
    199       return plot;
    200     }
    201 
    202     private double[,] GetProjectedData(IDataAnalysisProblemData problemData) {
     205      var data = CalculateProjectedData(Problem.ProblemData);
     206      var lowDimData = new DoubleMatrix(data);
     207    }
     208
     209    public override void Stop() {
     210      base.Stop();
     211      if (tsne != null) tsne.Running = false;
     212    }
     213
     214    private double[,] CalculateProjectedData(IDataAnalysisProblemData problemData) {
     215      var DataRowNames = new Dictionary<string, List<int>>();
     216      var rows = new Dictionary<string, ScatterPlotDataRow>();
     217
     218      if (problemData.Dataset.VariableNames.Contains(Classes)) {
     219        if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
     220          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
     221          for (int i = 0; i < classes.Length; i++) {
     222            if (!DataRowNames.ContainsKey(classes[i])) DataRowNames.Add(classes[i], new List<int>());
     223            DataRowNames[classes[i]].Add(i); //always succeeds
     224          }
     225        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     226          var classValues = problemData.Dataset.GetDoubleValues(Classes).ToArray();
     227          var max = classValues.Max() + 0.1;
     228          var min = classValues.Min() - 0.1;
     229          var contours = 8;
     230          for (int i = 0; i < contours; i++) {
     231            var name = GetContourName(i, min, max, contours);
     232            DataRowNames.Add(name, new List<int>());
     233            rows.Add(name, new ScatterPlotDataRow(name, "", new List<Point2D<double>>()));
     234            rows[name].VisualProperties.Color = GetHeatMapColor(i, contours);
     235            rows[name].VisualProperties.PointSize = i+3;
     236          }
     237          for (int i = 0; i < classValues.Length; i++) {
     238            DataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds
     239          }
     240
     241        }
     242
     243
     244      } else {
     245        DataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     246        DataRowNames.Add("Test", problemData.TestIndices.ToList());
     247      }
     248
    203249      var random = SetSeedRandomly ? new MersenneTwister() : new MersenneTwister(Seed);
    204       var tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
     250      tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, DataRowNames, rows);
    205251      var dataset = problemData.Dataset;
    206252      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
     
    210256    }
    211257
    212     private static void Normalize(DoubleMatrix data) {
    213       var max = new double[data.Columns];
    214       var min = new double[data.Columns];
    215       for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    216       for (var i = 0; i < data.Rows; i++)
    217         for (var j = 0; j < data.Columns; j++) {
    218           var v = data[i, j];
    219           max[j] = Math.Max(max[j], v);
    220           min[j] = Math.Min(min[j], v);
    221         }
    222       for (var i = 0; i < data.Rows; i++) {
    223         for (var j = 0; j < data.Columns; j++) {
    224           data[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
    225         }
    226       }
    227 
    228     }
     258    private static Color GetHeatMapColor(int contourNr, int noContours) {
     259      var q = (double)contourNr / noContours;  // q in [0,1]
     260      var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0);
     261      return c;
     262    }
     263    private static string GetContourName(double value, double min, double max, int noContours) {
     264      var size = (max - min) / noContours;
     265      var contourNr = (int)((value - min) / size);
     266      return GetContourName(contourNr, min, max, noContours);
     267    }
     268    private static string GetContourName(int i, double min, double max, int noContours) {
     269      var size = (max - min) / noContours;
     270      return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
     271    }
     272
    229273  }
    230274}
Note: See TracChangeset for help on using the changeset viewer.