Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/26/16 18:25:24 (8 years ago)
Author:
gkronber
Message:

#2652: merged r14235 from trunk to stable

Location:
stable
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Algorithms.DataAnalysis

  • stable/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r14186 r14308  
    5858    [Storable]
    5959    private int k;
     60    [Storable(DefaultValue = null)]
     61    private double[] weights; // not set for old versions loaded from disk
     62    [Storable(DefaultValue = null)]
     63    private double[] offsets; // not set for old versions loaded from disk
    6064
    6165    [StorableConstructor]
     
    9397
    9498      k = original.k;
     99      isCompatibilityLoaded = original.IsCompatibilityLoaded;
     100      if (!IsCompatibilityLoaded) {
     101        weights = new double[original.weights.Length];
     102        Array.Copy(original.weights, weights, weights.Length);
     103        offsets = new double[original.offsets.Length];
     104        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
     105      }
    95106      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    96107      if (original.classValues != null)
    97108        this.classValues = (double[])original.classValues.Clone();
    98109    }
    99     public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     110    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
    100111      : base(targetVariable) {
    101112      Name = ItemName;
     
    103114      this.k = k;
    104115      this.allowedInputVariables = allowedInputVariables.ToArray();
    105 
    106       var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
    107                                    allowedInputVariables.Concat(new string[] { targetVariable }),
    108                                    rows);
     116      double[,] inputMatrix;
     117      if (IsCompatibilityLoaded) {
     118        // no scaling
     119        inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
     120          this.allowedInputVariables.Concat(new string[] { targetVariable }),
     121          rows);
     122      } else {
     123        this.offsets = this.allowedInputVariables
     124          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
     125          .Concat(new double[] { 0 }) // no offset for target variable
     126          .ToArray();
     127        if (weights == null) {
     128          // automatic determination of weights (all features should have variance = 1)
     129          this.weights = this.allowedInputVariables
     130            .Select(name => 1.0 / dataset.GetDoubleValues(name, rows).StandardDeviationPop())
     131            .Concat(new double[] { 1.0 }) // no scaling for target variable
     132            .ToArray();
     133        } else {
     134          // user specified weights (+ 1 for target)
     135          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
     136          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
     137            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
     138        }
     139        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
     140      }
    109141
    110142      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     
    132164    }
    133165
     166    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
     167      var x = new double[rows.Count(), variables.Count()];
     168      var colIdx = 0;
     169      foreach (var variableName in variables) {
     170        var rowIdx = 0;
     171        foreach (var val in dataset.GetDoubleValues(variableName, rows)) {
     172          x[rowIdx, colIdx] = (val + offsets[colIdx]) * factors[colIdx];
     173          rowIdx++;
     174        }
     175        colIdx++;
     176      }
     177      return x;
     178    }
     179
    134180    public override IDeepCloneable Clone(Cloner cloner) {
    135181      return new NearestNeighbourModel(this, cloner);
     
    137183
    138184    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    139       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     185      double[,] inputData;
     186      if (IsCompatibilityLoaded) {
     187        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     188      } else {
     189        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
     190      }
    140191
    141192      int n = inputData.GetLength(0);
    142193      int columns = inputData.GetLength(1);
    143194      double[] x = new double[columns];
    144       double[] y = new double[1];
    145195      double[] dists = new double[k];
    146196      double[,] neighbours = new double[k, columns + 1];
     
    152202        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
    153203        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
    154         alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
     204        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); // gkronber: this call changes the kdTree data structure
    155205
    156206        double distanceWeightedValue = 0.0;
     
    166216    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    167217      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
    168       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
    169 
     218      double[,] inputData;
     219      if (IsCompatibilityLoaded) {
     220        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     221      } else {
     222        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
     223      }
    170224      int n = inputData.GetLength(0);
    171225      int columns = inputData.GetLength(1);
     
    219273    #endregion
    220274
     275
     276    // BackwardsCompatibility3.3
     277    #region Backwards compatible code, remove with 3.4
     278
     279    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
     280    [Storable(DefaultValue = true)]
     281    public bool IsCompatibilityLoaded {
     282      get { return isCompatibilityLoaded; }
     283      set { isCompatibilityLoaded = value; }
     284    }
     285    #endregion
    221286    #region persistence
    222287    [Storable]
Note: See TracChangeset for help on using the changeset viewer.