1 | using System.Collections.Generic;
|
---|
2 | using System.Drawing;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Threading;
|
---|
5 | using HeuristicLab.Analysis;
|
---|
6 | using HeuristicLab.Common;
|
---|
7 | using HeuristicLab.Core;
|
---|
8 | using HeuristicLab.Data;
|
---|
9 | using HeuristicLab.Optimization;
|
---|
10 | using HeuristicLab.Parameters;
|
---|
11 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
12 | using HeuristicLab.Problems.DataAnalysis;
|
---|
13 |
|
---|
14 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
15 | [StorableClass]
|
---|
16 | [Creatable(CreatableAttribute.Categories.DataAnalysis, Priority = 95)]
|
---|
17 | [Item("PrincipleComponentAnalysis", "Standard Principle Component Analyis")]
|
---|
18 | public sealed class PrincipleComponentAnalysis : FixedDataAnalysisAlgorithm<IRegressionProblem> {
|
---|
19 | private const string NormalizationParameterName = "Normalization";
|
---|
20 | public IFixedValueParameter<BoolValue> NormalizationParameter {
|
---|
21 | get { return Parameters[NormalizationParameterName] as IFixedValueParameter<BoolValue>; }
|
---|
22 | }
|
---|
23 | public bool Normalization {
|
---|
24 | get { return NormalizationParameter.Value.Value; }
|
---|
25 | }
|
---|
26 |
|
---|
27 | [StorableConstructor]
|
---|
28 | private PrincipleComponentAnalysis(bool deserializing) : base(deserializing) { }
|
---|
29 | private PrincipleComponentAnalysis(PrincipleComponentAnalysis original, Cloner cloner) : base(original, cloner) { }
|
---|
30 | public PrincipleComponentAnalysis() {
|
---|
31 | Problem = new RegressionProblem();
|
---|
32 | Parameters.Add(new FixedValueParameter<BoolValue>(NormalizationParameterName, "Whether the data should be zero centered and have variance of 1 for each variable, so different scalings are ignored.", new BoolValue(true)));
|
---|
33 | }
|
---|
34 | [StorableHook(HookType.AfterDeserialization)]
|
---|
35 | private void AfterDeserialization() { }
|
---|
36 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
37 | return new PrincipleComponentAnalysis(this, cloner);
|
---|
38 | }
|
---|
39 |
|
---|
40 | protected override void Run(CancellationToken cancellationToken) {
|
---|
41 | var data = Problem.ProblemData.Dataset;
|
---|
42 | var attributes = Problem.ProblemData.AllowedInputVariables.ToArray();
|
---|
43 | var solution = PrincipleComponentAnalysisStatic.Create(data, Problem.ProblemData.AllIndices, attributes, Normalization);
|
---|
44 | var res = solution.ProjectData(Problem.ProblemData.Dataset, Problem.ProblemData.AllIndices);
|
---|
45 | Results.Add(new Result("ProjectedData", new DoubleMatrix(res)));
|
---|
46 | Results.Add(new Result("Variances", new DoubleArray(solution.Variances)));
|
---|
47 | CreateScatterPlot(Problem.ProblemData.TargetVariable, res);
|
---|
48 | }
|
---|
49 |
|
---|
50 | #region Fancy ScatterPlot
|
---|
51 | private void CreateScatterPlot(string classesName, double[,] lowDimData) {
|
---|
52 | var results = Results;
|
---|
53 | var dataRowNames = new Dictionary<string, List<int>>();
|
---|
54 | var dataRows = new Dictionary<string, ScatterPlotDataRow>();
|
---|
55 | var problemData = Problem.ProblemData;
|
---|
56 |
|
---|
57 | //color datapoints acording to classes variable (be it double or string)
|
---|
58 | if (problemData.Dataset.VariableNames.Contains(classesName)) {
|
---|
59 | if (((Dataset) problemData.Dataset).VariableHasType<string>(classesName)) {
|
---|
60 | var classes = problemData.Dataset.GetStringValues(classesName).ToArray();
|
---|
61 | for (var i = 0; i < classes.Length; i++) {
|
---|
62 | if (!dataRowNames.ContainsKey(classes[i])) dataRowNames.Add(classes[i], new List<int>());
|
---|
63 | dataRowNames[classes[i]].Add(i);
|
---|
64 | }
|
---|
65 | }
|
---|
66 | else if (((Dataset) problemData.Dataset).VariableHasType<double>(classesName)) {
|
---|
67 | const int contours = 8;
|
---|
68 | Dictionary<int, string> contourMap;
|
---|
69 | IClusteringModel clusterModel;
|
---|
70 | double[][] borders;
|
---|
71 | CreateClusters(problemData, classesName, contours, out clusterModel, out contourMap, out borders);
|
---|
72 |
|
---|
73 | var contourorder = borders.Select((x, i) => new {x, i}).OrderBy(x => x.x[0]).Select(x => x.i).ToArray();
|
---|
74 | for (var i = 0; i < contours; i++) {
|
---|
75 | var c = contourorder[i];
|
---|
76 | var contourname = contourMap[c];
|
---|
77 | dataRowNames.Add(contourname, new List<int>());
|
---|
78 | dataRows.Add(contourname, new ScatterPlotDataRow(contourname, "", new List<Point2D<double>>()));
|
---|
79 | dataRows[contourname].VisualProperties.Color = GetHeatMapColor(i, contours);
|
---|
80 | dataRows[contourname].VisualProperties.PointSize = i + 3;
|
---|
81 | }
|
---|
82 |
|
---|
83 | var allClusters = clusterModel.GetClusterValues(problemData.Dataset, problemData.AllIndices).ToArray();
|
---|
84 | for (var i = 0; i < problemData.Dataset.Rows; i++) dataRowNames[contourMap[allClusters[i] - 1]].Add(i);
|
---|
85 | }
|
---|
86 | }
|
---|
87 | else {
|
---|
88 | dataRowNames.Add("Training", problemData.TrainingIndices.ToList());
|
---|
89 | dataRowNames.Add("Test", problemData.TestIndices.ToList());
|
---|
90 | }
|
---|
91 |
|
---|
92 | var plotName = "PjoctedDataPlot";
|
---|
93 | var plot = new ScatterPlot(plotName, "");
|
---|
94 | results.Add(new Result(plotName, "Plot of the projected data", plot));
|
---|
95 | foreach (var rowName in dataRowNames.Keys) {
|
---|
96 | if (!plot.Rows.ContainsKey(rowName))
|
---|
97 | plot.Rows.Add(dataRows.ContainsKey(rowName) ? dataRows[rowName] : new ScatterPlotDataRow(rowName, "", new List<Point2D<double>>()));
|
---|
98 | plot.Rows[rowName].Points.Replace(dataRowNames[rowName].Select(i => new Point2D<double>(lowDimData[i, 0], lowDimData[i, 1 % lowDimData.GetLength(1)])));
|
---|
99 | }
|
---|
100 | }
|
---|
101 |
|
---|
102 | private static void CreateClusters(IDataAnalysisProblemData data, string target, int contours, out IClusteringModel contourCluster, out Dictionary<int, string> contourNames, out double[][] borders) {
|
---|
103 | contourCluster = KMeansClustering.CreateKMeansSolution(new ClusteringProblemData((Dataset) data.Dataset, new[] {target}), contours, 3).Model;
|
---|
104 |
|
---|
105 | borders = Enumerable.Range(0, contours).Select(x => new[] {double.MaxValue, double.MinValue}).ToArray();
|
---|
106 | var clusters = contourCluster.GetClusterValues(data.Dataset, data.AllIndices).ToArray();
|
---|
107 | var targetvalues = data.Dataset.GetDoubleValues(target).ToArray();
|
---|
108 | foreach (var i in data.AllIndices) {
|
---|
109 | var cl = clusters[i] - 1;
|
---|
110 | var clv = targetvalues[i];
|
---|
111 | if (borders[cl][0] > clv) borders[cl][0] = clv;
|
---|
112 | if (borders[cl][1] < clv) borders[cl][1] = clv;
|
---|
113 | }
|
---|
114 |
|
---|
115 | contourNames = new Dictionary<int, string>();
|
---|
116 | for (var i = 0; i < contours; i++)
|
---|
117 | contourNames.Add(i, "[" + borders[i][0] + ";" + borders[i][1] + "]");
|
---|
118 | }
|
---|
119 |
|
---|
120 | private static Color GetHeatMapColor(int contourNr, int noContours) {
|
---|
121 | var hue = contourNr / (double) noContours;
|
---|
122 | const double saturation = 0.999f;
|
---|
123 | const double value = 0.999f;
|
---|
124 |
|
---|
125 | if (hue > 0.999f) { hue = 0.999f; }
|
---|
126 | if (hue < 0.001f) { hue = 0.001f; }
|
---|
127 |
|
---|
128 | var h6 = hue * 6f;
|
---|
129 | if (h6.IsAlmost(6f)) { h6 = 0f; }
|
---|
130 | var ihue = (int) h6;
|
---|
131 | var p = value * (1f - saturation);
|
---|
132 | var q = value * (1f - saturation * (h6 - ihue));
|
---|
133 | var t = value * (1f - saturation * (1f - (h6 - ihue)));
|
---|
134 | switch (ihue) {
|
---|
135 | case 0: return Color.FromArgb((int) (value * 255), (int) (t * 255), (int) (p * 255));
|
---|
136 | case 1: return Color.FromArgb((int) (q * 255), (int) (value * 255), (int) (p * 255));
|
---|
137 | case 2: return Color.FromArgb((int) (p * 255), (int) (value * 255), (int) (t * 255));
|
---|
138 | case 3: return Color.FromArgb((int) (p * 255), (int) (q * 255), (int) (value * 255));
|
---|
139 | case 4: return Color.FromArgb((int) (t * 255), (int) (p * 255), (int) (value * 255));
|
---|
140 | default: return Color.FromArgb((int) (value * 255), (int) (p * 255), (int) (q * 255));
|
---|
141 | }
|
---|
142 | }
|
---|
143 | #endregion
|
---|
144 | }
|
---|
145 | } |
---|