Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
12/22/16 10:08:25 (7 years ago)
Author:
bwerth
Message:

#2700 TSNEAnalysis is now a BasicAlg, hid some Parameters, added optional data normalization to make TSNE scaling-invariant

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAnalysis.cs

    r14512 r14518  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Drawing;
    2425using System.Linq;
     26using System.Threading;
    2527using HeuristicLab.Analysis;
    2628using HeuristicLab.Common;
     
    2830using HeuristicLab.Data;
    2931using HeuristicLab.Encodings.RealVectorEncoding;
     32using HeuristicLab.Optimization;
    3033using HeuristicLab.Parameters;
    3134using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    4043  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4144  [StorableClass]
    42   public sealed class TSNEAnalysis : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    43 
     45  public sealed class TSNEAnalysis : BasicAlgorithm {
     46
     47    public override Type ProblemType
     48    {
     49      get { return typeof(IDataAnalysisProblem); }
     50    }
     51    public new IDataAnalysisProblem Problem
     52    {
     53      get { return (IDataAnalysisProblem)base.Problem; }
     54      set { base.Problem = value; }
     55    }
    4456    #region Resultnames
    4557    private const string ScatterPlotResultName = "Scatterplot";
     
    6173    private const string SeedParameterName = "Seed";
    6274    private const string ClassesParameterName = "ClassNames";
     75    private const string NormalizationParameterName = "Normalization";
    6376    #endregion
    6477
     
    115128    {
    116129      get { return Parameters[ClassesParameterName] as IFixedValueParameter<StringValue>; }
     130    }
     131    public IFixedValueParameter<BoolValue> NormalizationParameter
     132    {
     133      get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
    117134    }
    118135    #endregion
     
    174191      get { return ClassesParameter.Value.Value; }
    175192    }
    176 
     193    public bool Normalization
     194    {
     195      get { return NormalizationParameter.Value.Value; }
     196    }
    177197    [Storable]
    178198    public TSNE<RealVector> tsne;
     
    191211      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis", new IntValue(2)));
    192212      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent", new IntValue(1000)));
    193       Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated", new IntValue(250)));
    194       Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched", new IntValue(250)));
     213      Parameters.Add(new FixedValueParameter<IntValue>(StopLyingIterationParameterName, "Number of iterations after which p is no longer approximated", new IntValue(0)));
     214      Parameters.Add(new FixedValueParameter<IntValue>(MomentumSwitchIterationParameterName, "Number of iterations after which the momentum in the gradient descent is switched", new IntValue(0)));
    195215      Parameters.Add(new FixedValueParameter<DoubleValue>(InitialMomentumParameterName, "The initial momentum in the gradient descent", new DoubleValue(0.5)));
    196216      Parameters.Add(new FixedValueParameter<DoubleValue>(FinalMomentumParameterName, "The final momentum", new DoubleValue(0.8)));
     
    199219      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The seed used if it should not be random", new IntValue(0)));
    200220      Parameters.Add(new FixedValueParameter<StringValue>(ClassesParameterName, "name of the column specifying the class lables of each data point. \n if the lable column can not be found Training/Test is used as labels", new StringValue("none")));
    201     }
    202     #endregion
    203 
    204     protected override void Run() {
    205       var data = CalculateProjectedData(Problem.ProblemData);
    206       var lowDimData = new DoubleMatrix(data);
    207     }
     221      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Wether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored", new BoolValue(true)));
     222
     223      MomentumSwitchIterationParameter.Hidden = true;
     224      InitialMomentumParameter.Hidden = true;
     225      FinalMomentumParameter.Hidden = true;
     226      StopLyingIterationParameter.Hidden = true;
     227      EtaParameter.Hidden = true;
     228    }
     229    #endregion
    208230
    209231    public override void Stop() {
     
    212234    }
    213235
     236    protected override void Run(CancellationToken cancellationToken) {
     237      var data = CalculateProjectedData(Problem.ProblemData);
     238      var lowDimData = new DoubleMatrix(data);
     239    }
     240
    214241    private double[,] CalculateProjectedData(IDataAnalysisProblemData problemData) {
    215       var DataRowNames = new Dictionary<string, List<int>>();
     242      var dataRowNames = new Dictionary<string, List<int>>();
    216243      var rows = new Dictionary<string, ScatterPlotDataRow>();
    217244
     
    220247          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
    221248          for (int i = 0; i < classes.Length; i++) {
    222             if (!DataRowNames.ContainsKey(classes[i])) DataRowNames.Add(classes[i], new List<int>());
    223             DataRowNames[classes[i]].Add(i); //always succeeds
     249            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     250            dataRowNames[classes[i]].Add(i); //always succeeds
    224251          }
    225252        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     
    228255          var min = classValues.Min() - 0.1;
    229256          var contours = 8;
    230           for (int i = 0; i < contours; i++) {
     257          for (var i = 0; i < contours; i++) {
    231258            var name = GetContourName(i, min, max, contours);
    232             DataRowNames.Add(name, new List<int>());
     259            dataRowNames.Add(name, new List<int>());
    233260            rows.Add(name, new ScatterPlotDataRow(name, "", new List<Point2D<double>>()));
    234261            rows[name].VisualProperties.Color = GetHeatMapColor(i, contours);
    235             rows[name].VisualProperties.PointSize = i+3;
     262            rows[name].VisualProperties.PointSize = i + 3;
    236263          }
    237264          for (int i = 0; i < classValues.Length; i++) {
    238             DataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds
     265            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds
    239266          }
    240267
    241268        }
    242269
    243 
    244270      } else {
    245         DataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    246         DataRowNames.Add("Test", problemData.TestIndices.ToList());
     271        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     272        dataRowNames.Add("Test", problemData.TestIndices.ToList());
    247273      }
    248274
    249275      var random = SetSeedRandomly ? new MersenneTwister() : new MersenneTwister(Seed);
    250       tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, DataRowNames, rows);
     276      tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows);
    251277      var dataset = problemData.Dataset;
    252278      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    253279      var data = new RealVector[dataset.Rows];
    254280      for (var row = 0; row < dataset.Rows; row++) data[row] = new RealVector(allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray());
     281
     282      if (Normalization) {
     283        data = NormalizeData(data);
     284      }
     285
    255286      return tsne.Run(data, NewDimensions, Perplexity, Theta);
     287    }
     288
     289    private RealVector[] NormalizeData(RealVector[] data) {
     290      var n = data[0].Length;
     291      var mean = new double[n];
     292      var sd = new double[n];
     293      var nData = new RealVector[data.Length];
     294      for (var i = 0; i < n; i++) {
     295        var i1 = i;
     296        sd[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).StandardDeviation();
     297        mean[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).Average();
     298      }
     299      for (int i = 0; i < data.Length; i++) {
     300        nData[i] = new RealVector(n);
     301        for (int j = 0; j < n; j++) {
     302          nData[i][j] = (data[i][j] - mean[j]) / sd[j];
     303        }
     304      }
     305      return nData;
     306
     307
    256308    }
    257309
Note: See TracChangeset for help on using the changeset viewer.