Free cookie consent management tool by TermsFeed Policy Generator

source: branches/EfficientGlobalOptimization/HeuristicLab.Algorithms.EGO/Operators/FitnessClusteringAnalyzer.cs @ 15870

Last change on this file since 15870 was 15338, checked in by bwerth, 7 years ago

#2745 fixed bug concerning new Start and StartAsync methods; passed CancellationToken to sub algorithms

File size: 6.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using System.Threading;
25using HeuristicLab.Algorithms.DataAnalysis;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Problems.DataAnalysis;
35
36namespace HeuristicLab.Algorithms.EGO {
37  [Item("FitnessClusteringAnalyzer", "Analyzes the correlation between perdictions and actual fitness values")]
38  [StorableClass]
39  public class FitnessClusteringAnalyzer : SingleSuccessorOperator, IAnalyzer, IStochasticOperator, IResultsOperator {
40    public override bool CanChangeName => true;
41    public bool EnabledByDefault => false;
42
43    public ILookupParameter<ModifiableDataset> DatasetParameter => (ILookupParameter<ModifiableDataset>)Parameters["Dataset"];
44    public ILookupParameter<ResultCollection> ResultsParameter => (ILookupParameter<ResultCollection>)Parameters["Results"];
45    public IFixedValueParameter<IntValue> KParameter => (IFixedValueParameter<IntValue>)Parameters["K"];
46    public IFixedValueParameter<IntValue> K2Parameter => (IFixedValueParameter<IntValue>)Parameters["K2"];
47    public ILookupParameter<IRandom> RandomParameter => (ILookupParameter<IRandom>)Parameters["Random"];
48
49    private const string SolutionName = "FitnessClustering";
50    private const string PlotName = "FitnessClusterPlot";
51
52    [StorableConstructor]
53    protected FitnessClusteringAnalyzer(bool deserializing) : base(deserializing) { }
54
55    protected FitnessClusteringAnalyzer(FitnessClusteringAnalyzer original, Cloner cloner) : base(original, cloner) { }
56
57    public FitnessClusteringAnalyzer() {
58      Parameters.Add(new LookupParameter<ModifiableDataset>("Dataset"));
59      Parameters.Add(new LookupParameter<ResultCollection>("Results", "The collection to store the results in."));
60      Parameters.Add(new FixedValueParameter<IntValue>("K", "The number of clusters.", new IntValue(3)));
61      Parameters.Add(new FixedValueParameter<IntValue>("K2", "The number of clusters.", new IntValue(3)));
62      Parameters.Add(new LookupParameter<IRandom>("Random"));
63    }
64
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new FitnessClusteringAnalyzer(this, cloner);
67    }
68
69    public sealed override IOperation Apply() {
70      var dataset = DatasetParameter.ActualValue;
71      var results = ResultsParameter.ActualValue;
72      var random = RandomParameter.ActualValue;
73      if (dataset.Rows < KParameter.Value.Value || dataset.Rows < 20) return base.Apply();
74
75      var clustering = CreateClustering(dataset, random);
76      if (!results.ContainsKey(SolutionName)) results.Add(new Result(SolutionName, clustering));
77      results[SolutionName].Value = clustering;
78      var plot = CreateTSNEPlot(clustering, dataset, random);
79      if (!results.ContainsKey(PlotName)) results.Add(new Result(PlotName, plot));
80      results[PlotName].Value = plot;
81
82      return base.Apply();
83    }
84
85
86    private ScatterPlot CreateTSNEPlot(KMeansClusteringSolution clustering, ModifiableDataset data, IRandom random) {
87      var clusteredData = (ModifiableDataset)data.Clone();
88      clusteredData.AddVariable("cluster", clustering.ClusterValues.Select(x => (double)x));
89
90      var prob = new RegressionProblem {
91        ProblemData = new RegressionProblemData(clusteredData, new[] { "output" }, "cluster")
92      };
93      var tsne = new TSNEAlgorithm {
94        Perplexity = data.Rows / 3 - 1,
95        Problem = prob
96      };
97      tsne.ClassesNameParameter.Value = tsne.ClassesNameParameter.ValidValues.FirstOrDefault(x => x.Value.Equals("cluster"));
98      var res = EgoUtilities.SyncRunSubAlgorithm(tsne, random.Next(), CancellationToken.None);
99      return res.Select(r => r.Value).OfType<ScatterPlot>().First();
100    }
101
102    private KMeansClusteringSolution CreateClustering(ModifiableDataset dataset, IRandom random) {
103      var pd = new ClusteringProblemData(dataset, new[] { "output" });
104      pd.TestPartition.Start = dataset.Rows;
105      pd.TestPartition.End = dataset.Rows;
106      pd.TrainingPartition.Start = 0;
107      pd.TrainingPartition.End = dataset.Rows;
108      return KMeansClustering.CreateKMeansSolution(pd, KParameter.Value.Value, 1);
109    }
110    private double[] GetWeights(ModifiableDataset dataset) {
111      var inputMatrix = dataset.ToArray(dataset.VariableNames.Where(x => x.StartsWith("input")), Enumerable.Range(0, dataset.Rows));
112      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
113        throw new NotSupportedException("k-Means clustering does not support NaN or infinity values in the input dataset.");
114      var indices = Enumerable.Range(0, inputMatrix.GetLength(0)).ToArray();
115      return indices.Select(i =>
116            K2Parameter.Value.Value > 0 ? indices.Where(j => j != i).Select(j =>
117                    1 / Math.Sqrt(EuclideanSquared(inputMatrix, inputMatrix, i, j))
118            ).OrderBy(x => x).Take(K2Parameter.Value.Value).Sum() : 1.0
119         ).ToArray();
120    }
121    private static double EuclideanSquared(double[,] input, double[,] input2, int row1, int row2) {
122      var sum = 0.0;
123      for (var i = 0; i < input.GetLength(1); i++) {
124        var d = input[row1, i] - input2[row2, i];
125        sum += d * d;
126      }
127      return sum;
128    }
129
130  }
131}
Note: See TracBrowser for help on using the repository browser.