Free cookie consent management tool by TermsFeed Policy Generator

source: branches/NCA/HeuristicLab.Algorithms.NCA/3.3/NCAModel.cs @ 8425

Last change on this file since 8425 was 8420, checked in by abeham, 12 years ago

#1913:

  • Worked on NCA
  • Added scatter plot view for the model to show training data when it is reduced to two dimensions

It works, but I don't think it works correctly yet. I have randomized the initial matrix, because the starting point influences the achievable quality quite a bit.

File size: 5.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.NCA {
31  [Item("NCAModel", "")]
32  [StorableClass]
33  public class NCAModel : NamedItem, INCAModel {
34
35    [Storable]
36    private string targetVariable;
37    [Storable]
38    private string[] allowedInputVariables;
39    [Storable]
40    private double[] classValues;
41    /// <summary>
42    /// Get a clone of the class values
43    /// </summary>
44    public double[] ClassValues {
45      get { return (double[])classValues.Clone(); }
46    }
47    [Storable]
48    private int k;
49    [Storable]
50    private double[,] transformationMatrix;
51    /// <summary>
52    /// Get a clone of the transformation matrix
53    /// </summary>
54    public double[,] TransformationMatrix {
55      get { return (double[,])transformationMatrix.Clone(); }
56    }
57    [Storable]
58    private double[,] transformedTrainingset;
59    /// <summary>
60    /// Get a clone of the transformed trainingset
61    /// </summary>
62    public double[,] TransformedTrainingset {
63      get { return (double[,])transformedTrainingset.Clone(); }
64    }
65
66    [StorableConstructor]
67    protected NCAModel(bool deserializing) : base(deserializing) { }
68    protected NCAModel(NCAModel original, Cloner cloner)
69      : base(original, cloner) {
70      k = original.k;
71      targetVariable = original.targetVariable;
72      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
73      if (original.classValues != null)
74        this.classValues = (double[])original.classValues.Clone();
75      if (original.transformationMatrix != null)
76        this.transformationMatrix = (double[,])original.transformationMatrix.Clone();
77      if (original.transformedTrainingset != null)
78        this.transformedTrainingset = (double[,])original.transformedTrainingset.Clone();
79    }
80    public NCAModel(double[,] transformedTrainingset, double[,] transformationMatrix, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
81      : base() {
82      this.name = ItemName;
83      this.description = ItemDescription;
84      this.transformedTrainingset = transformedTrainingset;
85      this.transformationMatrix = transformationMatrix;
86      this.k = k;
87      this.targetVariable = targetVariable;
88      this.allowedInputVariables = allowedInputVariables.ToArray();
89      if (classValues != null)
90        this.classValues = (double[])classValues.Clone();
91    }
92
93    public override IDeepCloneable Clone(Cloner cloner) {
94      return new NCAModel(this, cloner);
95    }
96
97    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
98      int k = Math.Min(this.k, transformedTrainingset.GetLength(0));
99      double[] transformedRow = new double[transformationMatrix.GetLength(1)];
100      var kVotes = new SortedList<double, double>(k + 1);
101      foreach (var r in rows) {
102        for (int i = 0; i < transformedRow.Length; i++) transformedRow[i] = 0;
103        int j = 0;
104        foreach (var v in allowedInputVariables) {
105          double val = dataset.GetDoubleValue(v, r);
106          for (int i = 0; i < transformedRow.Length; i++)
107            transformedRow[i] += val * transformationMatrix[j, i];
108          j++;
109        }
110        kVotes.Clear();
111        for (int a = 0; a < transformedTrainingset.GetLength(0); a++) {
112          double d = 0;
113          for (int y = 0; y < transformedRow.Length; y++) {
114            d += (transformedRow[y] - transformedTrainingset[a, y]) * (transformedRow[y] - transformedTrainingset[a, y]);
115          }
116          while (kVotes.ContainsKey(d)) d += 1e-12;
117          if (kVotes.Count <= k || kVotes.Last().Key > d) {
118            kVotes.Add(d, classValues[a]);
119            if (kVotes.Count > k) kVotes.RemoveAt(kVotes.Count - 1);
120          }
121        }
122        yield return kVotes.Values.ToLookup(x => x).MaxItems(x => x.Count()).First().Key;
123      }
124    }
125    public NCAClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
126      return new NCAClassificationSolution(problemData, this);
127    }
128    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
129      return CreateClassificationSolution(problemData);
130    }
131  }
132}
Note: See TracBrowser for help on using the repository browser.