Free cookie consent management tool by TermsFeed Policy Generator

source: branches/NCA/HeuristicLab.Algorithms.NCA/3.3/NCAModel.cs @ 8412

Last change on this file since 8412 was 8412, checked in by abeham, 12 years ago

#1913: imported branch (non-functional right now)

File size: 5.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.NCA {
31  [Item("NCAModel", "")]
32  [StorableClass]
33  public class NCAModel : NamedItem, INCAModel {
34
35    [Storable]
36    private string targetVariable;
37    [Storable]
38    private string[] allowedInputVariables;
39    [Storable]
40    private double[] classValues;
41    [Storable]
42    private int k;
43    [Storable]
44    private double[,] transformationMatrix;
45    /// <summary>
46    /// Get a clone of the transformation matrix
47    /// </summary>
48    public double[,] TransformationMatrix {
49      get { return (double[,])transformationMatrix.Clone(); }
50    }
51    [Storable]
52    private double[,] transformedTrainingset;
53    /// <summary>
54    /// Get a clone of the transformed trainingset
55    /// </summary>
56    public double[,] TransformedTrainingset {
57      get { return (double[,])transformedTrainingset.Clone(); }
58    }
59
60    [StorableConstructor]
61    protected NCAModel(bool deserializing) : base(deserializing) { }
62    protected NCAModel(NCAModel original, Cloner cloner)
63      : base(original, cloner) {
64      k = original.k;
65      targetVariable = original.targetVariable;
66      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
67      if (original.classValues != null)
68        this.classValues = (double[])original.classValues.Clone();
69      if (original.transformationMatrix != null)
70        this.transformationMatrix = (double[,])original.transformationMatrix.Clone();
71      if (original.transformedTrainingset != null)
72        this.transformedTrainingset = (double[,])original.transformedTrainingset.Clone();
73    }
74    public NCAModel(double[,] transformedTrainingset, double[,] transformationMatrix, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
75      : base() {
76      this.name = ItemName;
77      this.description = ItemDescription;
78      this.transformedTrainingset = transformedTrainingset;
79      this.transformationMatrix = transformationMatrix;
80      this.k = k;
81      this.targetVariable = targetVariable;
82      this.allowedInputVariables = allowedInputVariables.ToArray();
83      if (classValues != null)
84        this.classValues = (double[])classValues.Clone();
85    }
86
87    public override IDeepCloneable Clone(Cloner cloner) {
88      return new NCAModel(this, cloner);
89    }
90
91    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
92      int k = Math.Min(this.k, transformedTrainingset.GetLength(0));
93      double[] transformedRow = new double[transformationMatrix.GetLength(1)];
94      var kVotes = new SortedList<double, double>(k + 1);
95      foreach (var r in rows) {
96        for (int i = 0; i < transformedRow.Length; i++) transformedRow[i] = 0;
97        int j = 0;
98        foreach (var v in allowedInputVariables) {
99          double val = dataset.GetDoubleValue(v, r);
100          for (int i = 0; i < transformedRow.Length; i++)
101            transformedRow[i] += val * transformationMatrix[j, i];
102          j++;
103        }
104        kVotes.Clear();
105        for (int a = 0; a < transformedTrainingset.GetLength(0); a++) {
106          double d = 0;
107          for (int y = 0; y < transformedRow.Length; y++) {
108            d += (transformedRow[y] - transformedTrainingset[a, y]) * (transformedRow[y] - transformedTrainingset[a, y]);
109          }
110          while (kVotes.ContainsKey(d)) d += 1e-12;
111          if (kVotes.Count <= k || kVotes.Last().Key > d) {
112            kVotes.Add(d, classValues[a]);
113            if (kVotes.Count > k) kVotes.RemoveAt(kVotes.Count - 1);
114          }
115        }
116        yield return kVotes.Values.ToLookup(x => x).MaxItems(x => x.Count()).First().Key;
117      }
118    }
119    public NCAClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
120      return new NCAClassificationSolution(problemData, this);
121    }
122    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
123      return CreateClassificationSolution(problemData);
124    }
125  }
126}
Note: See TracBrowser for help on using the repository browser.