Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/06/18 18:15:29 (6 years ago)
Author:
jkarder
Message:

#2839:

Location:
branches/2839_HiveProjectManagement
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2839_HiveProjectManagement

  • branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis

  • branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:mergeinfo set to (toggle deleted branches)
      /stable/HeuristicLab.Algorithms.DataAnalysis/3.4mergedeligible
      /trunk/HeuristicLab.Algorithms.DataAnalysis/3.4mergedeligible
      /branches/1721-RandomForestPersistence/HeuristicLab.Algorithms.DataAnalysis/3.410321-10322
      /branches/Async/HeuristicLab.Algorithms.DataAnalysis/3.413329-15286
      /branches/Benchmarking/sources/HeuristicLab.Algorithms.DataAnalysis/3.46917-7005
      /branches/ClassificationModelComparison/HeuristicLab.Algorithms.DataAnalysis/3.49070-13099
      /branches/CloningRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.44656-4721
      /branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.45471-5808
      /branches/DataAnalysis SolutionEnsembles/HeuristicLab.Algorithms.DataAnalysis/3.45815-6180
      /branches/DataAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.44458-4459,​4462,​4464
      /branches/DataPreprocessing/HeuristicLab.Algorithms.DataAnalysis/3.410085-11101
      /branches/GP.Grammar.Editor/HeuristicLab.Algorithms.DataAnalysis/3.46284-6795
      /branches/GP.Symbols (TimeLag, Diff, Integral)/HeuristicLab.Algorithms.DataAnalysis/3.45060
      /branches/HeuristicLab.DatasetRefactor/sources/HeuristicLab.Algorithms.DataAnalysis/3.411570-12508
      /branches/HeuristicLab.Problems.Orienteering/HeuristicLab.Algorithms.DataAnalysis/3.411130-12721
      /branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.413819-14091
      /branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.48116-8789
      /branches/LogResidualEvaluator/HeuristicLab.Algorithms.DataAnalysis/3.410202-10483
      /branches/NET40/sources/HeuristicLab.Algorithms.DataAnalysis/3.45138-5162
      /branches/ParallelEngine/HeuristicLab.Algorithms.DataAnalysis/3.45175-5192
      /branches/ProblemInstancesRegressionAndClassification/HeuristicLab.Algorithms.DataAnalysis/3.47773-7810
      /branches/QAPAlgorithms/HeuristicLab.Algorithms.DataAnalysis/3.46350-6627
      /branches/Restructure trunk solution/HeuristicLab.Algorithms.DataAnalysis/3.46828
      /branches/SpectralKernelForGaussianProcesses/HeuristicLab.Algorithms.DataAnalysis/3.410204-10479
      /branches/SuccessProgressAnalysis/HeuristicLab.Algorithms.DataAnalysis/3.45370-5682
      /branches/Trunk/HeuristicLab.Algorithms.DataAnalysis/3.46829-6865
      /branches/VNS/HeuristicLab.Algorithms.DataAnalysis/3.45594-5752
      /branches/Weighted TSNE/3.415451-15531
      /branches/histogram/HeuristicLab.Algorithms.DataAnalysis/3.45959-6341
      /branches/symbreg-factors-2650/HeuristicLab.Algorithms.DataAnalysis/3.414232-14825
      /trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.415377-15681
  • branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r15234 r16057  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3838namespace HeuristicLab.Algorithms.DataAnalysis {
    3939  /// <summary>
    40   /// t-distributed stochastic neighbourhood embedding (tSNE) projects the data in a low dimensional
     40  /// t-Distributed Stochastic Neighbor Embedding (tSNE) projects the data in a low dimensional
    4141  /// space to allow visual cluster identification.
    4242  /// </summary>
    43   [Item("tSNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " +
    44                 "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
     43  [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " +
     44                                                              "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
    4545  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4646  [StorableClass]
     
    5757    }
    5858
    59     #region parameter names
     59    #region Parameter names
    6060    private const string DistanceFunctionParameterName = "DistanceFunction";
    6161    private const string PerplexityParameterName = "Perplexity";
     
    7272    private const string ClassesNameParameterName = "ClassesName";
    7373    private const string NormalizationParameterName = "Normalization";
     74    private const string RandomInitializationParameterName = "RandomInitialization";
    7475    private const string UpdateIntervalParameterName = "UpdateInterval";
    7576    #endregion
    7677
    77     #region result names
     78    #region Result names
    7879    private const string IterationResultName = "Iteration";
    7980    private const string ErrorResultName = "Error";
     
    8384    #endregion
    8485
    85     #region parameter properties
     86    #region Parameter properties
    8687    public IFixedValueParameter<DoubleValue> PerplexityParameter {
    87       get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
     88      get { return (IFixedValueParameter<DoubleValue>)Parameters[PerplexityParameterName]; }
    8889    }
    8990    public IFixedValueParameter<PercentValue> ThetaParameter {
    90       get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; }
     91      get { return (IFixedValueParameter<PercentValue>)Parameters[ThetaParameterName]; }
    9192    }
    9293    public IFixedValueParameter<IntValue> NewDimensionsParameter {
    93       get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; }
     94      get { return (IFixedValueParameter<IntValue>)Parameters[NewDimensionsParameterName]; }
    9495    }
    9596    public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter {
    96       get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; }
     97      get { return (IConstrainedValueParameter<IDistance<double[]>>)Parameters[DistanceFunctionParameterName]; }
    9798    }
    9899    public IFixedValueParameter<IntValue> MaxIterationsParameter {
    99       get { return Parameters[MaxIterationsParameterName] as IFixedValueParameter<IntValue>; }
     100      get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
    100101    }
    101102    public IFixedValueParameter<IntValue> StopLyingIterationParameter {
    102       get { return Parameters[StopLyingIterationParameterName] as IFixedValueParameter<IntValue>; }
     103      get { return (IFixedValueParameter<IntValue>)Parameters[StopLyingIterationParameterName]; }
    103104    }
    104105    public IFixedValueParameter<IntValue> MomentumSwitchIterationParameter {
    105       get { return Parameters[MomentumSwitchIterationParameterName] as IFixedValueParameter<IntValue>; }
     106      get { return (IFixedValueParameter<IntValue>)Parameters[MomentumSwitchIterationParameterName]; }
    106107    }
    107108    public IFixedValueParameter<DoubleValue> InitialMomentumParameter {
    108       get { return Parameters[InitialMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
     109      get { return (IFixedValueParameter<DoubleValue>)Parameters[InitialMomentumParameterName]; }
    109110    }
    110111    public IFixedValueParameter<DoubleValue> FinalMomentumParameter {
    111       get { return Parameters[FinalMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
     112      get { return (IFixedValueParameter<DoubleValue>)Parameters[FinalMomentumParameterName]; }
    112113    }
    113114    public IFixedValueParameter<DoubleValue> EtaParameter {
    114       get { return Parameters[EtaParameterName] as IFixedValueParameter<DoubleValue>; }
     115      get { return (IFixedValueParameter<DoubleValue>)Parameters[EtaParameterName]; }
    115116    }
    116117    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    117       get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
     118      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
    118119    }
    119120    public IFixedValueParameter<IntValue> SeedParameter {
    120       get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
     121      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
    121122    }
    122123    public IConstrainedValueParameter<StringValue> ClassesNameParameter {
    123       get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; }
     124      get { return (IConstrainedValueParameter<StringValue>)Parameters[ClassesNameParameterName]; }
    124125    }
    125126    public IFixedValueParameter<BoolValue> NormalizationParameter {
    126       get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
     127      get { return (IFixedValueParameter<BoolValue>)Parameters[NormalizationParameterName]; }
     128    }
     129    public IFixedValueParameter<BoolValue> RandomInitializationParameter {
     130      get { return (IFixedValueParameter<BoolValue>)Parameters[RandomInitializationParameterName]; }
    127131    }
    128132    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
    129       get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }
     133      get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
    130134    }
    131135    #endregion
     
    187191      set { NormalizationParameter.Value.Value = value; }
    188192    }
    189 
     193    public bool RandomInitialization {
     194      get { return RandomInitializationParameter.Value.Value; }
     195      set { RandomInitializationParameter.Value.Value = value; }
     196    }
    190197    public int UpdateInterval {
    191198      get { return UpdateIntervalParameter.Value.Value; }
     
    194201    #endregion
    195202
     203    #region Storable poperties
     204    [Storable]
     205    private Dictionary<string, IList<int>> dataRowIndices;
     206    [Storable]
     207    private TSNEStatic<double[]>.TSNEState state;
     208    #endregion
     209
    196210    #region Constructors & Cloning
    197211    [StorableConstructor]
    198212    private TSNEAlgorithm(bool deserializing) : base(deserializing) { }
    199213
     214    [StorableHook(HookType.AfterDeserialization)]
     215    private void AfterDeserialization() {
     216      if (!Parameters.ContainsKey(RandomInitializationParameterName))
     217        Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
     218      RegisterParameterEvents();
     219    }
    200220    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    201       if (original.dataRowNames != null)
    202         this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
    203       if (original.dataRows != null)
    204         this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
     221      if (original.dataRowIndices != null)
     222        dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices);
    205223      if (original.state != null)
    206         this.state = cloner.Clone(original.state);
    207       this.iter = original.iter;
    208     }
    209     public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
     224        state = cloner.Clone(original.state);
     225      RegisterParameterEvents();
     226    }
     227    public override IDeepCloneable Clone(Cloner cloner) {
     228      return new TSNEAlgorithm(this, cloner);
     229    }
    210230    public TSNEAlgorithm() {
    211231      var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>());
     
    213233      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    214234      Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " +
    215                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
    216                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
    217                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
    218                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
    219                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
     235                                                                               "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
     236                                                                               "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
     237                                                                               "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
     238                                                                               "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
     239                                                                               " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
    220240      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    221241      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     
    230250      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    231251      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50)));
    232       Parameters[UpdateIntervalParameterName].Hidden = true;
    233 
     252      Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
     253
     254      UpdateIntervalParameter.Hidden = true;
    234255      MomentumSwitchIterationParameter.Hidden = true;
    235256      InitialMomentumParameter.Hidden = true;
     
    238259      EtaParameter.Hidden = false;
    239260      Problem = new RegressionProblem();
    240     }
    241     #endregion
    242 
    243     [Storable]
    244     private Dictionary<string, List<int>> dataRowNames;
    245     [Storable]
    246     private Dictionary<string, ScatterPlotDataRow> dataRows;
    247     [Storable]
    248     private TSNEStatic<double[]>.TSNEState state;
    249     [Storable]
    250     private int iter;
     261      RegisterParameterEvents();
     262    }
     263    #endregion
    251264
    252265    public override void Prepare() {
    253266      base.Prepare();
    254       dataRowNames = null;
    255       dataRows = null;
     267      dataRowIndices = null;
    256268      state = null;
    257269    }
     
    259271    protected override void Run(CancellationToken cancellationToken) {
    260272      var problemData = Problem.ProblemData;
    261       // set up and initialized everything if necessary
     273      // set up and initialize everything if necessary
     274      var wdist = DistanceFunction as WeightedEuclideanDistance;
     275      if (wdist != null) wdist.Initialize(problemData);
    262276      if (state == null) {
    263277        if (SetSeedRandomly) Seed = new System.Random().Next();
     
    265279        var dataset = problemData.Dataset;
    266280        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    267         var data = new double[dataset.Rows][];
    268         for (var row = 0; row < dataset.Rows; row++)
    269           data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
    270 
    271         if (Normalization) data = NormalizeData(data);
    272 
    273         state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta,
    274           StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
    275 
    276         SetUpResults(data);
    277         iter = 0;
    278       }
    279       for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
    280         if (iter % UpdateInterval == 0)
    281           Analyze(state);
     281        var allindices = Problem.ProblemData.AllIndices.ToArray();
     282
     283        // jagged array is required to meet the static method declarations of TSNEStatic<T>
     284        var data = Enumerable.Range(0, dataset.Rows).Select(x => new double[allowedInputVariables.Length]).ToArray();
     285        var col = 0;
     286        foreach (var s in allowedInputVariables) {
     287          var row = 0;
     288          foreach (var d in dataset.GetDoubleValues(s)) {
     289            data[row][col] = d;
     290            row++;
     291          }
     292          col++;
     293        }
     294        if (Normalization) data = NormalizeInputData(data);
     295        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization);
     296        SetUpResults(allindices);
     297      }
     298      while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) {
     299        if (state.iter % UpdateInterval == 0) Analyze(state);
    282300        TSNEStatic<double[]>.Iterate(state);
    283301      }
     
    294312    protected override void RegisterProblemEvents() {
    295313      base.RegisterProblemEvents();
     314      if (Problem == null) return;
    296315      Problem.ProblemDataChanged += OnProblemDataChanged;
    297     }
     316      if (Problem.ProblemData == null) return;
     317      Problem.ProblemData.Changed += OnPerplexityChanged;
     318      Problem.ProblemData.Changed += OnColumnsChanged;
     319      if (Problem.ProblemData.Dataset == null) return;
     320      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     321      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
     322    }
     323
    298324    protected override void DeregisterProblemEvents() {
    299325      base.DeregisterProblemEvents();
     326      if (Problem == null) return;
    300327      Problem.ProblemDataChanged -= OnProblemDataChanged;
     328      if (Problem.ProblemData == null) return;
     329      Problem.ProblemData.Changed -= OnPerplexityChanged;
     330      Problem.ProblemData.Changed -= OnColumnsChanged;
     331      if (Problem.ProblemData.Dataset == null) return;
     332      Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;
     333      Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;
     334    }
     335
     336    protected override void OnStopped() {
     337      base.OnStopped();
     338      //bwerth: state objects can be very large; avoid state serialization
     339      state = null;
     340      dataRowIndices = null;
    301341    }
    302342
    303343    private void OnProblemDataChanged(object sender, EventArgs args) {
    304344      if (Problem == null || Problem.ProblemData == null) return;
     345      OnPerplexityChanged(this, null);
     346      OnColumnsChanged(this, null);
     347      Problem.ProblemData.Changed += OnPerplexityChanged;
     348      Problem.ProblemData.Changed += OnColumnsChanged;
     349      if (Problem.ProblemData.Dataset == null) return;
     350      Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
     351      Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    305352      if (!Parameters.ContainsKey(ClassesNameParameterName)) return;
    306353      ClassesNameParameter.ValidValues.Clear();
     
    308355    }
    309356
     357    private void OnColumnsChanged(object sender, EventArgs e) {
     358      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return;
     359      DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().AdaptToProblemData(Problem.ProblemData);
     360    }
     361
     362    private void RegisterParameterEvents() {
     363      PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
     364    }
     365
     366    private void OnPerplexityChanged(object sender, EventArgs e) {
     367      if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;
     368      PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));
     369    }
    310370    #endregion
    311371
    312372    #region Helpers
    313     private void SetUpResults(IReadOnlyCollection<double[]> data) {
     373    private void SetUpResults(IReadOnlyList<int> allIndices) {
    314374      if (Results == null) return;
    315375      var results = Results;
    316       dataRowNames = new Dictionary<string, List<int>>();
    317       dataRows = new Dictionary<string, ScatterPlotDataRow>();
     376      dataRowIndices = new Dictionary<string, IList<int>>();
    318377      var problemData = Problem.ProblemData;
    319378
    320       //color datapoints acording to classes variable (be it double or string)
    321       if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
    322         if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) {
    323           var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray();
    324           for (var i = 0; i < classes.Length; i++) {
    325             if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
    326             dataRowNames[classes[i]].Add(i);
     379      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     380      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
     381      if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     382      if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     383      if (!results.ContainsKey(ErrorPlotResultName)) {
     384        var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") {
     385          VisualProperties = {
     386            XAxisTitle = "UpdateIntervall",
     387            YAxisTitle = "Error",
     388            YAxisLogScale = true
    327389          }
    328         } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) {
    329           var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray();
    330           var max = classValues.Max() + 0.1;
    331           var min = classValues.Min() - 0.1;
    332           const int contours = 8;
    333           for (var i = 0; i < contours; i++) {
    334             var contourname = GetContourName(i, min, max, contours);
    335             dataRowNames.Add(contourname, new List<int>());
    336             dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
    337             dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
    338             dataRows[contourname].VisualProperties.PointSize = i + 3;
    339           }
    340           for (var i = 0; i < classValues.Length; i++) {
    341             dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
    342           }
    343         }
     390        };
     391        errortable.Rows.Add(new DataRow("Errors"));
     392        errortable.Rows["Errors"].VisualProperties.StartIndexZero = true;
     393        results.Add(new Result(ErrorPlotResultName, errortable));
     394      }
     395
     396      //color datapoints acording to classes variable (be it double, datetime or string)
     397      if (!problemData.Dataset.VariableNames.Contains(ClassesName)) {
     398        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     399        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
     400        return;
     401      }
     402
     403      var classificationData = problemData as ClassificationProblemData;
     404      if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
     405        var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
     406        var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
     407        for (var i = 0; i < classes.Length; i++) {
     408          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     409          dataRowIndices[classes[i]].Add(i);
     410        }
     411      } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) {
     412        var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
     413        for (var i = 0; i < classes.Length; i++) {
     414          if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
     415          dataRowIndices[classes[i]].Add(i);
     416        }
     417      } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) {
     418        var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     419        const int contours = 8;
     420        Dictionary<int, string> contourMap;
     421        IClusteringModel clusterModel;
     422        double[][] borders;
     423        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     424        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     425        for (var i = 0; i < contours; i++) {
     426          var c = contourorder[i];
     427          var contourname = contourMap[c];
     428          dataRowIndices.Add(contourname, new List<int>());
     429          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     430          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
     431        }
     432        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     433        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
     434      } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
     435        var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
     436        const int contours = 8;
     437        Dictionary<int, string> contourMap;
     438        IClusteringModel clusterModel;
     439        double[][] borders;
     440        CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
     441        var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
     442        for (var i = 0; i < contours; i++) {
     443          var c = contourorder[i];
     444          var contourname = contourMap[c];
     445          dataRowIndices.Add(contourname, new List<int>());
     446          var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
     447          row.VisualProperties.PointSize = 8;
     448          ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
     449        }
     450        var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
     451        for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    344452      } else {
    345         dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
    346         dataRowNames.Add("Test", problemData.TestIndices.ToList());
    347       }
    348 
    349       if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
    350       else ((IntValue)results[IterationResultName].Value).Value = 0;
    351 
    352       if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    353       else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
    354 
    355       if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
    356       else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
    357 
    358       var plot = results[ErrorPlotResultName].Value as DataTable;
    359       if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
    360 
    361       if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
    362       plot.Rows["errors"].Values.Clear();
    363       plot.Rows["errors"].VisualProperties.StartIndexZero = true;
    364 
    365       results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    366       results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
     453        dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
     454        dataRowIndices.Add("Test", problemData.TestIndices.ToList());
     455      }
    367456    }
    368457
     
    372461      var plot = results[ErrorPlotResultName].Value as DataTable;
    373462      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    374       var errors = plot.Rows["errors"].Values;
     463      var errors = plot.Rows["Errors"].Values;
    375464      var c = tsneState.EvaluateError();
    376465      errors.Add(c);
     
    378467      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
    379468
    380       var ndata = Normalize(tsneState.newData);
     469      var ndata = NormalizeProjectedData(tsneState.newData);
    381470      results[DataResultName].Value = new DoubleMatrix(ndata);
    382471      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
    383       FillScatterPlot(ndata, splot);
    384     }
    385 
    386     private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
    387       foreach (var rowName in dataRowNames.Keys) {
    388         if (!plot.Rows.ContainsKey(rowName))
    389           plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    390         plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    391       }
    392     }
    393 
    394     private static double[,] Normalize(double[,] data) {
     472      FillScatterPlot(ndata, splot, dataRowIndices);
     473    }
     474
     475    private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) {
     476      foreach (var rowName in dataRowIndices.Keys) {
     477        if (!plot.Rows.ContainsKey(rowName)) {
     478          plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     479          plot.Rows[rowName].VisualProperties.PointSize = 8;
     480        }
     481        plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     482      }
     483    }
     484
     485    private static double[,] NormalizeProjectedData(double[,] data) {
    395486      var max = new double[data.GetLength(1)];
    396487      var min = new double[data.GetLength(1)];
     
    398489      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    399490      for (var i = 0; i < data.GetLength(0); i++)
    400         for (var j = 0; j < data.GetLength(1); j++) {
    401           var v = data[i, j];
    402           max[j] = Math.Max(max[j], v);
    403           min[j] = Math.Min(min[j], v);
    404         }
     491      for (var j = 0; j < data.GetLength(1); j++) {
     492        var v = data[i, j];
     493        max[j] = Math.Max(max[j], v);
     494        min[j] = Math.Min(min[j], v);
     495      }
    405496      for (var i = 0; i < data.GetLength(0); i++) {
    406497        for (var j = 0; j < data.GetLength(1); j++) {
    407498          var d = max[j] - min[j];
    408           var s = data[i, j] - (max[j] + min[j]) / 2;  //shift data
    409           if (d.IsAlmost(0)) res[i, j] = data[i, j];   //no scaling possible
    410           else res[i, j] = s / d;  //scale data
     499          var s = data[i, j] - (max[j] + min[j]) / 2; //shift data
     500          if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible
     501          else res[i, j] = s / d; //scale data
    411502        }
    412503      }
     
    414505    }
    415506
    416     private static double[][] NormalizeData(IReadOnlyList<double[]> data) {
     507    private static double[][] NormalizeInputData(IReadOnlyList<IReadOnlyList<double>> data) {
    417508      // as in tSNE implementation by van der Maaten
    418       var n = data[0].Length;
     509      var n = data[0].Count;
    419510      var mean = new double[n];
    420511      var max = new double[n];
     
    426517      for (var i = 0; i < data.Count; i++) {
    427518        nData[i] = new double[n];
    428         for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
     519        for (var j = 0; j < n; j++)
     520          nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
    429521      }
    430522      return nData;
     
    432524
    433525    private static Color GetHeatMapColor(int contourNr, int noContours) {
    434       var q = (double)contourNr / noContours;  // q in [0,1]
    435       var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0);
    436       return c;
    437     }
    438 
    439     private static string GetContourName(double value, double min, double max, int noContours) {
    440       var size = (max - min) / noContours;
    441       var contourNr = (int)((value - min) / size);
    442       return GetContourName(contourNr, min, max, noContours);
    443     }
    444 
    445     private static string GetContourName(int i, double min, double max, int noContours) {
    446       var size = (max - min) / noContours;
    447       return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
     526      return ConvertTotalToRgb(0, noContours, contourNr);
     527    }
     528
     529    private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
     530      var cpd = new ClusteringProblemData((Dataset)data, new[] {target});
     531      contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model;
     532
     533      borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
     534      var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray();
     535      var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
     536      foreach (var i in cpd.AllIndices) {
     537        var cl = clusters[i] - 1;
     538        var clv = targetvalues[i];
     539        if (borders[cl][0] > clv) borders[cl][0] = clv;
     540        if (borders[cl][1] < clv) borders[cl][1] = clv;
     541      }
     542
     543      contourNames = new Dictionary<int, string>();
     544      for (var i = 0; i < contours; i++)
     545        contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
     546    }
     547
     548    private static Color ConvertTotalToRgb(double low, double high, double cell) {
     549      var colorGradient = ColorGradient.Colors;
     550      var range = high - low;
     551      var h = Math.Min(cell / range * colorGradient.Count, colorGradient.Count - 1);
     552      return colorGradient[(int)h];
    448553    }
    449554    #endregion
Note: See TracChangeset for help on using the changeset viewer.