Changeset 14235


Ignore:
Timestamp:
08/04/16 17:09:54 (3 years ago)
Author:
gkronber
Message:

#2652: added scaling and optional specification of feature-weights for kNN

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourClassification.cs

    r14185 r14235  
    4040    private const string KParameterName = "K";
    4141    private const string NearestNeighbourClassificationModelResultName = "Nearest neighbour classification solution";
     42    private const string WeightsParameterName = "Weights";
     43
    4244
    4345    #region parameter properties
    4446    public IFixedValueParameter<IntValue> KParameter {
    4547      get { return (IFixedValueParameter<IntValue>)Parameters[KParameterName]; }
     48    }
     49    public IValueParameter<DoubleArray> WeightsParameter {
     50      get { return (IValueParameter<DoubleArray>)Parameters[WeightsParameterName]; }
    4651    }
    4752    #endregion
     
    5358        else KParameter.Value.Value = value;
    5459      }
     60    }
     61    public DoubleArray Weights {
     62      get { return WeightsParameter.Value; }
     63      set { WeightsParameter.Value = value; }
    5564    }
    5665    #endregion
     
    6473      : base() {
    6574      Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "The number of nearest neighbours to consider for regression.", new IntValue(3)));
     75      Parameters.Add(new OptionalValueParameter<DoubleArray>(WeightsParameterName, "Optional: use weights to specify individual scaling values for all features. If not set the weights are calculated automatically (each feature is scaled to unit variance)"));
    6676      Problem = new ClassificationProblem();
    6777    }
    6878    [StorableHook(HookType.AfterDeserialization)]
    69     private void AfterDeserialization() { }
     79    private void AfterDeserialization() {
     80      // BackwardsCompatibility3.3
     81      #region Backwards compatible code, remove with 3.4
     82      if (!Parameters.ContainsKey(WeightsParameterName)) {
     83        Parameters.Add(new OptionalValueParameter<DoubleArray>(WeightsParameterName, "Optional: use weights to specify individual scaling values for all features. If not set the weights are calculated automatically (each feature is scaled to unit variance)"));
     84      }
     85      #endregion
     86    }
    7087
    7188    public override IDeepCloneable Clone(Cloner cloner) {
     
    7592    #region nearest neighbour
    7693    protected override void Run() {
    77       var solution = CreateNearestNeighbourClassificationSolution(Problem.ProblemData, K);
     94      double[] weights = null;
     95      if (Weights != null) weights = Weights.CloneAsArray();
     96      var solution = CreateNearestNeighbourClassificationSolution(Problem.ProblemData, K, weights);
    7897      Results.Add(new Result(NearestNeighbourClassificationModelResultName, "The nearest neighbour classification solution.", solution));
    7998    }
    8099
    81     public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k) {
     100    public static IClassificationSolution CreateNearestNeighbourClassificationSolution(IClassificationProblemData problemData, int k, double[] weights = null) {
    82101      var problemDataClone = (IClassificationProblemData)problemData.Clone();
    83       return new NearestNeighbourClassificationSolution(Train(problemDataClone, k), problemDataClone);
     102      return new NearestNeighbourClassificationSolution(Train(problemDataClone, k, weights), problemDataClone);
    84103    }
    85104
    86     public static INearestNeighbourModel Train(IClassificationProblemData problemData, int k) {
     105    public static INearestNeighbourModel Train(IClassificationProblemData problemData, int k, double[] weights = null) {
    87106      return new NearestNeighbourModel(problemData.Dataset,
    88107        problemData.TrainingIndices,
     
    90109        problemData.TargetVariable,
    91110        problemData.AllowedInputVariables,
     111        weights,
    92112        problemData.ClassValues.ToArray());
    93113    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourModel.cs

    r14185 r14235  
    5858    [Storable]
    5959    private int k;
     60    [Storable(DefaultValue = null)]
     61    private double[] weights; // not set for old versions loaded from disk
     62    [Storable(DefaultValue = null)]
     63    private double[] offsets; // not set for old versions loaded from disk
    6064
    6165    [StorableConstructor]
     
    9397
    9498      k = original.k;
     99      isCompatibilityLoaded = original.IsCompatibilityLoaded;
     100      if (!IsCompatibilityLoaded) {
     101        weights = new double[original.weights.Length];
     102        Array.Copy(original.weights, weights, weights.Length);
     103        offsets = new double[original.offsets.Length];
     104        Array.Copy(original.offsets, this.offsets, this.offsets.Length);
     105      }
    95106      allowedInputVariables = (string[])original.allowedInputVariables.Clone();
    96107      if (original.classValues != null)
    97108        this.classValues = (double[])original.classValues.Clone();
    98109    }
    99     public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, double[] classValues = null)
     110    public NearestNeighbourModel(IDataset dataset, IEnumerable<int> rows, int k, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<double> weights = null, double[] classValues = null)
    100111      : base(targetVariable) {
    101112      Name = ItemName;
     
    103114      this.k = k;
    104115      this.allowedInputVariables = allowedInputVariables.ToArray();
    105 
    106       var inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
    107                                    allowedInputVariables.Concat(new string[] { targetVariable }),
    108                                    rows);
     116      double[,] inputMatrix;
     117      if (IsCompatibilityLoaded) {
     118        // no scaling
     119        inputMatrix = AlglibUtil.PrepareInputMatrix(dataset,
     120          this.allowedInputVariables.Concat(new string[] { targetVariable }),
     121          rows);
     122      } else {
     123        this.offsets = this.allowedInputVariables
     124          .Select(name => dataset.GetDoubleValues(name, rows).Average() * -1)
     125          .Concat(new double[] { 0 }) // no offset for target variable
     126          .ToArray();
     127        if (weights == null) {
     128          // automatic determination of weights (all features should have variance = 1)
     129          this.weights = this.allowedInputVariables
     130            .Select(name => 1.0 / dataset.GetDoubleValues(name, rows).StandardDeviationPop())
     131            .Concat(new double[] { 1.0 }) // no scaling for target variable
     132            .ToArray();
     133        } else {
     134          // user specified weights (+ 1 for target)
     135          this.weights = weights.Concat(new double[] { 1.0 }).ToArray();
     136          if (this.weights.Length - 1 != this.allowedInputVariables.Length)
     137            throw new ArgumentException("The number of elements in the weight vector must match the number of input variables");
     138        }
     139        inputMatrix = CreateScaledData(dataset, this.allowedInputVariables.Concat(new string[] { targetVariable }), rows, this.offsets, this.weights);
     140      }
    109141
    110142      if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x) || double.IsInfinity(x)))
     
    132164    }
    133165
     166    private static double[,] CreateScaledData(IDataset dataset, IEnumerable<string> variables, IEnumerable<int> rows, double[] offsets, double[] factors) {
     167      var x = new double[rows.Count(), variables.Count()];
     168      var colIdx = 0;
     169      foreach (var variableName in variables) {
     170        var rowIdx = 0;
     171        foreach (var val in dataset.GetDoubleValues(variableName, rows)) {
     172          x[rowIdx, colIdx] = (val + offsets[colIdx]) * factors[colIdx];
     173          rowIdx++;
     174        }
     175        colIdx++;
     176      }
     177      return x;
     178    }
     179
    134180    public override IDeepCloneable Clone(Cloner cloner) {
    135181      return new NearestNeighbourModel(this, cloner);
     
    137183
    138184    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
    139       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     185      double[,] inputData;
     186      if (IsCompatibilityLoaded) {
     187        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     188      } else {
     189        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
     190      }
    140191
    141192      int n = inputData.GetLength(0);
    142193      int columns = inputData.GetLength(1);
    143194      double[] x = new double[columns];
    144       double[] y = new double[1];
    145195      double[] dists = new double[k];
    146196      double[,] neighbours = new double[k, columns + 1];
     
    152202        int actNeighbours = alglib.nearestneighbor.kdtreequeryknn(kdTree, x, k, false);
    153203        alglib.nearestneighbor.kdtreequeryresultsdistances(kdTree, ref dists);
    154         alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours);
     204        alglib.nearestneighbor.kdtreequeryresultsxy(kdTree, ref neighbours); // gkronber: this call changes the kdTree data structure
    155205
    156206        double distanceWeightedValue = 0.0;
     
    166216    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
    167217      if (classValues == null) throw new InvalidOperationException("No class values are defined.");
    168       double[,] inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
    169 
     218      double[,] inputData;
     219      if (IsCompatibilityLoaded) {
     220        inputData = AlglibUtil.PrepareInputMatrix(dataset, allowedInputVariables, rows);
     221      } else {
     222        inputData = CreateScaledData(dataset, allowedInputVariables, rows, offsets, weights);
     223      }
    170224      int n = inputData.GetLength(0);
    171225      int columns = inputData.GetLength(1);
     
    219273    #endregion
    220274
     275
     276    // BackwardsCompatibility3.3
     277    #region Backwards compatible code, remove with 3.4
     278
     279    private bool isCompatibilityLoaded = false; // new kNN models have the value false, kNN models loaded from disc have the value true
     280    [Storable(DefaultValue = true)]
     281    public bool IsCompatibilityLoaded {
     282      get { return isCompatibilityLoaded; }
     283      set { isCompatibilityLoaded = value; }
     284    }
     285    #endregion
    221286    #region persistence
    222287    [Storable]
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/NearestNeighbour/NearestNeighbourRegression.cs

    r14185 r14235  
    3939    private const string KParameterName = "K";
    4040    private const string NearestNeighbourRegressionModelResultName = "Nearest neighbour regression solution";
     41    private const string WeightsParameterName = "Weights";
    4142
    4243    #region parameter properties
    4344    public IFixedValueParameter<IntValue> KParameter {
    4445      get { return (IFixedValueParameter<IntValue>)Parameters[KParameterName]; }
     46    }
     47
     48    public IValueParameter<DoubleArray> WeightsParameter {
     49      get { return (IValueParameter<DoubleArray>)Parameters[WeightsParameterName]; }
    4550    }
    4651    #endregion
     
    5257        else KParameter.Value.Value = value;
    5358      }
     59    }
     60
     61    public DoubleArray Weights {
     62      get { return WeightsParameter.Value; }
     63      set { WeightsParameter.Value = value; }
    5464    }
    5565    #endregion
     
    6373      : base() {
    6474      Parameters.Add(new FixedValueParameter<IntValue>(KParameterName, "The number of nearest neighbours to consider for regression.", new IntValue(3)));
     75      Parameters.Add(new OptionalValueParameter<DoubleArray>(WeightsParameterName, "Optional: use weights to specify individual scaling values for all features. If not set the weights are calculated automatically (each feature is scaled to unit variance)"));
    6576      Problem = new RegressionProblem();
    6677    }
     78
    6779    [StorableHook(HookType.AfterDeserialization)]
    68     private void AfterDeserialization() { }
     80    private void AfterDeserialization() {
     81      // BackwardsCompatibility3.3
     82      #region Backwards compatible code, remove with 3.4
     83      if (!Parameters.ContainsKey(WeightsParameterName)) {
     84        Parameters.Add(new OptionalValueParameter<DoubleArray>(WeightsParameterName, "Optional: use weights to specify individual scaling values for all features. If not set the weights are calculated automatically (each feature is scaled to unit variance)"));
     85      }
     86      #endregion
     87    }
    6988
    7089    public override IDeepCloneable Clone(Cloner cloner) {
     
    7493    #region nearest neighbour
    7594    protected override void Run() {
    76       var solution = CreateNearestNeighbourRegressionSolution(Problem.ProblemData, K);
     95      double[] weights = null;
     96      if (Weights != null) weights = Weights.CloneAsArray();
     97      var solution = CreateNearestNeighbourRegressionSolution(Problem.ProblemData, K, weights);
    7798      Results.Add(new Result(NearestNeighbourRegressionModelResultName, "The nearest neighbour regression solution.", solution));
    7899    }
    79100
    80     public static IRegressionSolution CreateNearestNeighbourRegressionSolution(IRegressionProblemData problemData, int k) {
     101    public static IRegressionSolution CreateNearestNeighbourRegressionSolution(IRegressionProblemData problemData, int k, double[] weights = null) {
    81102      var clonedProblemData = (IRegressionProblemData)problemData.Clone();
    82       return new NearestNeighbourRegressionSolution(Train(problemData, k), clonedProblemData);
     103      return new NearestNeighbourRegressionSolution(Train(problemData, k, weights), clonedProblemData);
    83104    }
    84105
    85     public static INearestNeighbourModel Train(IRegressionProblemData problemData, int k) {
     106    public static INearestNeighbourModel Train(IRegressionProblemData problemData, int k, double[] weights = null) {
    86107      return new NearestNeighbourModel(problemData.Dataset,
    87108        problemData.TrainingIndices,
    88109        k,
    89110        problemData.TargetVariable,
    90         problemData.AllowedInputVariables);
     111        problemData.AllowedInputVariables,
     112        weights);
    91113    }
    92114    #endregion
Note: See TracChangeset for help on using the changeset viewer.