Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/27/17 17:27:03 (7 years ago)
Author:
gkronber
Message:

#2700: refactoring

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r14785 r14788  
    2929using HeuristicLab.Core;
    3030using HeuristicLab.Data;
    31 using HeuristicLab.Encodings.RealVectorEncoding;
    3231using HeuristicLab.Optimization;
    3332using HeuristicLab.Parameters;
     
    7271    private const string ClassesParameterName = "ClassNames";
    7372    private const string NormalizationParameterName = "Normalization";
     73    #endregion
     74
     75    #region result names
     76    private const string IterationResultName = "Iteration";
     77    private const string ErrorResultName = "Error";
     78    private const string ErrorPlotResultName = "Error plot";
     79    private const string ScatterPlotResultName = "Scatterplot";
     80    private const string DataResultName = "Projected data";
    7481    #endregion
    7582
     
    209216    #endregion
    210217
    211     public override void Stop() {
    212       base.Stop();
    213       if (tsne != null) tsne.Running = false;
    214     }
     218    [Storable]
     219    private Dictionary<string, List<int>> dataRowNames;    // TODO
     220    [Storable]
     221    private Dictionary<string, ScatterPlotDataRow> dataRows; // TODO
     222
    215223
    216224    protected override void Run(CancellationToken cancellationToken) {
    217       var dataRowNames = new Dictionary<string, List<int>>();
    218       var rows = new Dictionary<string, ScatterPlotDataRow>();
     225      var problemData = Problem.ProblemData;
     226
     227      // set up and run tSNE
     228      if (SetSeedRandomly) Seed = new System.Random().Next();
     229      var random = new MersenneTwister((uint)Seed);
     230      var dataset = problemData.Dataset;
     231      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
     232      var data = new double[dataset.Rows][];
     233      for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
     234      if (Normalization) data = NormalizeData(data);
     235
     236      var tsneState = TSNE<double[]>.CreateState(data, Distance, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
     237
     238      SetUpResults(data);
     239      for (int iter = 0; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++)
     240      {
     241        TSNE<double[]>.Iterate(tsneState);
     242        Analyze(tsneState);
     243      }
     244    }
     245
     246    private void SetUpResults(IReadOnlyCollection<double[]> data) {
     247      if (Results == null) return;
     248      var results = Results;
     249      dataRowNames = new Dictionary<string, List<int>>();
     250      dataRows = new Dictionary<string, ScatterPlotDataRow>();
    219251      var problemData = Problem.ProblemData;
    220252
     
    235267            var contourname = GetContourName(i, min, max, contours);
    236268            dataRowNames.Add(contourname, new List<int>());
    237             rows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    238             rows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    239             rows[contourname].VisualProperties.PointSize = i + 3;
     269            dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     270            dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     271            dataRows[contourname].VisualProperties.PointSize = i + 3;
    240272          }
    241273          for (var i = 0; i < classValues.Length; i++) {
     
    248280      }
    249281
    250       // set up and run tSNE
    251       if (SetSeedRandomly) Seed = new System.Random().Next();
    252       var random = new MersenneTwister((uint)Seed);
    253       tsne = new TSNE<double[]>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows);
    254       var dataset = problemData.Dataset;
    255       var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    256       var data = new double[dataset.Rows][];
    257       for (var row = 0; row < dataset.Rows; row++) data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    258       if (Normalization) data = NormalizeData(data);
    259       tsne.Run(data, NewDimensions, Perplexity, Theta);
     282      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     283      else ((IntValue)results[IterationResultName].Value).Value = 0;
     284
     285      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     286      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
     287
     288      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
     289      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
     290
     291      var plot = results[ErrorPlotResultName].Value as DataTable;
     292      if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
     293
     294      if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
     295      plot.Rows["errors"].Values.Clear();
     296
     297      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     298      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     299    }
     300
     301    private void Analyze(TSNE<double[]>.TSNEState tsneState) {
     302      if (Results == null) return;
     303      var results = Results;
     304      var plot = results[ErrorPlotResultName].Value as DataTable;
     305      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
     306      var errors = plot.Rows["errors"].Values;
     307      var c = tsneState.EvaluateError();
     308      errors.Add(c);
     309      ((IntValue)results[IterationResultName].Value).Value = tsneState.iter + 1;
     310      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
     311
     312      var ndata = Normalize(tsneState.newData);
     313      results[DataResultName].Value = new DoubleMatrix(ndata);
     314      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
     315      FillScatterPlot(ndata, splot);
     316    }
     317
     318    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
     319      foreach (var rowName in dataRowNames.Keys) {
     320        if (!plot.Rows.ContainsKey(rowName))
     321          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     322        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     323      }
     324    }
     325
     326    private static double[,] Normalize(double[,] data) {
     327      var max = new double[data.GetLength(1)];
     328      var min = new double[data.GetLength(1)];
     329      var res = new double[data.GetLength(0), data.GetLength(1)];
     330      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
     331      for (var i = 0; i < data.GetLength(0); i++)
     332        for (var j = 0; j < data.GetLength(1); j++) {
     333          var v = data[i, j];
     334          max[j] = Math.Max(max[j], v);
     335          min[j] = Math.Min(min[j], v);
     336        }
     337      for (var i = 0; i < data.GetLength(0); i++) {
     338        for (var j = 0; j < data.GetLength(1); j++) {
     339          res[i, j] = (data[i, j] - (max[j] + min[j]) / 2) / (max[j] - min[j]);
     340        }
     341      }
     342      return res;
    260343    }
    261344
     
    276359      return nData;
    277360    }
     361
    278362    private static Color GetHeatMapColor(int contourNr, int noContours) {
    279363      var q = (double)contourNr / noContours;  // q in [0,1]
     
    281365      return c;
    282366    }
     367
    283368    private static string GetContourName(double value, double min, double max, int noContours) {
    284369      var size = (max - min) / noContours;
     
    286371      return GetContourName(contourNr, min, max, noContours);
    287372    }
     373
    288374    private static string GetContourName(int i, double min, double max, int noContours) {
    289375      var size = (max - min) / noContours;
Note: See TracChangeset for help on using the changeset viewer.