Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/PCA/PrincipleComponentAnalysis.cs @ 15430

Last change on this file since 15430 was 15430, checked in by bwerth, 7 years ago

#2847 first implementation of M5'-regression

File size: 7.4 KB
Line 
1using System.Collections.Generic;
2using System.Drawing;
3using System.Linq;
4using System.Threading;
5using HeuristicLab.Analysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Optimization;
10using HeuristicLab.Parameters;
11using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
12using HeuristicLab.Problems.DataAnalysis;
13
14namespace HeuristicLab.Algorithms.DataAnalysis {
15  [StorableClass]
16  [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 95)]
17  [Item("PrincipleComponentAnalysis", "Standard Principle Component Analyis")]
18  public sealed class PrincipleComponentAnalysis : FixedDataAnalysisAlgorithm<IRegressionProblem> {
19    private const string NormalizationParameterName = "Normalization";
20    public IFixedValueParameter<BoolValue> NormalizationParameter {
21      get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
22    }
23    public bool Normalization {
24      get { return NormalizationParameter.Value.Value; }
25    }
26
27    [StorableConstructor]
28    private PrincipleComponentAnalysis(bool deserializing) : base(deserializing) { }
29    private PrincipleComponentAnalysis(PrincipleComponentAnalysis original, Cloner cloner) : base(original, cloner) { }
30    public PrincipleComponentAnalysis() {
31      Problem = new RegressionProblem();
32      Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
33    }
34    [StorableHook(HookType.AfterDeserialization)]
35    private void AfterDeserialization() { }
36    public override IDeepCloneable Clone(Cloner cloner) {
37      return new PrincipleComponentAnalysis(this, cloner);
38    }
39
40    protected override void Run(CancellationToken cancellationToken) {
41      var data = Problem.ProblemData.Dataset;
42      var attributes = Problem.ProblemData.AllowedInputVariables.ToArray();
43      var solution = PrincipleComponentAnalysisStatic.Create(data, Problem.ProblemData.AllIndices, attributes, Normalization);
44      var res = solution.ProjectData(Problem.ProblemData.Dataset, Problem.ProblemData.AllIndices);
45      Results.Add(new Result("ProjectedData", new DoubleMatrix(res)));
46      Results.Add(new Result("Variances", new DoubleArray(solution.Variances)));
47      CreateScatterPlot(Problem.ProblemData.TargetVariable, res);
48    }
49
50    #region Fancy ScatterPlot
51    private void CreateScatterPlot(string classesName, double[,] lowDimData) {
52      var results = Results;
53      var dataRowNames = new Dictionary<string, List<int>>();
54      var dataRows = new Dictionary<string, ScatterPlotDataRow>();
55      var problemData = Problem.ProblemData;
56
57      //color datapoints acording to classes variable (be it double or string)
58      if (problemData.Dataset.VariableNames.Contains(classesName)) {
59        if (((Dataset) problemData.Dataset).VariableHasType<string>(classesName)) {
60          var classes = problemData.Dataset.GetStringValues(classesName).ToArray();
61          for (var i = 0; i < classes.Length; i++) {
62            if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
63            dataRowNames[classes[i]].Add(i);
64          }
65        }
66        else if (((Dataset) problemData.Dataset).VariableHasType<double>(classesName)) {
67          const int contours = 8;
68          Dictionary<int, string> contourMap;
69          IClusteringModel clusterModel;
70          double[][] borders;
71          CreateClusters(problemData, classesName, contours, out clusterModel, out contourMap, out borders);
72
73          var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
74          for (var i = 0; i < contours; i++) {
75            var c = contourorder[i];
76            var contourname = contourMap[c];
77            dataRowNames.Add(contourname, new List<int>());
78            dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
79            dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
80            dataRows[contourname].VisualProperties.PointSize = i + 3;
81          }
82
83          var allClusters = clusterModel.GetClusterValues(problemData.Dataset, problemData.AllIndices).ToArray();
84          for (var i = 0; i < problemData.Dataset.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
85        }
86      }
87      else {
88        dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
89        dataRowNames.Add("Test", problemData.TestIndices.ToList());
90      }
91
92      var plotName = "PjoctedDataPlot";
93      var plot = new ScatterPlot(plotName, "");
94      results.Add(new Result(plotName, "Plot of the projected data", plot));
95      foreach (var rowName in dataRowNames.Keys) {
96        if (!plot.Rows.ContainsKey(rowName))
97          plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
98        plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1 % lowDimData.GetLength(1)])));
99      }
100    }
101
102    private static void CreateClusters(IDataAnalysisProblemData data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
103      contourCluster = KMeansClustering.CreateKMeansSolution(new ClusteringProblemData((Dataset) data.Dataset, new[] {target}), contours, 3).Model;
104
105      borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
106      var clusters = contourCluster.GetClusterValues(data.Dataset, data.AllIndices).ToArray();
107      var targetvalues = data.Dataset.GetDoubleValues(target).ToArray();
108      foreach (var i in data.AllIndices) {
109        var cl = clusters[i] - 1;
110        var clv = targetvalues[i];
111        if (borders[cl][0] > clv) borders[cl][0] = clv;
112        if (borders[cl][1] < clv) borders[cl][1] = clv;
113      }
114
115      contourNames = new Dictionary<int, string>();
116      for (var i = 0; i < contours; i++)
117        contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
118    }
119
120    private static Color GetHeatMapColor(int contourNr, int noContours) {
121      var hue = contourNr / (double) noContours;
122      const double saturation = 0.999f;
123      const double value = 0.999f;
124
125      if (hue > 0.999f) { hue = 0.999f; }
126      if (hue < 0.001f) { hue = 0.001f; }
127
128      var h6 = hue * 6f;
129      if (h6.IsAlmost(6f)) { h6 = 0f; }
130      var ihue = (int) h6;
131      var p = value * (1f - saturation);
132      var q = value * (1f - saturation * (h6 - ihue));
133      var t = value * (1f - saturation * (1f - (h6 - ihue)));
134      switch (ihue) {
135        case 0: return Color.FromArgb((int) (value * 255), (int) (t * 255), (int) (p * 255));
136        case 1: return Color.FromArgb((int) (q * 255), (int) (value * 255), (int) (p * 255));
137        case 2: return Color.FromArgb((int) (p * 255), (int) (value * 255), (int) (t * 255));
138        case 3: return Color.FromArgb((int) (p * 255), (int) (q * 255), (int) (value * 255));
139        case 4: return Color.FromArgb((int) (t * 255), (int) (p * 255), (int) (value * 255));
140        default: return Color.FromArgb((int) (value * 255), (int) (p * 255), (int) (q * 255));
141      }
142    }
143    #endregion
144  }
145}
Note: See TracBrowser for help on using the repository browser.