Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/20/18 13:52:40 (6 years ago)
Author:
pfleck
Message:

#2845 reverted the last merge (r16307) because some revisions were missing

Location:
branches/2845_EnhancedProgress
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2845_EnhancedProgress

  • branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis

  • branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:mergeinfo deleted
  • branches/2845_EnhancedProgress/HeuristicLab.Algorithms.DataAnalysis/3.4/TSNE/TSNEAlgorithm.cs

    r16307 r16308  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    3838namespace HeuristicLab.Algorithms.DataAnalysis {
    3939  /// <summary>
    40   /// t-Distributed Stochastic Neighbor Embedding (tSNE) projects the data in a low dimensional
     40  /// t-distributed stochastic neighbourhood embedding (tSNE) projects the data in a low dimensional
    4141  /// space to allow visual cluster identification.
    4242  /// </summary>
    43   [Item("t-Distributed Stochastic Neighbor Embedding (tSNE)", "t-Distributed Stochastic Neighbor Embedding projects the data in a low " +
    44                                                               "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
     43  [Item("tSNE", "t-distributed stochastic neighbourhood embedding projects the data in a low " +
     44                "dimensional space to allow visual cluster identification. Implemented similar to: https://lvdmaaten.github.io/tsne/#implementations (Barnes-Hut t-SNE). Described in : https://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf")]
    4545  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 100)]
    4646  [StorableClass]
     
    5757    }
    5858
    59     #region Parameter names
     59    #region parameter names
    6060    private const string DistanceFunctionParameterName = "DistanceFunction";
    6161    private const string PerplexityParameterName = "Perplexity";
     
    7272    private const string ClassesNameParameterName = "ClassesName";
    7373    private const string NormalizationParameterName = "Normalization";
    74     private const string RandomInitializationParameterName = "RandomInitialization";
    7574    private const string UpdateIntervalParameterName = "UpdateInterval";
    7675    #endregion
    7776
    78     #region Result names
     77    #region result names
    7978    private const string IterationResultName = "Iteration";
    8079    private const string ErrorResultName = "Error";
     
    8483    #endregion
    8584
    86     #region Parameter properties
     85    #region parameter properties
    8786    public IFixedValueParameter<DoubleValue> PerplexityParameter {
    88       get { return (IFixedValueParameter<DoubleValue>)Parameters[PerplexityParameterName]; }
     87      get { return Parameters[PerplexityParameterName] as IFixedValueParameter<DoubleValue>; }
    8988    }
    9089    public IFixedValueParameter<PercentValue> ThetaParameter {
    91       get { return (IFixedValueParameter<PercentValue>)Parameters[ThetaParameterName]; }
     90      get { return Parameters[ThetaParameterName] as IFixedValueParameter<PercentValue>; }
    9291    }
    9392    public IFixedValueParameter<IntValue> NewDimensionsParameter {
    94       get { return (IFixedValueParameter<IntValue>)Parameters[NewDimensionsParameterName]; }
     93      get { return Parameters[NewDimensionsParameterName] as IFixedValueParameter<IntValue>; }
    9594    }
    9695    public IConstrainedValueParameter<IDistance<double[]>> DistanceFunctionParameter {
    97       get { return (IConstrainedValueParameter<IDistance<double[]>>)Parameters[DistanceFunctionParameterName]; }
     96      get { return Parameters[DistanceFunctionParameterName] as IConstrainedValueParameter<IDistance<double[]>>; }
    9897    }
    9998    public IFixedValueParameter<IntValue> MaxIterationsParameter {
    100       get { return (IFixedValueParameter<IntValue>)Parameters[MaxIterationsParameterName]; }
     99      get { return Parameters[MaxIterationsParameterName] as IFixedValueParameter<IntValue>; }
    101100    }
    102101    public IFixedValueParameter<IntValue> StopLyingIterationParameter {
    103       get { return (IFixedValueParameter<IntValue>)Parameters[StopLyingIterationParameterName]; }
     102      get { return Parameters[StopLyingIterationParameterName] as IFixedValueParameter<IntValue>; }
    104103    }
    105104    public IFixedValueParameter<IntValue> MomentumSwitchIterationParameter {
    106       get { return (IFixedValueParameter<IntValue>)Parameters[MomentumSwitchIterationParameterName]; }
     105      get { return Parameters[MomentumSwitchIterationParameterName] as IFixedValueParameter<IntValue>; }
    107106    }
    108107    public IFixedValueParameter<DoubleValue> InitialMomentumParameter {
    109       get { return (IFixedValueParameter<DoubleValue>)Parameters[InitialMomentumParameterName]; }
     108      get { return Parameters[InitialMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
    110109    }
    111110    public IFixedValueParameter<DoubleValue> FinalMomentumParameter {
    112       get { return (IFixedValueParameter<DoubleValue>)Parameters[FinalMomentumParameterName]; }
     111      get { return Parameters[FinalMomentumParameterName] as IFixedValueParameter<DoubleValue>; }
    113112    }
    114113    public IFixedValueParameter<DoubleValue> EtaParameter {
    115       get { return (IFixedValueParameter<DoubleValue>)Parameters[EtaParameterName]; }
     114      get { return Parameters[EtaParameterName] as IFixedValueParameter<DoubleValue>; }
    116115    }
    117116    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
    118       get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
     117      get { return Parameters[SetSeedRandomlyParameterName] as IFixedValueParameter<BoolValue>; }
    119118    }
    120119    public IFixedValueParameter<IntValue> SeedParameter {
    121       get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
     120      get { return Parameters[SeedParameterName] as IFixedValueParameter<IntValue>; }
    122121    }
    123122    public IConstrainedValueParameter<StringValue> ClassesNameParameter {
    124       get { return (IConstrainedValueParameter<StringValue>)Parameters[ClassesNameParameterName]; }
     123      get { return Parameters[ClassesNameParameterName] as IConstrainedValueParameter<StringValue>; }
    125124    }
    126125    public IFixedValueParameter<BoolValue> NormalizationParameter {
    127       get { return (IFixedValueParameter<BoolValue>)Parameters[NormalizationParameterName]; }
    128     }
    129     public IFixedValueParameter<BoolValue> RandomInitializationParameter {
    130       get { return (IFixedValueParameter<BoolValue>)Parameters[RandomInitializationParameterName]; }
     126      get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
    131127    }
    132128    public IFixedValueParameter<IntValue> UpdateIntervalParameter {
    133       get { return (IFixedValueParameter<IntValue>)Parameters[UpdateIntervalParameterName]; }
     129      get { return Parameters[UpdateIntervalParameterName] as IFixedValueParameter<IntValue>; }
    134130    }
    135131    #endregion
     
    191187      set { NormalizationParameter.Value.Value = value; }
    192188    }
    193     public bool RandomInitialization {
    194       get { return RandomInitializationParameter.Value.Value; }
    195       set { RandomInitializationParameter.Value.Value = value; }
    196     }
     189
    197190    public int UpdateInterval {
    198191      get { return UpdateIntervalParameter.Value.Value; }
     
    201194    #endregion
    202195
    203     #region Storable poperties
    204     [Storable]
    205     private Dictionary<string, IList<int>> dataRowIndices;
    206     [Storable]
    207     private TSNEStatic<double[]>.TSNEState state;
    208     #endregion
    209 
    210196    #region Constructors & Cloning
    211197    [StorableConstructor]
    212198    private TSNEAlgorithm(bool deserializing) : base(deserializing) { }
    213199
    214     [StorableHook(HookType.AfterDeserialization)]
    215     private void AfterDeserialization() {
    216       if (!Parameters.ContainsKey(RandomInitializationParameterName))
    217         Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
    218       RegisterParameterEvents();
    219     }
    220200    private TSNEAlgorithm(TSNEAlgorithm original, Cloner cloner) : base(original, cloner) {
    221       if (original.dataRowIndices != null)
    222         dataRowIndices = new Dictionary<string, IList<int>>(original.dataRowIndices);
     201      if (original.dataRowNames != null)
     202        this.dataRowNames = new Dictionary<string, List<int>>(original.dataRowNames);
     203      if (original.dataRows != null)
     204        this.dataRows = original.dataRows.ToDictionary(kvp => kvp.Key, kvp => cloner.Clone(kvp.Value));
    223205      if (original.state != null)
    224         state = cloner.Clone(original.state);
    225       RegisterParameterEvents();
    226     }
    227     public override IDeepCloneable Clone(Cloner cloner) {
    228       return new TSNEAlgorithm(this, cloner);
    229     }
     206        this.state = cloner.Clone(original.state);
     207      this.iter = original.iter;
     208    }
     209    public override IDeepCloneable Clone(Cloner cloner) { return new TSNEAlgorithm(this, cloner); }
    230210    public TSNEAlgorithm() {
    231211      var distances = new ItemSet<IDistance<double[]>>(ApplicationManager.Manager.GetInstances<IDistance<double[]>>());
     
    233213      Parameters.Add(new FixedValueParameter<DoubleValue>(PerplexityParameterName, "Perplexity-parameter of tSNE. Comparable to k in a k-nearest neighbour algorithm. Recommended value is floor(number of points /3) or lower", new DoubleValue(25)));
    234214      Parameters.Add(new FixedValueParameter<PercentValue>(ThetaParameterName, "Value describing how much appoximated " +
    235                                                                                "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
    236                                                                                "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
    237                                                                                "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
    238                                                                                "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
    239                                                                                " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
     215                                                                              "gradients my differ from exact gradients. Set to 0 for exact calculation and in [0,1] otherwise. " +
     216                                                                              "Appropriate values for theta are between 0.1 and 0.7 (default = 0.5). CAUTION: exact calculation of " +
     217                                                                              "forces requires building a non-sparse N*N matrix where N is the number of data points. This may " +
     218                                                                              "exceed memory limitations. The function is designed to run on large (N > 5000) data sets. It may give" +
     219                                                                              " poor performance on very small data sets(it is better to use a standard t - SNE implementation on such data).", new PercentValue(0)));
    240220      Parameters.Add(new FixedValueParameter<IntValue>(NewDimensionsParameterName, "Dimensionality of projected space (usually 2 for easy visual analysis)", new IntValue(2)));
    241221      Parameters.Add(new FixedValueParameter<IntValue>(MaxIterationsParameterName, "Maximum number of iterations for gradient descent.", new IntValue(1000)));
     
    250230      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
    251231      Parameters.Add(new FixedValueParameter<IntValue>(UpdateIntervalParameterName, "The interval after which the results will be updated.", new IntValue(50)));
    252       Parameters.Add(new FixedValueParameter<BoolValue>(RandomInitializationParameterName, "Wether data points should be randomly initialized or according to the first 2 dimensions", new BoolValue(true)));
    253 
    254       UpdateIntervalParameter.Hidden = true;
     232      Parameters[UpdateIntervalParameterName].Hidden = true;
     233
    255234      MomentumSwitchIterationParameter.Hidden = true;
    256235      InitialMomentumParameter.Hidden = true;
     
    259238      EtaParameter.Hidden = false;
    260239      Problem = new RegressionProblem();
    261       RegisterParameterEvents();
    262     }
    263     #endregion
     240    }
     241    #endregion
     242
     243    [Storable]
     244    private Dictionary<string, List<int>> dataRowNames;
     245    [Storable]
     246    private Dictionary<string, ScatterPlotDataRow> dataRows;
     247    [Storable]
     248    private TSNEStatic<double[]>.TSNEState state;
     249    [Storable]
     250    private int iter;
    264251
    265252    public override void Prepare() {
    266253      base.Prepare();
    267       dataRowIndices = null;
     254      dataRowNames = null;
     255      dataRows = null;
    268256      state = null;
    269257    }
     
    271259    protected override void Run(CancellationToken cancellationToken) {
    272260      var problemData = Problem.ProblemData;
    273       // set up and initialize everything if necessary
    274       var wdist = DistanceFunction as WeightedEuclideanDistance;
    275       if (wdist != null) wdist.Initialize(problemData);
     261      // set up and initialized everything if necessary
    276262      if (state == null) {
    277263        if (SetSeedRandomly) Seed = new System.Random().Next();
     
    279265        var dataset = problemData.Dataset;
    280266        var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
    281         var allindices = Problem.ProblemData.AllIndices.ToArray();
    282 
    283         // jagged array is required to meet the static method declarations of TSNEStatic<T>
    284         var data = Enumerable.Range(0, dataset.Rows).Select(x => new double[allowedInputVariables.Length]).ToArray();
    285         var col = 0;
    286         foreach (var s in allowedInputVariables) {
    287           var row = 0;
    288           foreach (var d in dataset.GetDoubleValues(s)) {
    289             data[row][col] = d;
    290             row++;
    291           }
    292           col++;
    293         }
    294         if (Normalization) data = NormalizeInputData(data);
    295         state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta, StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta, RandomInitialization);
    296         SetUpResults(allindices);
    297       }
    298       while (state.iter < MaxIterations && !cancellationToken.IsCancellationRequested) {
    299         if (state.iter % UpdateInterval == 0) Analyze(state);
     267        var data = new double[dataset.Rows][];
     268        for (var row = 0; row < dataset.Rows; row++)
     269          data[row] = allowedInputVariables.Select(col => dataset.GetDoubleValue(col, row)).ToArray();
     270
     271        if (Normalization) data = NormalizeData(data);
     272
     273        state = TSNEStatic<double[]>.CreateState(data, DistanceFunction, random, NewDimensions, Perplexity, Theta,
     274          StopLyingIteration, MomentumSwitchIteration, InitialMomentum, FinalMomentum, Eta);
     275
     276        SetUpResults(data);
     277        iter = 0;
     278      }
     279      for (; iter < MaxIterations && !cancellationToken.IsCancellationRequested; iter++) {
     280        if (iter % UpdateInterval == 0)
     281          Analyze(state);
    300282        TSNEStatic<double[]>.Iterate(state);
    301283      }
     
    312294    protected override void RegisterProblemEvents() {
    313295      base.RegisterProblemEvents();
    314       if (Problem == null) return;
    315296      Problem.ProblemDataChanged += OnProblemDataChanged;
    316       if (Problem.ProblemData == null) return;
    317       Problem.ProblemData.Changed += OnPerplexityChanged;
    318       Problem.ProblemData.Changed += OnColumnsChanged;
    319       if (Problem.ProblemData.Dataset == null) return;
    320       Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
    321       Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    322     }
    323 
     297    }
    324298    protected override void DeregisterProblemEvents() {
    325299      base.DeregisterProblemEvents();
    326       if (Problem == null) return;
    327300      Problem.ProblemDataChanged -= OnProblemDataChanged;
    328       if (Problem.ProblemData == null) return;
    329       Problem.ProblemData.Changed -= OnPerplexityChanged;
    330       Problem.ProblemData.Changed -= OnColumnsChanged;
    331       if (Problem.ProblemData.Dataset == null) return;
    332       Problem.ProblemData.Dataset.RowsChanged -= OnPerplexityChanged;
    333       Problem.ProblemData.Dataset.ColumnsChanged -= OnColumnsChanged;
    334     }
    335 
    336     protected override void OnStopped() {
    337       base.OnStopped();
    338       //bwerth: state objects can be very large; avoid state serialization
    339       state = null;
    340       dataRowIndices = null;
    341301    }
    342302
    343303    private void OnProblemDataChanged(object sender, EventArgs args) {
    344304      if (Problem == null || Problem.ProblemData == null) return;
    345       OnPerplexityChanged(this, null);
    346       OnColumnsChanged(this, null);
    347       Problem.ProblemData.Changed += OnPerplexityChanged;
    348       Problem.ProblemData.Changed += OnColumnsChanged;
    349       if (Problem.ProblemData.Dataset == null) return;
    350       Problem.ProblemData.Dataset.RowsChanged += OnPerplexityChanged;
    351       Problem.ProblemData.Dataset.ColumnsChanged += OnColumnsChanged;
    352305      if (!Parameters.ContainsKey(ClassesNameParameterName)) return;
    353306      ClassesNameParameter.ValidValues.Clear();
     
    355308    }
    356309
    357     private void OnColumnsChanged(object sender, EventArgs e) {
    358       if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(DistanceFunctionParameterName)) return;
    359       DistanceFunctionParameter.ValidValues.OfType<WeightedEuclideanDistance>().Single().AdaptToProblemData(Problem.ProblemData);
    360     }
    361 
    362     private void RegisterParameterEvents() {
    363       PerplexityParameter.Value.ValueChanged += OnPerplexityChanged;
    364     }
    365 
    366     private void OnPerplexityChanged(object sender, EventArgs e) {
    367       if (Problem == null || Problem.ProblemData == null || Problem.ProblemData.Dataset == null || !Parameters.ContainsKey(PerplexityParameterName)) return;
    368       PerplexityParameter.Value.Value = Math.Max(1, Math.Min((Problem.ProblemData.Dataset.Rows - 1) / 3.0, Perplexity));
    369     }
    370310    #endregion
    371311
    372312    #region Helpers
    373     private void SetUpResults(IReadOnlyList<int> allIndices) {
     313    private void SetUpResults(IReadOnlyCollection<double[]> data) {
    374314      if (Results == null) return;
    375315      var results = Results;
    376       dataRowIndices = new Dictionary<string, IList<int>>();
     316      dataRowNames = new Dictionary<string, List<int>>();
     317      dataRows = new Dictionary<string, ScatterPlotDataRow>();
    377318      var problemData = Problem.ProblemData;
    378319
     320      //color datapoints acording to classes variable (be it double or string)
     321      if (problemData.Dataset.VariableNames.Contains(ClassesName)) {
     322        if ((problemData.Dataset as Dataset).VariableHasType<string>(ClassesName)) {
     323          var classes = problemData.Dataset.GetStringValues(ClassesName).ToArray();
     324          for (var i = 0; i < classes.Length; i++) {
     325            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
     326            dataRowNames[classes[i]].Add(i);
     327          }
     328        } else if ((problemData.Dataset as Dataset).VariableHasType<double>(ClassesName)) {
     329          var classValues = problemData.Dataset.GetDoubleValues(ClassesName).ToArray();
     330          var max = classValues.Max() + 0.1;
     331          var min = classValues.Min() - 0.1;
     332          const int contours = 8;
     333          for (var i = 0; i < contours; i++) {
     334            var contourname = GetContourName(i, min, max, contours);
     335            dataRowNames.Add(contourname, new List<int>());
     336            dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
     337            dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
     338            dataRows[contourname].VisualProperties.PointSize = i + 3;
     339          }
     340          for (var i = 0; i < classValues.Length; i++) {
     341            dataRowNames[GetContourName(classValues[i], min, max, contours)].Add(i);
     342          }
     343        }
     344      } else {
     345        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
     346        dataRowNames.Add("Test", problemData.TestIndices.ToList());
     347      }
     348
    379349      if (!results.ContainsKey(IterationResultName)) results.Add(new Result(IterationResultName, new IntValue(0)));
     350      else ((IntValue)results[IterationResultName].Value).Value = 0;
     351
    380352      if (!results.ContainsKey(ErrorResultName)) results.Add(new Result(ErrorResultName, new DoubleValue(0)));
    381       if (!results.ContainsKey(ScatterPlotResultName)) results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
    382       if (!results.ContainsKey(DataResultName)) results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
    383       if (!results.ContainsKey(ErrorPlotResultName)) {
    384         var errortable = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent") {
    385           VisualProperties = {
    386             XAxisTitle = "UpdateIntervall",
    387             YAxisTitle = "Error",
    388             YAxisLogScale = true
    389           }
    390         };
    391         errortable.Rows.Add(new DataRow("Errors"));
    392         errortable.Rows["Errors"].VisualProperties.StartIndexZero = true;
    393         results.Add(new Result(ErrorPlotResultName, errortable));
    394       }
    395 
    396       //color datapoints acording to classes variable (be it double, datetime or string)
    397       if (!problemData.Dataset.VariableNames.Contains(ClassesName)) {
    398         dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
    399         dataRowIndices.Add("Test", problemData.TestIndices.ToList());
    400         return;
    401       }
    402 
    403       var classificationData = problemData as ClassificationProblemData;
    404       if (classificationData != null && classificationData.TargetVariable.Equals(ClassesName)) {
    405         var classNames = classificationData.ClassValues.Zip(classificationData.ClassNames, (v, n) => new {v, n}).ToDictionary(x => x.v, x => x.n);
    406         var classes = classificationData.Dataset.GetDoubleValues(classificationData.TargetVariable, allIndices).Select(v => classNames[v]).ToArray();
    407         for (var i = 0; i < classes.Length; i++) {
    408           if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
    409           dataRowIndices[classes[i]].Add(i);
    410         }
    411       } else if (((Dataset)problemData.Dataset).VariableHasType<string>(ClassesName)) {
    412         var classes = problemData.Dataset.GetStringValues(ClassesName, allIndices).ToArray();
    413         for (var i = 0; i < classes.Length; i++) {
    414           if (!dataRowIndices.ContainsKey(classes[i])) dataRowIndices.Add(classes[i], new List<int>());
    415           dataRowIndices[classes[i]].Add(i);
    416         }
    417       } else if (((Dataset)problemData.Dataset).VariableHasType<double>(ClassesName)) {
    418         var clusterdata = new Dataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
    419         const int contours = 8;
    420         Dictionary<int, string> contourMap;
    421         IClusteringModel clusterModel;
    422         double[][] borders;
    423         CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
    424         var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
    425         for (var i = 0; i < contours; i++) {
    426           var c = contourorder[i];
    427           var contourname = contourMap[c];
    428           dataRowIndices.Add(contourname, new List<int>());
    429           var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
    430           ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
    431         }
    432         var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    433         for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    434       } else if (((Dataset)problemData.Dataset).VariableHasType<DateTime>(ClassesName)) {
    435         var clusterdata = new Dataset(problemData.Dataset.DateTimeVariables, problemData.Dataset.DateTimeVariables.Select(v => problemData.Dataset.GetDoubleValues(v, allIndices).ToList()));
    436         const int contours = 8;
    437         Dictionary<int, string> contourMap;
    438         IClusteringModel clusterModel;
    439         double[][] borders;
    440         CreateClusters(clusterdata, ClassesName, contours, out clusterModel, out contourMap, out borders);
    441         var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
    442         for (var i = 0; i < contours; i++) {
    443           var c = contourorder[i];
    444           var contourname = contourMap[c];
    445           dataRowIndices.Add(contourname, new List<int>());
    446           var row = new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()) {VisualProperties = {Color = GetHeatMapColor(i, contours), PointSize = 8}};
    447           row.VisualProperties.PointSize = 8;
    448           ((ScatterPlot)results[ScatterPlotResultName].Value).Rows.Add(row);
    449         }
    450         var allClusters = clusterModel.GetClusterValues(clusterdata, Enumerable.Range(0, clusterdata.Rows)).ToArray();
    451         for (var i = 0; i < clusterdata.Rows; i++) dataRowIndices[contourMap[allClusters[i] - 1]].Add(i);
    452       } else {
    453         dataRowIndices.Add("Training", problemData.TrainingIndices.ToList());
    454         dataRowIndices.Add("Test", problemData.TestIndices.ToList());
    455       }
     353      else ((DoubleValue)results[ErrorResultName].Value).Value = 0;
     354
     355      if (!results.ContainsKey(ErrorPlotResultName)) results.Add(new Result(ErrorPlotResultName, new DataTable(ErrorPlotResultName, "Development of errors during gradient descent")));
     356      else results[ErrorPlotResultName].Value = new DataTable(ErrorPlotResultName, "Development of errors during gradient descent");
     357
     358      var plot = results[ErrorPlotResultName].Value as DataTable;
     359      if (plot == null) throw new ArgumentException("could not create/access error data table in results collection");
     360
     361      if (!plot.Rows.ContainsKey("errors")) plot.Rows.Add(new DataRow("errors"));
     362      plot.Rows["errors"].Values.Clear();
     363      plot.Rows["errors"].VisualProperties.StartIndexZero = true;
     364
     365      results.Add(new Result(ScatterPlotResultName, "Plot of the projected data", new ScatterPlot(DataResultName, "")));
     366      results.Add(new Result(DataResultName, "Projected Data", new DoubleMatrix()));
    456367    }
    457368
     
    461372      var plot = results[ErrorPlotResultName].Value as DataTable;
    462373      if (plot == null) throw new ArgumentException("Could not create/access error data table in results collection.");
    463       var errors = plot.Rows["Errors"].Values;
     374      var errors = plot.Rows["errors"].Values;
    464375      var c = tsneState.EvaluateError();
    465376      errors.Add(c);
     
    467378      ((DoubleValue)results[ErrorResultName].Value).Value = errors.Last();
    468379
    469       var ndata = NormalizeProjectedData(tsneState.newData);
     380      var ndata = Normalize(tsneState.newData);
    470381      results[DataResultName].Value = new DoubleMatrix(ndata);
    471382      var splot = results[ScatterPlotResultName].Value as ScatterPlot;
    472       FillScatterPlot(ndata, splot, dataRowIndices);
    473     }
    474 
    475     private static void FillScatterPlot(double[,] lowDimData, ScatterPlot plot, Dictionary<string, IList<int>> dataRowIndices) {
    476       foreach (var rowName in dataRowIndices.Keys) {
    477         if (!plot.Rows.ContainsKey(rowName)) {
    478           plot.Rows.Add(new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
    479           plot.Rows[rowName].VisualProperties.PointSize = 8;
    480         }
    481         plot.Rows[rowName].Points.Replace(dataRowIndices[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
    482       }
    483     }
    484 
    485     private static double[,] NormalizeProjectedData(double[,] data) {
     383      FillScatterPlot(ndata, splot);
     384    }
     385
     386    private void FillScatterPlot(double[,] lowDimData, ScatterPlot plot) {
     387      foreach (var rowName in dataRowNames.Keys) {
     388        if (!plot.Rows.ContainsKey(rowName))
     389          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
     390        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1])));
     391      }
     392    }
     393
     394    private static double[,] Normalize(double[,] data) {
    486395      var max = new double[data.GetLength(1)];
    487396      var min = new double[data.GetLength(1)];
     
    489398      for (var i = 0; i < max.Length; i++) max[i] = min[i] = data[0, i];
    490399      for (var i = 0; i < data.GetLength(0); i++)
    491       for (var j = 0; j < data.GetLength(1); j++) {
    492         var v = data[i, j];
    493         max[j] = Math.Max(max[j], v);
    494         min[j] = Math.Min(min[j], v);
    495       }
     400        for (var j = 0; j < data.GetLength(1); j++) {
     401          var v = data[i, j];
     402          max[j] = Math.Max(max[j], v);
     403          min[j] = Math.Min(min[j], v);
     404        }
    496405      for (var i = 0; i < data.GetLength(0); i++) {
    497406        for (var j = 0; j < data.GetLength(1); j++) {
    498407          var d = max[j] - min[j];
    499           var s = data[i, j] - (max[j] + min[j]) / 2; //shift data
    500           if (d.IsAlmost(0)) res[i, j] = data[i, j]; //no scaling possible
    501           else res[i, j] = s / d; //scale data
     408          var s = data[i, j] - (max[j] + min[j]) / 2;  //shift data
     409          if (d.IsAlmost(0)) res[i, j] = data[i, j];   //no scaling possible
     410          else res[i, j] = s / d;  //scale data
    502411        }
    503412      }
     
    505414    }
    506415
    507     private static double[][] NormalizeInputData(IReadOnlyList<IReadOnlyList<double>> data) {
     416    private static double[][] NormalizeData(IReadOnlyList<double[]> data) {
    508417      // as in tSNE implementation by van der Maaten
    509       var n = data[0].Count;
     418      var n = data[0].Length;
    510419      var mean = new double[n];
    511420      var max = new double[n];
     
    517426      for (var i = 0; i < data.Count; i++) {
    518427        nData[i] = new double[n];
    519         for (var j = 0; j < n; j++)
    520           nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
     428        for (var j = 0; j < n; j++) nData[i][j] = max[j].IsAlmost(0) ? data[i][j] - mean[j] : (data[i][j] - mean[j]) / max[j];
    521429      }
    522430      return nData;
     
    524432
    525433    private static Color GetHeatMapColor(int contourNr, int noContours) {
    526       return ConvertTotalToRgb(0, noContours, contourNr);
    527     }
    528 
    529     private static void CreateClusters(IDataset data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
    530       var cpd = new ClusteringProblemData((Dataset)data, new[] {target});
    531       contourCluster = KMeansClustering.CreateKMeansSolution(cpd, contours, 3).Model;
    532 
    533       borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
    534       var clusters = contourCluster.GetClusterValues(cpd.Dataset, cpd.AllIndices).ToArray();
    535       var targetvalues = cpd.Dataset.GetDoubleValues(target).ToArray();
    536       foreach (var i in cpd.AllIndices) {
    537         var cl = clusters[i] - 1;
    538         var clv = targetvalues[i];
    539         if (borders[cl][0] > clv) borders[cl][0] = clv;
    540         if (borders[cl][1] < clv) borders[cl][1] = clv;
    541       }
    542 
    543       contourNames = new Dictionary<int, string>();
    544       for (var i = 0; i < contours; i++)
    545         contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
    546     }
    547 
    548     private static Color ConvertTotalToRgb(double low, double high, double cell) {
    549       var colorGradient = ColorGradient.Colors;
    550       var range = high - low;
    551       var h = Math.Min(cell / range * colorGradient.Count, colorGradient.Count - 1);
    552       return colorGradient[(int)h];
     434      var q = (double)contourNr / noContours;  // q in [0,1]
     435      var c = q < 0.5 ? Color.FromArgb((int)(q * 2 * 255), 255, 0) : Color.FromArgb(255, (int)((1 - q) * 2 * 255), 0);
     436      return c;
     437    }
     438
     439    private static string GetContourName(double value, double min, double max, int noContours) {
     440      var size = (max - min) / noContours;
     441      var contourNr = (int)((value - min) / size);
     442      return GetContourName(contourNr, min, max, noContours);
     443    }
     444
     445    private static string GetContourName(int i, double min, double max, int noContours) {
     446      var size = (max - min) / noContours;
     447      return "[" + (min + i * size) + ";" + (min + (i + 1) * size) + ")";
    553448    }
    554449    #endregion
Note: See TracChangeset for help on using the changeset viewer.