Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/09/12 02:17:57 (12 years ago)
Author:
abeham
Message:

#1913:

  • Refactored NCAModel and NeighborhoodComponentsAnalysis algorithm
  • Model now includes NearestNeighborModel
  • Algorithm has ability to be canceled (basically recreated the optimization loop of mincgoptimize)
  • Scaling should work properly now
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/NCA/HeuristicLab.Algorithms.NCA/3.3/NCAModel.cs

    r8441 r8454  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
     
    3534
    3635    [Storable]
    37     private string targetVariable;
    38     [Storable]
    39     private string[] allowedInputVariables;
    40     [Storable]
    41     private double[] classValues;
    42     /// <summary>
    43     /// Get a clone of the class values
    44     /// </summary>
    45     public double[] ClassValues {
    46       get { return (double[])classValues.Clone(); }
    47     }
    48     [Storable]
    49     private int k;
     36    private Scaling scaling;
    5037    [Storable]
    5138    private double[,] transformationMatrix;
    52     /// <summary>
    53     /// Get a clone of the transformation matrix
    54     /// </summary>
    5539    public double[,] TransformationMatrix {
    5640      get { return (double[,])transformationMatrix.Clone(); }
    5741    }
    5842    [Storable]
    59     private double[,] transformedTrainingset;
    60     /// <summary>
    61     /// Get a clone of the transformed trainingset
    62     /// </summary>
    63     public double[,] TransformedTrainingset {
    64       get { return (double[,])transformedTrainingset.Clone(); }
    65     }
     43    private string[] allowedInputVariables;
    6644    [Storable]
    67     private Scaling scaling;
     45    private string targetVariable;
     46    [Storable]
     47    private INearestNeighbourModel nnModel;
     48    [Storable]
     49    private Dictionary<double, double> nn2ncaClassMapping;
     50    [Storable]
     51    private Dictionary<double, double> nca2nnClassMapping;
    6852
    6953    [StorableConstructor]
     
    7155    protected NCAModel(NCAModel original, Cloner cloner)
    7256      : base(original, cloner) {
    73       k = original.k;
    74       targetVariable = original.targetVariable;
    75       allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    76       if (original.classValues != null)
    77         this.classValues = (double[])original.classValues.Clone();
    78       if (original.transformationMatrix != null)
    79         this.transformationMatrix = (double[,])original.transformationMatrix.Clone();
    80       if (original.transformedTrainingset != null)
    81         this.transformedTrainingset = (double[,])original.transformedTrainingset.Clone();
    8257      this.scaling = cloner.Clone(original.scaling);
     58      this.transformationMatrix = (double[,])original.transformationMatrix.Clone();
     59      this.allowedInputVariables = (string[])original.allowedInputVariables.Clone();
     60      this.targetVariable = original.targetVariable;
     61      this.nnModel = cloner.Clone(original.nnModel);
     62      this.nn2ncaClassMapping = original.nn2ncaClassMapping.ToDictionary(x => x.Key, y => y.Value);
     63      this.nca2nnClassMapping = original.nca2nnClassMapping.ToDictionary(x => x.Key, y => y.Value);
    8364    }
    84     public NCAModel(double[,] transformedTrainingset, Scaling scaling, double[,] transformationMatrix, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
    85       : base() {
    86       this.name = ItemName;
    87       this.description = ItemDescription;
    88       this.transformedTrainingset = transformedTrainingset;
     65    public NCAModel(int k, double[,] scaledData, Scaling scaling, double[,] transformationMatrix, string targetVariable, IEnumerable<double> targetVector, IEnumerable<string> allowedInputVariables) {
     66      Name = ItemName;
     67      Description = ItemDescription;
    8968      this.scaling = scaling;
    9069      this.transformationMatrix = transformationMatrix;
    91       this.k = k;
     70      this.allowedInputVariables = allowedInputVariables.ToArray();
    9271      this.targetVariable = targetVariable;
    93       this.allowedInputVariables = allowedInputVariables.ToArray();
    94       if (classValues != null)
    95         this.classValues = (double[])classValues.Clone();
     72
     73      nca2nnClassMapping = targetVector.Distinct().OrderBy(x => x).Select((v, i) => new { Index = (double)i, Class = v }).ToDictionary(x => x.Class, y => y.Index);
     74      nn2ncaClassMapping = nca2nnClassMapping.ToDictionary(x => x.Value, y => y.Key);
     75
     76      var transformedData = ReduceWithTarget(scaledData, targetVector.Select(x => nca2nnClassMapping[x]));
     77
     78      var kdtree = new alglib.nearestneighbor.kdtree();
     79      alglib.nearestneighbor.kdtreebuild(transformedData, transformedData.GetLength(0), transformedData.GetLength(1) - 1, 1, 2, kdtree);
     80
     81      nnModel = new NearestNeighbourModel(kdtree, k, targetVariable,
     82        Enumerable.Range(0, transformationMatrix.GetLength(1)).Select(x => x.ToString()),
     83        nn2ncaClassMapping.Keys.ToArray());
    9684    }
    9785
     
    10189
    10290    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
    103       var k = Math.Min(this.k, transformedTrainingset.GetLength(0));
    104       var transformedRow = new double[transformationMatrix.GetLength(1)];
    105       var kVotes = new SortedList<double, double>(k + 1);
    106       foreach (var r in rows) {
    107         for (int i = 0; i < transformedRow.Length; i++) transformedRow[i] = 0;
    108         int j = 0;
    109         foreach (var v in allowedInputVariables) {
    110           var values = scaling.GetScaledValues(dataset, v, rows);
    111           double val = dataset.GetDoubleValue(v, r);
    112           for (int i = 0; i < transformedRow.Length; i++)
    113             transformedRow[i] += val * transformationMatrix[j, i];
    114           j++;
     91      var unknownClasses = dataset.GetDoubleValues(targetVariable, rows).Where(x => !nca2nnClassMapping.ContainsKey(x));
     92      if (unknownClasses.Any())
     93        foreach (var uc in unknownClasses) {
     94          nca2nnClassMapping[uc] = nca2nnClassMapping.Count;
     95          nn2ncaClassMapping[nca2nnClassMapping[uc]] = uc;
    11596        }
    116         kVotes.Clear();
    117         for (int a = 0; a < transformedTrainingset.GetLength(0); a++) {
    118           double d = 0;
    119           for (int y = 0; y < transformedRow.Length; y++) {
    120             d += (transformedRow[y] - transformedTrainingset[a, y]) * (transformedRow[y] - transformedTrainingset[a, y]);
    121           }
    122           while (kVotes.ContainsKey(d)) d += 1e-12;
    123           if (kVotes.Count <= k || kVotes.Last().Key > d) {
    124             kVotes.Add(d, classValues[a]);
    125             if (kVotes.Count > k) kVotes.RemoveAt(kVotes.Count - 1);
    126           }
    127         }
    128         yield return kVotes.Values.ToLookup(x => x).MaxItems(x => x.Count()).First().Key;
    129       }
     97      var transformedData = ReduceWithTarget(dataset, rows, dataset.GetDoubleValues(targetVariable, rows).Select(x => nca2nnClassMapping[x]));
     98      var ds = new Dataset(Enumerable.Range(0, transformationMatrix.GetLength(1)).Select(x => x.ToString()).Concat(targetVariable.ToEnumerable()), transformedData);
     99      return nnModel.GetEstimatedClassValues(ds, Enumerable.Range(0, ds.Rows)).Select(x => nn2ncaClassMapping[x]);
    130100    }
     101
    131102    public NCAClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
    132103      return new NCAClassificationSolution(problemData, this);
    133104    }
     105
    134106    IClassificationSolution IClassificationModel.CreateClassificationSolution(IClassificationProblemData problemData) {
    135107      return CreateClassificationSolution(problemData);
     
    137109
    138110    public double[,] Reduce(Dataset dataset, IEnumerable<int> rows) {
    139       var result = new double[rows.Count(), transformationMatrix.GetLength(1)];
    140       int v = 0;
    141       foreach (var r in rows) {
    142         int i = 0;
    143         foreach (var variable in allowedInputVariables) {
    144           double val = dataset.GetDoubleValue(variable, r);
    145           for (int j = 0; j < result.GetLength(1); j++)
    146             result[v, j] += val * transformationMatrix[i, j];
    147           i++;
    148         }
    149         v++;
    150       }
     111      var scaledData = AlglibUtil.PrepareAndScaleInputMatrix(dataset, allowedInputVariables, rows, scaling);
     112      return Reduce(scaledData);
     113    }
     114
     115    private double[,] Reduce(double[,] scaledData) {
     116      var result = new double[scaledData.GetLength(0), transformationMatrix.GetLength(1)];
     117      for (int i = 0; i < scaledData.GetLength(0); i++)
     118        for (int j = 0; j < scaledData.GetLength(1); j++)
     119          for (int x = 0; x < transformationMatrix.GetLength(1); x++) {
     120            result[i, x] += scaledData[i, j] * transformationMatrix[j, x];
     121          }
     122      return result;
     123    }
     124
     125    private double[,] ReduceWithTarget(Dataset dataset, IEnumerable<int> rows, IEnumerable<double> targetValues) {
     126      var scaledData = AlglibUtil.PrepareAndScaleInputMatrix(dataset, allowedInputVariables, rows, scaling);
     127      return ReduceWithTarget(scaledData, targetValues);
     128    }
     129
     130    private double[,] ReduceWithTarget(double[,] scaledData, IEnumerable<double> targetValues) {
     131      var result = new double[scaledData.GetLength(0), transformationMatrix.GetLength(1) + 1];
     132      for (int i = 0; i < scaledData.GetLength(0); i++)
     133        for (int j = 0; j < scaledData.GetLength(1); j++)
     134          for (int x = 0; x < transformationMatrix.GetLength(1); x++) {
     135            result[i, x] += scaledData[i, j] * transformationMatrix[j, x];
     136          }
     137
     138      int r = 0;
     139      foreach (var d in targetValues) result[r++, transformationMatrix.GetLength(1)] = d;
     140
    151141      return result;
    152142    }
Note: See TracChangeset for help on using the changeset viewer.