Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs @ 5678

Last change on this file since 5678 was 5678, checked in by gkronber, 13 years ago

#1418 Worked on calculation of thresholds for classification solutions based on discriminant functions.

File size: 9.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using System.Collections.Generic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Classification;
36
37namespace HeuristicLab.Algorithms.DataAnalysis {
38  /// <summary>
39  /// Linear discriminant analysis classification algorithm.
40  /// </summary>
41  [Item("Linear Discriminant Analysis", "Linear discriminant analysis classification algorithm.")]
42  [Creatable("Data Analysis")]
43  [StorableClass]
44  public sealed class LinearDiscriminantAnalysis : FixedDataAnalysisAlgorithm<IClassificationProblem> {
45    private const string LinearDiscriminantAnalysisSolutionResultName = "Linear discriminant analysis solution";
46
47    [StorableConstructor]
48    private LinearDiscriminantAnalysis(bool deserializing) : base(deserializing) { }
49    private LinearDiscriminantAnalysis(LinearDiscriminantAnalysis original, Cloner cloner)
50      : base(original, cloner) {
51    }
52    public LinearDiscriminantAnalysis()
53      : base() {
54      Problem = new ClassificationProblem();
55    }
56    [StorableHook(HookType.AfterDeserialization)]
57    private void AfterDeserialization() { }
58
59    public override IDeepCloneable Clone(Cloner cloner) {
60      return new LinearDiscriminantAnalysis(this, cloner);
61    }
62
63    #region Fisher LDA
64    protected override void Run() {
65      var solution = CreateLinearDiscriminantAnalysisSolution(Problem.ProblemData);
66      Results.Add(new Result(LinearDiscriminantAnalysisSolutionResultName, "The linear discriminant analysis.", solution));
67    }
68
69    public static IClassificationSolution CreateLinearDiscriminantAnalysisSolution(IClassificationProblemData problemData) {
70      Dataset dataset = problemData.Dataset;
71      string targetVariable = problemData.TargetVariable;
72      IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;
73      int samplesStart = problemData.TrainingPartitionStart.Value;
74      int samplesEnd = problemData.TrainingPartitionEnd.Value;
75      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
76      int nClasses = problemData.ClassNames.Count();
77      double[,] inputMatrix = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables.Concat(new string[] { targetVariable }), rows);
78
79      // change class values into class index
80      int targetVariableColumn = inputMatrix.GetLength(1) - 1;
81      List<double> classValues = problemData.ClassValues.OrderBy(x => x).ToList();
82      for (int row = 0; row < inputMatrix.GetLength(0); row++) {
83        inputMatrix[row, targetVariableColumn] = classValues.IndexOf(inputMatrix[row, targetVariableColumn]);
84      }
85      int info;
86      double[] w;
87      alglib.fisherlda(inputMatrix, inputMatrix.GetLength(0), allowedInputVariables.Count(), nClasses, out info, out w);
88      if (info < 1) throw new ArgumentException("Error in calculation of linear discriminant analysis solution");
89
90      ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());
91      ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();
92      tree.Root.AddSubTree(startNode);
93      ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();
94      startNode.AddSubTree(addition);
95
96      int col = 0;
97      foreach (string column in allowedInputVariables) {
98        VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
99        vNode.VariableName = column;
100        vNode.Weight = w[col];
101        addition.AddSubTree(vNode);
102        col++;
103      }
104
105      ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode();
106      cNode.Value = w[w.Length - 1];
107      addition.AddSubTree(cNode);
108
109
110      var model = LinearDiscriminantAnalysis.CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), problemData, rows);
111      SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, problemData);
112
113      return solution;
114    }
115    #endregion
116
117    private static SymbolicDiscriminantFunctionClassificationModel CreateDiscriminantFunctionModel(ISymbolicExpressionTree tree,
118      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
119      IClassificationProblemData problemData,
120      IEnumerable<int> rows) {
121      string targetVariable = problemData.TargetVariable;
122      List<double> originalClasses = problemData.ClassValues.ToList();
123      int nClasses = problemData.Classes;
124      List<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows).ToList();
125      double maxEstimatedValue = estimatedValues.Max();
126      double minEstimatedValue = estimatedValues.Min();
127      var estimatedTargetValues =
128         (from row in problemData.TrainingIndizes
129          select new { EstimatedValue = estimatedValues[row], TargetValue = problemData.Dataset[targetVariable, row] })
130         .ToList();
131
132      Dictionary<double, double> classMean = new Dictionary<double, double>();
133      Dictionary<double, double> classStdDev = new Dictionary<double, double>();
134      // calculate moments per class
135      foreach (var classValue in originalClasses) {
136        var estimatedValuesForClass = from x in estimatedTargetValues
137                                      where x.TargetValue == classValue
138                                      select x.EstimatedValue;
139        double mean, variance;
140        OnlineMeanAndVarianceCalculator.Calculate(estimatedValuesForClass, out mean, out variance);
141        classMean[classValue] = mean;
142        classStdDev[classValue] = Math.Sqrt(variance);
143      }
144      List<double> thresholds = new List<double>();
145      for (int i = 0; i < nClasses - 1; i++) {
146        for (int j = i + 1; j < nClasses; j++) {
147          double x1, x2;
148          double class0 = originalClasses[i];
149          double class1 = originalClasses[j];
150          // calculate all thresholds
151          CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
152          if (!thresholds.Any(x => x.IsAlmost(x1))) thresholds.Add(x1);
153          if (!thresholds.Any(x => x.IsAlmost(x2))) thresholds.Add(x2);
154        }
155      }
156      thresholds.Sort();
157      thresholds.Insert(0, double.NegativeInfinity);
158      thresholds.Add(double.PositiveInfinity);
159      List<double> classValues = new List<double>();
160      for (int i = 0; i < thresholds.Count - 1; i++) {
161        double m;
162        if (double.IsNegativeInfinity(thresholds[i])) {
163          m = thresholds[i + 1] - 1.0;
164        } else if (double.IsPositiveInfinity(thresholds[i + 1])) {
165          m = thresholds[i] + 1.0;
166        } else {
167          m = thresholds[i] + (thresholds[i + 1] - thresholds[i]) / 2.0;
168        }
169
170        double maxDensity = 0;
171        double maxDensityClassValue = -1;
172        foreach (var classValue in originalClasses) {
173          double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
174          if (density > maxDensity) {
175            maxDensity = density;
176            maxDensityClassValue = classValue;
177          }
178        }
179        classValues.Add(maxDensityClassValue);
180      }
181      List<double> filteredThresholds = new List<double>();
182      List<double> filteredClassValues = new List<double>();
183      filteredThresholds.Add(thresholds[0]);
184      filteredClassValues.Add(classValues[0]);
185      for (int i = 0; i < classValues.Count - 1; i++) {
186        if (classValues[i] != classValues[i + 1]) {
187          filteredThresholds.Add(thresholds[i + 1]);
188          filteredClassValues.Add(classValues[i + 1]);
189        }
190      }
191      filteredThresholds.Add(double.PositiveInfinity);
192
193      return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, filteredClassValues, filteredThresholds);
194    }
195
196    private static double NormalDensity(double x, double mu, double sigma) {
197      return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
198    }
199
200    private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
201      double a = (s1 * s1 - s2 * s2);
202      double b = (m1 * s2 * s2 - m2 * s1 * s1);
203      double c = 2 * s1 * s1 * s2 * s2 * Math.Log(s2) - 2 * s1 * s1 * s2 * s2 * Math.Log(s1) - s1 * s1 * m2 * m2 + s2 * s2 * m1 * m1;
204      x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
205      x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
206    }
207  }
208}
Note: See TracBrowser for help on using the repository browser.