Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/10/17 08:37:27 (7 years ago)
Author:
bwerth
Message:

#2700 fixed displaying of randomly generated seed and some minor code simplifications

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAnalysis.cs

    r14558 r14742  
    5757      set { base.Problem = value; }
    5858    }
    59     #region Resultnames
    60     private const string ScatterPlotResultName = "Scatterplot";
    61     private const string DataResultName = "Projected Data";
    62     #endregion
    6359
    6460    #region Parameternames
     
    238234
    239235    protected override void Run(CancellationToken cancellationToken) {
    240       var data = CalculateProjectedData(Problem.ProblemData);
    241       var lowDimData = new DoubleMatrix(data);
    242     }
    243 
    244     private double[,] CalculateProjectedData(IDataAnalysisProblemData problemData) {
    245236      var dataRowNames = new Dictionary<string, List<int>>();
    246237      var rows = new Dictionary<string, ScatterPlotDataRow>();
    247 
     238      var problemData = Problem.ProblemData;
     239
     240      //color datapoints acording to Classes-Variable (be it double or string)
    248241      if (problemData.Dataset.VariableNames.Contains(Classes)) {
    249242        if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
    250243          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
    251           for (int i = 0; i < classes.Length; i++) {
     244          for (var i = 0; i < classes.Length; i++) {
    252245            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    253             dataRowNames[classes[i]].Add(i); //always succeeds
     246            dataRowNames[classes[i]].Add(i);
    254247          }
    255248        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     
    257250          var max = classValues.Max() + 0.1;
    258251          var min = classValues.Min() - 0.1;
    259           var contours = 8;
     252          const int contours = 8;
    260253          for (var i = 0; i < contours; i++) {
    261             var name = GetContourName(i, min, max, contours);
    262             dataRowNames.Add(name, new List<int>());
    263             rows.Add(name, new ScatterPlotDataRow(name, "", new List<Point2D<double>>()));
    264             rows[name].VisualProperties.Color = GetHeatMapColor(i, contours);
    265             rows[name].VisualProperties.PointSize = i + 3;
     254            var contourname = GetContourName(i, min, max, contours);
     255            dataRowNames.Add(contourname, new List<int>());
     256            rows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     257            rows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     258            rows[contourname].VisualProperties.PointSize = i + 3;
    266259          }
    267           for (int i = 0; i < classValues.Length; i++) {
    268             dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds
     260          for (var i = 0; i < classValues.Length; i++) {
     261            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    269262          }
    270 
    271263        }
    272 
    273264      } else {
    274265        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     
    276267      }
    277268
    278       var random = SetSeedRandomly ? new MersenneTwister() : new MersenneTwister(Seed);
     269      //Set up and run TSNE
     270      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
     271      var random = new MersenneTwister(Seed);
    279272      tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows);
    280273      var dataset = problemData.Dataset;
     
    282275      var data = new RealVector[dataset.Rows];
    283276      for (var row = 0; row < dataset.Rows; row++) data[row] = new RealVector(allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray());
    284 
    285       if (Normalization) {
    286         data = NormalizeData(data);
    287       }
    288 
    289       return tsne.Run(data, NewDimensions, Perplexity, Theta);
    290     }
    291 
    292     private RealVector[] NormalizeData(RealVector[] data) {
     277      if (Normalization) data = NormalizeData(data);
     278      tsne.Run(data, NewDimensions, Perplexity, Theta);
     279    }
     280
     281    private static RealVector[] NormalizeData(IReadOnlyList<RealVector> data) {
    293282      var n = data[0].Length;
    294283      var mean = new double[n];
    295284      var sd = new double[n];
    296       var nData = new RealVector[data.Length];
     285      var nData = new RealVector[data.Count];
    297286      for (var i = 0; i < n; i++) {
    298287        var i1 = i;
    299         sd[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).StandardDeviation();
    300         mean[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).Average();
     288        sd[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).StandardDeviation();
     289        mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).Average();
    301290      }
    302       for (int i = 0; i < data.Length; i++) {
     291      for (var i = 0; i < data.Count; i++) {
    303292        nData[i] = new RealVector(n);
    304         for (int j = 0; j < n; j++) {
    305           nData[i][j] = (data[i][j] - mean[j]) / sd[j];
    306         }
     293        for (var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j];
    307294      }
    308295      return nData;
    309 
    310 
    311     }
    312 
     296    }
    313297    private static Color GetHeatMapColor(int contourNr, int noContours) {
    314298      var q = (double)contourNr / noContours;  // q in [0,1]
Note: See TracChangeset for help on using the changeset viewer.