Free cookie consent management tool by TermsFeed Policy Generator

source: branches/NCA/HeuristicLab.Algorithms.NCA/3.3/NCAModel.cs @ 8441

Last change on this file since 8441 was 8441, checked in by abeham, 12 years ago

#1913: added quality output

File size: 6.0 KB
RevLine 
[8412]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[8441]25using HeuristicLab.Algorithms.DataAnalysis;
[8412]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.NCA {
32  [Item("NCAModel", "")]
33  [StorableClass]
34  public class NCAModel : NamedItem, INCAModel {
35
36    [Storable]
37    private string targetVariable;
38    [Storable]
39    private string[] allowedInputVariables;
40    [Storable]
41    private double[] classValues;
[8420]42    /// <summary>
43    /// Get a clone of the class values
44    /// </summary>
45    public double[] ClassValues {
46      get { return (double[])classValues.Clone(); }
47    }
[8412]48    [Storable]
49    private int k;
50    [Storable]
51    private double[,] transformationMatrix;
52    /// <summary>
53    /// Get a clone of the transformation matrix
54    /// </summary>
55    public double[,] TransformationMatrix {
56      get { return (double[,])transformationMatrix.Clone(); }
57    }
58    [Storable]
59    private double[,] transformedTrainingset;
60    /// <summary>
61    /// Get a clone of the transformed trainingset
62    /// </summary>
63    public double[,] TransformedTrainingset {
64      get { return (double[,])transformedTrainingset.Clone(); }
65    }
[8441]66    [Storable]
67    private Scaling scaling;
[8412]68
69    [StorableConstructor]
70    protected NCAModel(bool deserializing) : base(deserializing) { }
71    protected NCAModel(NCAModel original, Cloner cloner)
72      : base(original, cloner) {
73      k = original.k;
74      targetVariable = original.targetVariable;
75      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
76      if (original.classValues != null)
77        this.classValues = (double[])original.classValues.Clone();
78      if (original.transformationMatrix != null)
79        this.transformationMatrix = (double[,])original.transformationMatrix.Clone();
80      if (original.transformedTrainingset != null)
81        this.transformedTrainingset = (double[,])original.transformedTrainingset.Clone();
[8441]82      this.scaling = cloner.Clone(original.scaling);
[8412]83    }
[8441]84    public NCAModel(double[,] transformedTrainingset, Scaling scaling, double[,] transformationMatrix, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
[8412]85      : base() {
86      this.name = ItemName;
87      this.description = ItemDescription;
88      this.transformedTrainingset = transformedTrainingset;
[8441]89      this.scaling = scaling;
[8412]90      this.transformationMatrix = transformationMatrix;
91      this.k = k;
92      this.targetVariable = targetVariable;
93      this.allowedInputVariables = allowedInputVariables.ToArray();
94      if (classValues != null)
95        this.classValues = (double[])classValues.Clone();
96    }
97
98    public override IDeepCloneable Clone(Cloner cloner) {
99      return new NCAModel(this, cloner);
100    }
101
102    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
[8441]103      var k = Math.Min(this.k, transformedTrainingset.GetLength(0));
104      var transformedRow = new double[transformationMatrix.GetLength(1)];
[8412]105      var kVotes = new SortedList<double, double>(k + 1);
106      foreach (var r in rows) {
107        for (int i = 0; i < transformedRow.Length; i++) transformedRow[i] = 0;
108        int j = 0;
109        foreach (var v in allowedInputVariables) {
[8441]110          var values = scaling.GetScaledValues(dataset, v, rows);
[8412]111          double val = dataset.GetDoubleValue(v, r);
112          for (int i = 0; i < transformedRow.Length; i++)
113            transformedRow[i] += val * transformationMatrix[j, i];
114          j++;
115        }
116        kVotes.Clear();
117        for (int a = 0; a < transformedTrainingset.GetLength(0); a++) {
118          double d = 0;
119          for (int y = 0; y < transformedRow.Length; y++) {
120            d += (transformedRow[y] - transformedTrainingset[a, y]) * (transformedRow[y] - transformedTrainingset[a, y]);
121          }
122          while (kVotes.ContainsKey(d)) d += 1e-12;
123          if (kVotes.Count <= k || kVotes.Last().Key > d) {
124            kVotes.Add(d, classValues[a]);
125            if (kVotes.Count > k) kVotes.RemoveAt(kVotes.Count - 1);
126          }
127        }
128        yield return kVotes.Values.ToLookup(x => x).MaxItems(x => x.Count()).First().Key;
129      }
130    }
131    public NCAClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
132      return new NCAClassificationSolution(problemData, this);
133    }
134    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
135      return CreateClassificationSolution(problemData);
136    }
[8437]137
138    public double[,] Reduce(Dataset dataset, IEnumerable<int> rows) {
139      var result = new double[rows.Count(), transformationMatrix.GetLength(1)];
140      int v = 0;
141      foreach (var r in rows) {
142        int i = 0;
143        foreach (var variable in allowedInputVariables) {
144          double val = dataset.GetDoubleValue(variable, r);
145          for (int j = 0; j < result.GetLength(1); j++)
146            result[v, j] += val * transformationMatrix[i, j];
147          i++;
148        }
149        v++;
150      }
151      return result;
152    }
[8412]153  }
154}
Note: See TracBrowser for help on using the repository browser.