Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15556


Ignore:
Timestamp:
12/21/17 09:14:27 (7 years ago)
Author:
bwerth
Message:

#2850 reduced state of TSNEAlgorithm.cs

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r15551 r15556  
    4242  /// </summary>
    4343  [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " +
    44                 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
     44                                                              "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
    4545  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4646  [StorableClass]
     
    203203    #region Storable poperties
    204204    [Storable]
    205     private Dictionary<string, List<int>> dataRowNames;
    206     [Storable]
    207     private Dictionary<string, ScatterPlotDataRow> dataRows;
     205    private Dictionary<string, IList<int>> dataRowIndices;
    208206    [Storable]
    209207    private TSNEStatic<double[]>.TSNEState state;
    210     [Storable]
    211     private int iter;
    212208    #endregion
    213209
     
    223219    }
    224220    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    225       if (original.dataRowNames != null)
    226         dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    227       if (original.dataRows != null)
    228         dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     221      if (original.dataRowIndices != null)
     222        dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices);
    229223      if (original.state != null)
    230224        state = cloner.Clone(original.state);
    231       iter = original.iter;
    232225      RegisterParameterEvents();
    233226    }
     
    259252      Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
    260253
    261       Parameters[UpdateIntervalParameterName].Hidden = true;
    262 
     254      UpdateIntervalParameter.Hidden = true;
    263255      MomentumSwitchIterationParameter.Hidden = true;
    264256      InitialMomentumParameter.Hidden = true;
     
    273265    public override void Prepare() {
    274266      base.Prepare();
    275       dataRowNames = null;
    276       dataRows = null;
     267      dataRowIndices = null;
    277268      state = null;
    278269    }
     
    301292          col++;
    302293        }
    303 
    304294        if (Normalization) data = NormalizeInputData(data);
    305295        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization);
    306296        SetUpResults(allindices);
    307         iter = 0;
    308       }
    309       for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
    310         if (iter % UpdateInterval == 0) Analyze(state);
     297      }
     298      while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) {
     299        if (state.iter % UpdateInterval == 0) Analyze(state);
    311300        TSNEStatic<double[]>.Iterate(state);
    312301      }
     
    324313      base.RegisterProblemEvents();
    325314      if (Problem == null) return;
     315      Problem.ProblemDataChanged += OnProblemDataChanged;
     316      if (Problem.ProblemData == null) return;
     317      Problem.ProblemData.Changed += OnPerplexityChanged;
     318      Problem.ProblemData.Changed += OnColumnsChanged;
     319      if (Problem.ProblemData.Dataset == null) return;
     320      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     321      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
     322    }
     323
     324    protected override void DeregisterProblemEvents() {
     325      base.DeregisterProblemEvents();
     326      if (Problem == null) return;
    326327      Problem.ProblemDataChanged -= OnProblemDataChanged;
    327       Problem.ProblemDataChanged += OnProblemDataChanged;
    328328      if (Problem.ProblemData == null) return;
    329329      Problem.ProblemData.Changed -= OnPerplexityChanged;
    330330      Problem.ProblemData.Changed -= OnColumnsChanged;
    331       Problem.ProblemData.Changed += OnPerplexityChanged;
    332       Problem.ProblemData.Changed += OnColumnsChanged;
    333331      if (Problem.ProblemData.Dataset == null) return;
    334332      Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;
    335333      Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;
    336       Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
    337       Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    338     }
    339 
    340     protected override void DeregisterProblemEvents() {
    341       base.DeregisterProblemEvents();
    342       Problem.ProblemDataChanged -= OnProblemDataChanged;
    343334    }
    344335
    345336    protected override void OnStopped() {
    346337      base.OnStopped();
     338      //bwerth: state objects can be very large; avoid state serialization
    347339      state = null;
    348       dataRowNames = null;
    349       dataRows = null;
     340      dataRowIndices = null;
    350341    }
    351342
     
    354345      OnPerplexityChanged(this, null);
    355346      OnColumnsChanged(this, null);
    356       Problem.ProblemData.Changed -= OnPerplexityChanged;
    357347      Problem.ProblemData.Changed += OnPerplexityChanged;
    358       Problem.ProblemData.Changed -= OnColumnsChanged;
    359348      Problem.ProblemData.Changed += OnColumnsChanged;
    360349      if (Problem.ProblemData.Dataset == null) return;
    361       Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;
    362       Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;
    363350      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
    364351      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
     
    374361
    375362    private void RegisterParameterEvents() {
    376       PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;
    377363      PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
    378364    }
     
    380366    private void OnPerplexityChanged(object sender, EventArgs e) {
    381367      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;
    382       PerplexityParameter.Value.ValueChanged -= OnPerplexityChanged;
    383368      PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));
    384       PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
    385369    }
    386370    #endregion
     
    390374      if (Results == null) return;
    391375      var results = Results;
    392       dataRowNames = new Dictionary<string, List<int>>();
    393       dataRows = new Dictionary<string, ScatterPlotDataRow>();
     376      dataRowIndices = new Dictionary<string, IList<int>>();
    394377      var problemData = Problem.ProblemData;
    395378
     
    411394      }
    412395
    413       //color datapoints acording to classes variable (be it double or string)
     396      //color datapoints acording to classes variable (be it double, datetime or string)
    414397      if (!problemData.Dataset.VariableNames.Contains(ClassesName)) {
    415         dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    416         dataRowNames.Add("Test", problemData.TestIndices.ToList());
     398        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     399        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
    417400        return;
    418401      }
     402
    419403      var classificationData = problemData as ClassificationProblemData;
    420404      if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
     
    422406        var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
    423407        for (var i = 0; i < classes.Length; i++) {
    424           if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    425           dataRowNames[classes[i]].Add(i);
     408          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     409          dataRowIndices[classes[i]].Add(i);
    426410        }
    427411      } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) {
    428412        var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
    429413        for (var i = 0; i < classes.Length; i++) {
    430           if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    431           dataRowNames[classes[i]].Add(i);
     414          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     415          dataRowIndices[classes[i]].Add(i);
    432416        }
    433417      } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) {
     
    442426          var c = contourorder[i];
    443427          var contourname = contourMap[c];
    444           dataRowNames.Add(contourname, new List<int>());
    445           dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    446           dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     428          dataRowIndices.Add(contourname, new List<int>());
     429          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     430          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
    447431        }
    448432        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    449         for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     433        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    450434      } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
    451435        var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     
    459443          var c = contourorder[i];
    460444          var contourname = contourMap[c];
    461           dataRowNames.Add(contourname, new List<int>());
    462           dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    463           dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     445          dataRowIndices.Add(contourname, new List<int>());
     446          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     447          row.VisualProperties.PointSize = 8;
     448          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
    464449        }
    465450        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    466         for (var i = 0; i < clusterdata.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
     451        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    467452      } else {
    468         dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    469         dataRowNames.Add("Test", problemData.TestIndices.ToList());
     453        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     454        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
    470455      }
    471456    }
     
    485470      results[DataResultName].Value = new DoubleMatrix(ndata);
    486471      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
    487       FillScatterPlot(ndata, splot);
    488     }
    489 
    490     private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    491       foreach (var rowName in dataRowNames.Keys) {
     472      FillScatterPlot(ndata, splot, dataRowIndices);
     473    }
     474
     475    private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) {
     476      foreach (var rowName in dataRowIndices.Keys) {
    492477        if (!plot.Rows.ContainsKey(rowName)) {
    493           plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     478          plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    494479          plot.Rows[rowName].VisualProperties.PointSize = 8;
    495480        }
    496         plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     481        plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    497482      }
    498483    }
     
    504489      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    505490      for (var i = 0; i < data.GetLength(0); i++)
    506         for (var j = 0; j < data.GetLength(1); j++) {
    507           var v = data[i, j];
    508           max[j] = Math.Max(max[j], v);
    509           min[j] = Math.Min(min[j], v);
    510         }
     491      for (var j = 0; j < data.GetLength(1); j++) {
     492        var v = data[i, j];
     493        max[j] = Math.Max(max[j], v);
     494        min[j] = Math.Min(min[j], v);
     495      }
    511496      for (var i = 0; i < data.GetLength(0); i++) {
    512497        for (var j = 0; j < data.GetLength(1); j++) {
     
    532517      for (var i = 0; i < data.Count; i++) {
    533518        nData[i] = new double[n];
    534         for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
     519        for (var j = 0; j < n; j++)
     520          nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
    535521      }
    536522      return nData;
     
    542528
    543529    private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
    544       var cpd = new ClusteringProblemData((Dataset)data, new[] { target });
     530      var cpd = new ClusteringProblemData((Dataset)data, new[] {target});
    545531      contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model;
    546532
    547       borders = Enumerable.Range(0, contours).Select(x => new[] { double.MaxValue, double.MinValue }).ToArray();
     533      borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
    548534      var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray();
    549535      var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
Note: See TracChangeset for help on using the changeset viewer.