Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/19/21 16:07:45 (2 years ago)
Author:
mkommend
Message:

#2521: Merged trunk changes into branch.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r17226 r18086  
    3232  /// Represents a nearest neighbour model for regression and classification
    3333  /// </summary>
    34   [StorableType("A76C0823-3077-4ACE-8A40-E9B717C7DB60")]
     34  [StorableType("04A07DF6-6EB5-4D29-B7AE-5BE204CAF6BC")]
    3535  [Item("NearestNeighbourModel", "Represents a nearest neighbour model for regression and classification.")]
    3636  public sealed class NearestNeighbourModel : ClassificationModel, INearestNeighbourModel {
    3737
    38     private readonly object kdTreeLockObject = new object();
    39 
    40     private alglib.nearestneighbor.kdtree kdTree;
    41     public alglib.nearestneighbor.kdtree KDTree {
    42       get { return kdTree; }
    43       set {
    44         if (value != kdTree) {
    45           if (value == null) throw new ArgumentNullException();
    46           kdTree = value;
    47           OnChanged(EventArgs.Empty);
    48         }
    49       }
     38    private alglib.knnmodel model;
     39    [Storable]
     40    private string SerializedModel {
     41      get { alglib.knnserialize(model, out var ser); return ser; }
     42      set { if (value != null) alglib.knnunserialize(value, out model); }
    5043    }
    5144
     
    6053    [Storable]
    6154    private int k;
    62     [Storable(DefaultValue = false)]
    63     private bool selfMatch;
    64     [Storable(DefaultValue = null)]
    65     private double[] weights; // not set for old versions loaded from disk
    66     [Storable(DefaultValue = null)]
    67     private double[] offsets; // not set for old versions loaded from disk
     55    [Storable]
     56    private double[] weights;
     57    [Storable]
     58    private double[] offsets;
    6859
    6960    [StorableConstructor]
    70     private NearestNeighbourModel(StorableConstructorFlag _) : base(_) {
    71       kdTree = new alglib.nearestneighbor.kdtree();
    72     }
     61    private NearestNeighbourModel(StorableConstructorFlag _) : base(_) { }
    7362    private NearestNeighbourModel(NearestNeighbourModel original, Cloner cloner)
    7463      : base(original, cloner) {
    75       kdTree = new alglib.nearestneighbor.kdtree();
    76       kdTree.approxf = original.kdTree.approxf;
    77       kdTree.boxmax = (double[])original.kdTree.boxmax.Clone();
    78       kdTree.boxmin = (double[])original.kdTree.boxmin.Clone();
    79       kdTree.buf = (double[])original.kdTree.buf.Clone();
    80       kdTree.curboxmax = (double[])original.kdTree.curboxmax.Clone();
    81       kdTree.curboxmin = (double[])original.kdTree.curboxmin.Clone();
    82       kdTree.curdist = original.kdTree.curdist;
    83       kdTree.debugcounter = original.kdTree.debugcounter;
    84       kdTree.idx = (int[])original.kdTree.idx.Clone();
    85       kdTree.kcur = original.kdTree.kcur;
    86       kdTree.kneeded = original.kdTree.kneeded;
    87       kdTree.n = original.kdTree.n;
    88       kdTree.nodes = (int[])original.kdTree.nodes.Clone();
    89       kdTree.normtype = original.kdTree.normtype;
    90       kdTree.nx = original.kdTree.nx;
    91       kdTree.ny = original.kdTree.ny;
    92       kdTree.r = (double[])original.kdTree.r.Clone();
    93       kdTree.rneeded = original.kdTree.rneeded;
    94       kdTree.selfmatch = original.kdTree.selfmatch;
    95       kdTree.splits = (double[])original.kdTree.splits.Clone();
    96       kdTree.tags = (int[])original.kdTree.tags.Clone();
    97       kdTree.x = (double[])original.kdTree.x.Clone();
    98       kdTree.xy = (double[,])original.kdTree.xy.Clone();
    99       selfMatch = original.selfMatch;
     64      if (original.model != null)
     65        model = (alglib.knnmodel)original.model.make_copy();
    10066      k = original.k;
    101       isCompatibilityLoaded = original.IsCompatibilityLoaded;
    102       if (!IsCompatibilityLoaded) {
    103         weights = new double[original.weights.Length];
    104         Array.Copy(original.weights, weights, weights.Length);
    105         offsets = new double[original.offsets.Length];
    106         Array.Copy(original.offsets, this.offsets, this.offsets.Length);
    107       }
     67      weights = new double[original.weights.Length];
     68      Array.Copy(original.weights, weights, weights.Length);
     69      offsets = new double[original.offsets.Length];
     70      Array.Copy(original.offsets, this.offsets, this.offsets.Length);
     71
    10872      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    10973      if (original.classValues != null)
    11074        this.classValues = (double[])original.classValues.Clone();
    11175    }
    112     public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, bool selfMatch, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
     76    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
    11377      : base(targetVariable) {
    11478      Name = ItemName;
    11579      Description = ItemDescription;
    116       this.selfMatch = selfMatch;
    11780      this.k = k;
    11881      this.allowedInputVariables = allowedInputVariables.ToArray();
    11982      double[,] inputMatrix;
    120       if (IsCompatibilityLoaded) {
    121         // no scaling
    122         inputMatrix = dataset.ToArray(
    123           this.allowedInputVariables.Concat(new string[] { targetVariable }),
    124           rows);
     83      this.offsets = this.allowedInputVariables
     84        .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
     85        .Concat(new double[] { 0 }) // no offset for target variable
     86        .ToArray();
     87      if (weights == null) {
     88        // automatic determination of weights (all features should have variance = 1)
     89        this.weights = this.allowedInputVariables
     90          .Select(name => {
     91            var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
     92            return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
     93          })
     94          .Concat(new double[] { 1.0 }) // no scaling for target variable
     95          .ToArray();
    12596      } else {
    126         this.offsets = this.allowedInputVariables
    127           .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
    128           .Concat(new double[] { 0 }) // no offset for target variable
    129           .ToArray();
    130         if (weights == null) {
    131           // automatic determination of weights (all features should have variance = 1)
    132           this.weights = this.allowedInputVariables
    133             .Select(name => {
    134               var pop = dataset.GetDoubleValues(name, rows).StandardDeviationPop();
    135               return pop.IsAlmost(0) ? 1.0 : 1.0 / pop;
    136             })
    137             .Concat(new double[] { 1.0 }) // no scaling for target variable
    138             .ToArray();
    139         } else {
    140           // user specified weights (+ 1 for target)
    141           this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
    142           if (this.weights.Length - 1 != this.allowedInputVariables.Length)
    143             throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
    144         }
    145         inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
    146       }
     97        // user specified weights (+ 1 for target)
     98        this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
     99        if (this.weights.Length - 1 != this.allowedInputVariables.Length)
     100          throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
     101      }
     102      inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
    147103
    148104      if (inputMatrix.ContainsNanOrInfinity())
    149105        throw new NotSupportedException(
    150106          "Nearest neighbour model does not support NaN or infinity values in the input dataset.");
    151 
    152       this.kdTree = new alglib.nearestneighbor.kdtree();
    153107
    154108      var nRows = inputMatrix.GetLength(0);
     
    167121        }
    168122      }
    169       alglib.nearestneighbor.kdtreebuild(inputMatrix, nRows, inputMatrix.GetLength(1) - 1, 1, 2, kdTree);
     123
     124      alglib.knnbuildercreate(out var knnbuilder);
     125      if (classValues == null) {
     126        alglib.knnbuildersetdatasetreg(knnbuilder, inputMatrix, nRows, nFeatures, nout: 1);
     127      } else {
     128        alglib.knnbuildersetdatasetcls(knnbuilder, inputMatrix, nRows, nFeatures, classValues.Length);
     129      }
     130      alglib.knnbuilderbuildknnmodel(knnbuilder, k, 0.0, out model, out var report); // eps=0 (exact k-nn search is performed)
     131
    170132    }
    171133
     
    184146    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    185147      double[,] inputData;
    186       if (IsCompatibilityLoaded) {
    187         inputData = dataset.ToArray(allowedInputVariables, rows);
    188       } else {
    189         inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
    190       }
     148      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
    191149
    192150      int n = inputData.GetLength(0);
    193151      int columns = inputData.GetLength(1);
    194152      double[] x = new double[columns];
    195       double[] dists = new double[k];
    196       double[,] neighbours = new double[k, columns + 1];
    197 
     153
     154      alglib.knncreatebuffer(model, out var buf);
     155      var y = new double[1];
    198156      for (int row = 0; row < n; row++) {
    199157        for (int column = 0; column < columns; column++) {
    200158          x[column] = inputData[row, column];
    201159        }
    202         int numNeighbours;
    203         lock (kdTreeLockObject) { // gkronber: the following calls change the kdTree data structure
    204           numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
    205           alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
    206           alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
    207         }
    208         if (selfMatch) {
    209           // weights for neighbours are 1/d.
    210           // override distances (=0) of exact matches using 1% of the distance of the next closest non-self-match neighbour -> selfmatches weight 100x more than the next closest neighbor.
    211           // if all k neighbours are selfmatches then they all have weight 0.01.
    212           double minDist = dists[0] + 1;
    213           for (int i = 0; i < numNeighbours; i++) {
    214             if ((minDist > dists[i]) && (dists[i] != 0)) {
    215               minDist = dists[i];
    216             }
    217           }
    218           minDist /= 100.0;
    219           for (int i = 0; i < numNeighbours; i++) {
    220             if (dists[i] == 0) {
    221               dists[i] = minDist;
    222             }
    223           }
    224         }
    225         double distanceWeightedValue = 0.0;
    226         double distsSum = 0.0;
    227         for (int i = 0; i < numNeighbours; i++) {
    228           distanceWeightedValue += neighbours[i, columns] / dists[i];
    229           distsSum += 1.0 / dists[i];
    230         }
    231         yield return distanceWeightedValue / distsSum;
     160        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
     161        yield return y[0];
    232162      }
    233163    }
     
    236166      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
    237167      double[,] inputData;
    238       if (IsCompatibilityLoaded) {
    239         inputData = dataset.ToArray(allowedInputVariables, rows);
    240       } else {
    241         inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
    242       }
     168      inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
     169
    243170      int n = inputData.GetLength(0);
    244171      int columns = inputData.GetLength(1);
    245172      double[] x = new double[columns];
    246       int[] y = new int[classValues.Length];
    247       double[] dists = new double[k];
    248       double[,] neighbours = new double[k, columns + 1];
    249 
     173
     174      alglib.knncreatebuffer(model, out var buf);
     175      var y = new double[classValues.Length];
    250176      for (int row = 0; row < n; row++) {
    251177        for (int column = 0; column < columns; column++) {
    252178          x[column] = inputData[row, column];
    253179        }
    254         int numNeighbours;
    255         lock (kdTreeLockObject) {
    256           // gkronber: the following calls change the kdTree data structure
    257           numNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, selfMatch);
    258           alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
    259           alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
    260         }
    261         Array.Clear(y, 0, y.Length);
    262         for (int i = 0; i < numNeighbours; i++) {
    263           int classValue = (int)Math.Round(neighbours[i, columns]);
    264           y[classValue]++;
    265         }
    266 
    267         // find class for with the largest probability value
    268         int maxProbClassIndex = 0;
    269         double maxProb = y[0];
    270         for (int i = 1; i < y.Length; i++) {
    271           if (maxProb < y[i]) {
    272             maxProb = y[i];
    273             maxProbClassIndex = i;
    274           }
    275         }
    276         yield return classValues[maxProbClassIndex];
     180        alglib.knntsprocess(model, buf, x, ref y); // thread-safe process
     181        // find most probably class
     182        var maxC = 0;
     183        for (int i = 1; i < y.Length; i++)
     184          if (maxC < y[i]) maxC = i;
     185        yield return classValues[maxC];
    277186      }
    278187    }
     
    303212      return new NearestNeighbourClassificationSolution(this, new ClassificationProblemData(problemData));
    304213    }
    305 
    306     #region events
    307     public event EventHandler Changed;
    308     private void OnChanged(EventArgs e) {
    309       var handlers = Changed;
    310       if (handlers != null)
    311         handlers(this, e);
    312     }
    313     #endregion
    314 
    315 
    316     // BackwardsCompatibility3.3
    317     #region Backwards compatible code, remove with 3.4
    318 
    319     private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
    320     [Storable(DefaultValue = true)]
    321     public bool IsCompatibilityLoaded {
    322       get { return isCompatibilityLoaded; }
    323       set { isCompatibilityLoaded = value; }
    324     }
    325     #endregion
    326     #region persistence
    327     [Storable]
    328     public double KDTreeApproxF {
    329       get { return kdTree.approxf; }
    330       set { kdTree.approxf = value; }
    331     }
    332     [Storable]
    333     public double[] KDTreeBoxMax {
    334       get { return kdTree.boxmax; }
    335       set { kdTree.boxmax = value; }
    336     }
    337     [Storable]
    338     public double[] KDTreeBoxMin {
    339       get { return kdTree.boxmin; }
    340       set { kdTree.boxmin = value; }
    341     }
    342     [Storable]
    343     public double[] KDTreeBuf {
    344       get { return kdTree.buf; }
    345       set { kdTree.buf = value; }
    346     }
    347     [Storable]
    348     public double[] KDTreeCurBoxMax {
    349       get { return kdTree.curboxmax; }
    350       set { kdTree.curboxmax = value; }
    351     }
    352     [Storable]
    353     public double[] KDTreeCurBoxMin {
    354       get { return kdTree.curboxmin; }
    355       set { kdTree.curboxmin = value; }
    356     }
    357     [Storable]
    358     public double KDTreeCurDist {
    359       get { return kdTree.curdist; }
    360       set { kdTree.curdist = value; }
    361     }
    362     [Storable]
    363     public int KDTreeDebugCounter {
    364       get { return kdTree.debugcounter; }
    365       set { kdTree.debugcounter = value; }
    366     }
    367     [Storable]
    368     public int[] KDTreeIdx {
    369       get { return kdTree.idx; }
    370       set { kdTree.idx = value; }
    371     }
    372     [Storable]
    373     public int KDTreeKCur {
    374       get { return kdTree.kcur; }
    375       set { kdTree.kcur = value; }
    376     }
    377     [Storable]
    378     public int KDTreeKNeeded {
    379       get { return kdTree.kneeded; }
    380       set { kdTree.kneeded = value; }
    381     }
    382     [Storable]
    383     public int KDTreeN {
    384       get { return kdTree.n; }
    385       set { kdTree.n = value; }
    386     }
    387     [Storable]
    388     public int[] KDTreeNodes {
    389       get { return kdTree.nodes; }
    390       set { kdTree.nodes = value; }
    391     }
    392     [Storable]
    393     public int KDTreeNormType {
    394       get { return kdTree.normtype; }
    395       set { kdTree.normtype = value; }
    396     }
    397     [Storable]
    398     public int KDTreeNX {
    399       get { return kdTree.nx; }
    400       set { kdTree.nx = value; }
    401     }
    402     [Storable]
    403     public int KDTreeNY {
    404       get { return kdTree.ny; }
    405       set { kdTree.ny = value; }
    406     }
    407     [Storable]
    408     public double[] KDTreeR {
    409       get { return kdTree.r; }
    410       set { kdTree.r = value; }
    411     }
    412     [Storable]
    413     public double KDTreeRNeeded {
    414       get { return kdTree.rneeded; }
    415       set { kdTree.rneeded = value; }
    416     }
    417     [Storable]
    418     public bool KDTreeSelfMatch {
    419       get { return kdTree.selfmatch; }
    420       set { kdTree.selfmatch = value; }
    421     }
    422     [Storable]
    423     public double[] KDTreeSplits {
    424       get { return kdTree.splits; }
    425       set { kdTree.splits = value; }
    426     }
    427     [Storable]
    428     public int[] KDTreeTags {
    429       get { return kdTree.tags; }
    430       set { kdTree.tags = value; }
    431     }
    432     [Storable]
    433     public double[] KDTreeX {
    434       get { return kdTree.x; }
    435       set { kdTree.x = value; }
    436     }
    437     [Storable]
    438     public double[,] KDTreeXY {
    439       get { return kdTree.xy; }
    440       set { kdTree.xy = value; }
    441     }
    442     #endregion
    443214  }
    444215}
Note: See TracChangeset for help on using the changeset viewer.