Free cookie consent management tool by TermsFeed Policy Generator

Changeset 14742 for branches


Ignore:
Timestamp:
03/10/17 08:37:27 (8 years ago)
Author:
bwerth
Message:

#2700 fixed displaying of randomly generated seed and some minor code simplifications

Location:
branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/SPtree.cs

    r14518 r14742  
    157157      // Ignore objects which do not belong in this quad tree
    158158      var point = new double[dimension];
    159       Buffer.BlockCopy(Data, (int)(sizeof(double) * dimension * newIndex), point, 0, (int)(sizeof(double) * dimension));
     159      Buffer.BlockCopy(Data, sizeof(double) * dimension * newIndex, point, 0, sizeof(double) * dimension);
    160160      if (!boundary.ContainsPoint(point)) return false;
    161161      cumulativeSize++;
     
    227227    public bool IsCorrect() {
    228228      var row = new double[dimension];
    229       for (var n = 0; n < size; n++)
    230         Buffer.BlockCopy(Data, (int)(sizeof(double) * dimension * n), row, 0, (int)(sizeof(double) * dimension));
     229      for (var n = 0; n < size; n++) Buffer.BlockCopy(Data, sizeof(double) * dimension * n, row, 0, sizeof(double) * dimension);
    231230      if (!boundary.ContainsPoint(row)) return false;
    232231      if (isLeaf) return true;
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNE.cs

    r14558 r14742  
    132132      this.results = results;
    133133      this.dataRowLookup = dataRowLookup;
    134       if (dataRows != null)
    135         this.dataRows = dataRows;
     134      if (dataRows != null) this.dataRows = dataRows;
    136135      else { this.dataRows = new Dictionary<string, ScatterPlotDataRow>(); }
    137136    }
     
    200199
    201200    private void SetUpResults(IReadOnlyCollection<T> data) {
    202       if (dataRowLookup == null) {
    203         dataRowLookup = new Dictionary<string, List<int>>();
    204         dataRowLookup.Add("Data", Enumerable.Range(0, data.Count).ToList());
    205       }
     201      if (dataRowLookup == null) dataRowLookup = new Dictionary<string, List<int>> { { "Data", Enumerable.Range(0, data.Count).ToList() } };
    206202      if (results == null) return;
     203
    207204      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    208205      else ((IntValue)results[IterationResultName].Value).Value = 0;
     
    216213      var plot = results[ErrorPlotResultName].Value as DataTable;
    217214      if (plot == null) throw new ArgumentException("could not create/access Error-DataTable in Results-Collection");
    218       if (!plot.Rows.ContainsKey("errors")) {
    219         plot.Rows.Add(new DataRow("errors"));
    220       }
     215
     216      if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    221217      plot.Rows["errors"].Values.Clear();
     218
    222219      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    223220      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     
    245242    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    246243      foreach (var rowName in dataRowLookup.Keys) {
    247         if (!plot.Rows.ContainsKey(rowName)) {
     244        if (!plot.Rows.ContainsKey(rowName))
    248245          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    249         }
    250         //else plot.Rows[rowName].Points.Clear();
    251246        plot.Rows[rowName].Points.Replace(dataRowLookup[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    252         //plot.Rows[rowName].Points.AddRange();
    253247      }
    254248    }
     
    503497      tree.ComputeEdgeForces(rowP, colP, valP, n, posF);
    504498      var row = new double[d];
    505       for (int n1 = 0; n1 < n; n1++) {
     499      for (var n1 = 0; n1 < n; n1++) {
    506500        Buffer.BlockCopy(negF, (sizeof(double) * n1 * d), row, 0, d);
    507501        tree.ComputeNonEdgeForces(n1, theta, row, sumQ);
  • branches/TSNE/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAnalysis.cs

    r14558 r14742  
    5757      set { base.Problem = value; }
    5858    }
    59     #region Resultnames
    60     private const string ScatterPlotResultName = "Scatterplot";
    61     private const string DataResultName = "Projected Data";
    62     #endregion
    6359
    6460    #region Parameternames
     
    238234
    239235    protected override void Run(CancellationToken cancellationToken) {
    240       var data = CalculateProjectedData(Problem.ProblemData);
    241       var lowDimData = new DoubleMatrix(data);
    242     }
    243 
    244     private double[,] CalculateProjectedData(IDataAnalysisProblemData problemData) {
    245236      var dataRowNames = new Dictionary<string, List<int>>();
    246237      var rows = new Dictionary<string, ScatterPlotDataRow>();
    247 
     238      var problemData = Problem.ProblemData;
     239
     240      //color datapoints acording to Classes-Variable (be it double or string)
    248241      if (problemData.Dataset.VariableNames.Contains(Classes)) {
    249242        if ((problemData.Dataset as Dataset).VariableHasType<string>(Classes)) {
    250243          var classes = problemData.Dataset.GetStringValues(Classes).ToArray();
    251           for (int i = 0; i < classes.Length; i++) {
     244          for (var i = 0; i < classes.Length; i++) {
    252245            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    253             dataRowNames[classes[i]].Add(i); //always succeeds
     246            dataRowNames[classes[i]].Add(i);
    254247          }
    255248        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(Classes)) {
     
    257250          var max = classValues.Max() + 0.1;
    258251          var min = classValues.Min() - 0.1;
    259           var contours = 8;
     252          const int contours = 8;
    260253          for (var i = 0; i < contours; i++) {
    261             var name = GetContourName(i, min, max, contours);
    262             dataRowNames.Add(name, new List<int>());
    263             rows.Add(name, new ScatterPlotDataRow(name, "", new List<Point2D<double>>()));
    264             rows[name].VisualProperties.Color = GetHeatMapColor(i, contours);
    265             rows[name].VisualProperties.PointSize = i + 3;
     254            var contourname = GetContourName(i, min, max, contours);
     255            dataRowNames.Add(contourname, new List<int>());
     256            rows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     257            rows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     258            rows[contourname].VisualProperties.PointSize = i + 3;
    266259          }
    267           for (int i = 0; i < classValues.Length; i++) {
    268             dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i); //always succeeds
     260          for (var i = 0; i < classValues.Length; i++) {
     261            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    269262          }
    270 
    271263        }
    272 
    273264      } else {
    274265        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     
    276267      }
    277268
    278       var random = SetSeedRandomly ? new MersenneTwister() : new MersenneTwister(Seed);
     269      //Set up and run TSNE
     270      if (SetSeedRandomly) SeedParameter.Value.Value = new System.Random().Next();
     271      var random = new MersenneTwister(Seed);
    279272      tsne = new TSNE<RealVector>(Distance, random, Results, MaxIterations, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, dataRowNames, rows);
    280273      var dataset = problemData.Dataset;
     
    282275      var data = new RealVector[dataset.Rows];
    283276      for (var row = 0; row < dataset.Rows; row++) data[row] = new RealVector(allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray());
    284 
    285       if (Normalization) {
    286         data = NormalizeData(data);
    287       }
    288 
    289       return tsne.Run(data, NewDimensions, Perplexity, Theta);
    290     }
    291 
    292     private RealVector[] NormalizeData(RealVector[] data) {
     277      if (Normalization) data = NormalizeData(data);
     278      tsne.Run(data, NewDimensions, Perplexity, Theta);
     279    }
     280
     281    private static RealVector[] NormalizeData(IReadOnlyList<RealVector> data) {
    293282      var n = data[0].Length;
    294283      var mean = new double[n];
    295284      var sd = new double[n];
    296       var nData = new RealVector[data.Length];
     285      var nData = new RealVector[data.Count];
    297286      for (var i = 0; i < n; i++) {
    298287        var i1 = i;
    299         sd[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).StandardDeviation();
    300         mean[i] = Enumerable.Range(0, data.Length).Select(x => data[x][i1]).Average();
     288        sd[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).StandardDeviation();
     289        mean[i] = Enumerable.Range(0, data.Count).Select(x => data[x][i1]).Average();
    301290      }
    302       for (int i = 0; i < data.Length; i++) {
     291      for (var i = 0; i < data.Count; i++) {
    303292        nData[i] = new RealVector(n);
    304         for (int j = 0; j < n; j++) {
    305           nData[i][j] = (data[i][j] - mean[j]) / sd[j];
    306         }
     293        for (var j = 0; j < n; j++) nData[i][j] = (data[i][j] - mean[j]) / sd[j];
    307294      }
    308295      return nData;
    309 
    310 
    311     }
    312 
     296    }
    313297    private static Color GetHeatMapColor(int contourNr, int noContours) {
    314298      var q = (double)contourNr / noContours;  // q in [0,1]
Note: See TracChangeset for help on using the changeset viewer.