Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/11/20 13:39:48 (4 years ago)
Author:
pfleck
Message:

#3040 First draft of different-vector-length strategies (cut, fill, resample, cycle, ...)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeVectorInterpreter.cs

    r17604 r17721  
    2929using HeuristicLab.Parameters;
    3030using HEAL.Attic;
     31using MathNet.Numerics;
    3132using MathNet.Numerics.Statistics;
    32 
    3333using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
    3434
     
    4343      Sum,
    4444      First,
     45      L1Norm,
     46      L2Norm,
    4547      NaN,
    4648      Exception
     49    }
     50    public static double Aggregate(Aggregation aggregation, DoubleVector vector) {
     51      switch (aggregation) {
     52        case Aggregation.Mean: return Statistics.Mean(vector);
     53        case Aggregation.Median: return Statistics.Median(vector);
     54        case Aggregation.Sum: return vector.Sum();
     55        case Aggregation.First: return vector.First();
     56        case Aggregation.L1Norm: return vector.L1Norm();
     57        case Aggregation.L2Norm: return vector.L2Norm();
     58        case Aggregation.NaN: return double.NaN;
     59        case Aggregation.Exception: throw new InvalidOperationException("Result of the tree is not a scalar.");
     60        default: throw new ArgumentOutOfRangeException(nameof(aggregation), aggregation, null);
     61      }
     62    }
     63
     64    [StorableType("73DCBB45-916F-4139-8ADC-57BA610A1B66")]
     65    public enum VectorLengthStrategy {
     66      ExceptionIfDifferent,
     67      FillShorterWithNaN,
     68      FillShorterWithNeutralElement,
     69      CutLonger,
     70      ResampleToLonger,
     71      ResampleToShorter,
     72      CycleShorter
     73    }
     74
     75    #region Implementation VectorLengthStrategy
     76    public static (DoubleVector, DoubleVector) ExceptionIfDifferent(DoubleVector lhs, DoubleVector rhs) {
     77      if (lhs.Count != rhs.Count)
     78        throw new InvalidOperationException($"Vector Lengths incompatible ({lhs.Count} vs. {rhs.Count}");
     79      return (lhs, rhs);
     80    }
     81
     82    public static (DoubleVector, DoubleVector) FillShorter(DoubleVector lhs, DoubleVector rhs, double fillElement) {
     83      var targetLength = Math.Max(lhs.Count, rhs.Count);
     84
     85      DoubleVector PadVector(DoubleVector v) {
     86        if (v.Count == targetLength) return v;
     87        var p = DoubleVector.Build.Dense(targetLength, fillElement);
     88        v.CopySubVectorTo(p, 0, 0, v.Count);
     89        return p;
     90      }
     91
     92      return (PadVector(lhs), PadVector(rhs));
     93    }
     94
     95    public static (DoubleVector, DoubleVector) CutLonger(DoubleVector lhs, DoubleVector rhs) {
     96      var targetLength = Math.Min(lhs.Count, rhs.Count);
     97
     98      DoubleVector CutVector(DoubleVector v) {
     99        if (v.Count == targetLength) return v;
     100        return v.SubVector(0, targetLength);
     101      }
     102
     103      return (CutVector(lhs), CutVector(rhs));
     104    }
     105
     106    private static DoubleVector ResampleToLength(DoubleVector v, int targetLength) {
     107      if (v.Count == targetLength) return v;
     108
     109      var indices = Enumerable.Range(0, v.Count).Select(x => (double)x);
     110      var interpolation = Interpolate.Linear(indices, v);
     111
     112      var resampledIndices = Enumerable.Range(0, targetLength).Select(i => (double)i / targetLength * v.Count);
     113      var interpolatedValues = resampledIndices.Select(interpolation.Interpolate);
     114
     115      return DoubleVector.Build.DenseOfEnumerable(interpolatedValues);
     116    }
     117    public static (DoubleVector, DoubleVector) ResampleToLonger(DoubleVector lhs, DoubleVector rhs) {
     118      var maxLength = Math.Max(lhs.Count, rhs.Count);
     119      return (ResampleToLength(lhs, maxLength), ResampleToLength(rhs, maxLength));
     120    }
     121    public static (DoubleVector, DoubleVector) ResampleToShorter(DoubleVector lhs, DoubleVector rhs) {
     122      var minLength = Math.Min(lhs.Count, rhs.Count);
     123      return (ResampleToLength(lhs, minLength), ResampleToLength(rhs, minLength));
     124    }
     125
     126    public static (DoubleVector, DoubleVector) CycleShorter(DoubleVector lhs, DoubleVector rhs) {
     127      var targetLength = Math.Max(lhs.Count, rhs.Count);
     128
     129      DoubleVector CycleVector(DoubleVector v) {
     130        if (v.Count == targetLength) return v;
     131        var cycledValues = Enumerable.Range(0, targetLength).Select(i => v[i % v.Count]);
     132        return DoubleVector.Build.DenseOfEnumerable(cycledValues);
     133      }
     134
     135      return (CycleVector(lhs), CycleVector(rhs));
     136    }
     137    #endregion
     138
     139    public static (DoubleVector lhs, DoubleVector rhs) ApplyVectorLengthStrategy(VectorLengthStrategy strategy, DoubleVector lhs, DoubleVector rhs,
     140      double neutralElement = double.NaN) {
     141
     142      switch (strategy) {
     143        case VectorLengthStrategy.ExceptionIfDifferent: return ExceptionIfDifferent(lhs, rhs);
     144        case VectorLengthStrategy.FillShorterWithNaN: return FillShorter(lhs, rhs, double.NaN);
     145        case VectorLengthStrategy.FillShorterWithNeutralElement: return FillShorter(lhs, rhs, neutralElement);
     146        case VectorLengthStrategy.CutLonger: return CutLonger(lhs, rhs);
     147        case VectorLengthStrategy.ResampleToLonger: return ResampleToLonger(lhs, rhs);
     148        case VectorLengthStrategy.ResampleToShorter: return ResampleToShorter(lhs, rhs);
     149        case VectorLengthStrategy.CycleShorter: return CycleShorter(lhs, rhs);
     150        default: throw new ArgumentOutOfRangeException(nameof(strategy), strategy, null);
     151      }
    47152    }
    48153
     
    56161    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
    57162    private const string FinalAggregationParameterName = "FinalAggregation";
     163    private const string DifferentVectorLengthStrategyParameterName = "DifferentVectorLengthStrategy";
    58164
    59165    public override bool CanChangeName {
     
    71177    public IFixedValueParameter<EnumValue<Aggregation>> FinalAggregationParameter {
    72178      get { return (IFixedValueParameter<EnumValue<Aggregation>>)Parameters[FinalAggregationParameterName]; }
     179    }
     180    public IFixedValueParameter<EnumValue<VectorLengthStrategy>> DifferentVectorLengthStrategyParameter {
     181      get { return (IFixedValueParameter<EnumValue<VectorLengthStrategy>>)Parameters[DifferentVectorLengthStrategyParameterName]; }
    73182    }
    74183    #endregion
     
    83192      set { FinalAggregationParameter.Value.Value = value; }
    84193    }
     194    public VectorLengthStrategy DifferentVectorLengthStrategy {
     195      get { return DifferentVectorLengthStrategyParameter.Value.Value; }
     196      set { DifferentVectorLengthStrategyParameter.Value.Value = value; }
     197    }
    85198    #endregion
    86199
     
    103216      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
    104217      Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
     218      Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
    105219    }
    106220
     
    109223      if (!Parameters.ContainsKey(FinalAggregationParameterName)) {
    110224        Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
     225      }
     226      if (!Parameters.ContainsKey(DifferentVectorLengthStrategyParameterName)) {
     227        Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
    111228      }
    112229    }
     
    133250          yield return result.Scalar;
    134251        else if (result.IsVector) {
    135           if (FinalAggregation == Aggregation.Mean) yield return result.Vector.Mean();
    136           else if (FinalAggregation == Aggregation.Median) yield return Statistics.Median(result.Vector);
    137           else if (FinalAggregation == Aggregation.Sum) yield return result.Vector.Sum();
    138           else if (FinalAggregation == Aggregation.First) yield return result.Vector.First();
    139           else if (FinalAggregation == Aggregation.Exception) throw new InvalidOperationException("Result of the tree is not a scalar.");
    140           else yield return double.NaN;
     252          yield return Aggregate(FinalAggregation, result.Vector);
    141253        } else
    142254          yield return double.NaN;
     
    206318
    207319    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
     320      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
    208321      Func<double, double, double> ssFunc = null,
    209322      Func<double, DoubleVector, DoubleVector> svFunc = null,
    210323      Func<DoubleVector, double, DoubleVector> vsFunc = null,
    211324      Func<DoubleVector, DoubleVector, DoubleVector> vvFunc = null) {
     325
    212326      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
    213327      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
    214328      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
    215       if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
     329      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
     330        if (lhs.Vector.Count == rhs.Vector.Count) {
     331          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
     332        } else {
     333          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
     334          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
     335        }
     336      }
    216337      return EvaluationResult.NaN;
    217338    }
     
    250371    }
    251372    private static EvaluationResult AggregateMultipleApply(EvaluationResult lhs, EvaluationResult rhs,
     373      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
    252374      Func<double, double, double> ssFunc = null,
    253375      Func<double, DoubleVector, double> svFunc = null,
     
    257379      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
    258380      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
    259       if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
     381      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
     382        if (lhs.Vector.Count == rhs.Vector.Count) {
     383          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
     384        } else {
     385          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
     386          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
     387        }
     388      }
    260389      return EvaluationResult.NaN;
    261390    }
     
    283412              var op = Evaluate(dataset, ref row, state);
    284413              cur = ArithmeticApply(cur, op,
     414                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
    285415                (s1, s2) => s1 + s2,
    286416                (s1, v2) => s1 + v2,
     
    295425              var op = Evaluate(dataset, ref row, state);
    296426              cur = ArithmeticApply(cur, op,
     427                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
    297428                (s1, s2) => s1 - s2,
    298429                (s1, v2) => s1 - v2,
     
    307438              var op = Evaluate(dataset, ref row, state);
    308439              cur = ArithmeticApply(cur, op,
     440                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
    309441                (s1, s2) => s1 * s2,
    310442                (s1, v2) => s1 * v2,
     
    319451              var op = Evaluate(dataset, ref row, state);
    320452              cur = ArithmeticApply(cur, op,
     453                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
    321454                (s1, s2) => s1 / s2,
    322455                (s1, v2) => s1 / v2,
     
    362495            var y = Evaluate(dataset, ref row, state);
    363496            return ArithmeticApply(x, y,
     497              (lhs, rhs) => lhs.Count < rhs.Count
     498                ? CutLonger(lhs, rhs)
     499                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
    364500              (s1, s2) => Math.Pow(s1, Math.Round(s2)),
    365501              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)),
     
    383519            var y = Evaluate(dataset, ref row, state);
    384520            return ArithmeticApply(x, y,
     521              (lhs, rhs) => lhs.Count < rhs.Count
     522                ? CutLonger(lhs, rhs)
     523                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
    385524              (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)),
    386525              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)),
     
    410549            return AggregateApply(cur,
    411550              s => s,
    412               v => v.Mean());
     551              v => Statistics.Mean(v));
    413552          }
    414553        case OpCodes.StandardDeviation: {
     
    458597            var x2 = Evaluate(dataset, ref row, state);
    459598            return AggregateMultipleApply(x1, x2,
    460               //(s1, s2) => s1 - s2,
    461               //(s1, v2) => Math.Sqrt((s1 - v2).PointwisePower(2).Sum()),
    462               //(v1, s2) => Math.Sqrt((v1 - s2).PointwisePower(2).Sum()),
    463               vvFunc: (v1, v2) => v1.Count == v2.Count ? Math.Sqrt((v1 - v2).PointwisePower(2).Sum()) : double.NaN);
     599              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
     600              (s1, s2) => s1 - s2,
     601              (s1, v2) => Math.Sqrt((s1 - v2).PointwisePower(2).Sum()),
     602              (v1, s2) => Math.Sqrt((v1 - s2).PointwisePower(2).Sum()),
     603              (v1, v2) => Math.Sqrt((v1 - v2).PointwisePower(2).Sum()));
    464604          }
    465605        case OpCodes.Covariance: {
     
    467607            var x2 = Evaluate(dataset, ref row, state);
    468608            return AggregateMultipleApply(x1, x2,
    469               //(s1, s2) => 0,
    470               //(s1, v2) => 0,
    471               //(v1, s2) => 0,
    472               vvFunc: (v1, v2) => v1.Count == v2.Count ? Statistics.PopulationCovariance(v1, v2) : double.NaN);
     609              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
     610              (s1, s2) => 0,
     611              (s1, v2) => 0,
     612              (v1, s2) => 0,
     613              (v1, v2) => Statistics.PopulationCovariance(v1, v2));
    473614          }
    474615        case OpCodes.Variable: {
Note: See TracChangeset for help on using the changeset viewer.