Free cookie consent management tool by TermsFeed Policy Generator

Changeset 3995


Ignore:
Timestamp:
07/05/10 13:58:47 (14 years ago)
Author:
mkommend
Message:

improved !SymbolicRegressionScaledMSEEvaluator (ticket #1074)

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/SymbolicRegressionMeanSquaredErrorEvaluator.cs

    r3513 r3995  
    7272
    7373    public static double Calculate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, int start, int end) {
    74       int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
    7574      var estimatedValues = from x in interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start))
    7675                            let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x))
    7776                            select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
    78       var originalValues = from row in Enumerable.Range(start, end - start) select dataset[row, targetVariableIndex];
     77      var originalValues = dataset.GetEnumeratedVariableValues(targetVariable, start, end);
    7978      return SimpleMSEEvaluator.Calculate(originalValues, estimatedValues);
    8079    }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/SymbolicRegressionScaledMeanSquaredErrorEvaluator.cs

    r3807 r3995  
    7575
    7676    public static double Calculate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, int start, int end, out double beta, out double alpha) {
    77       var estimatedValues = CalculateScaledEstimatedValues(interpreter, solution, dataset, targetVariable, start, end, out beta, out alpha);
    78       estimatedValues = from x in estimatedValues
    79                         let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x))
    80                         select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
    81       var originalValues = dataset.GetVariableValues(targetVariable, start, end);
    82       return SimpleMSEEvaluator.Calculate(originalValues, estimatedValues);
     77      IEnumerable<double> originalValues = dataset.GetEnumeratedVariableValues(targetVariable, start, end);
     78      IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start));
     79      CalculateScalingParameters(originalValues, estimatedValues, out beta, out alpha);
     80
     81      return CalculateWithScaling(interpreter, solution, lowerEstimationLimit, upperEstimationLimit, dataset, targetVariable, start, end, beta, alpha);
    8382    }
    8483
    8584    public static double CalculateWithScaling(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, int start, int end, double beta, double alpha) {
    86       var estimatedValues = from x in interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start))
    87                             let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x * beta + alpha))
    88                             select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
    89       var originalValues = dataset.GetVariableValues(targetVariable, start, end);
    90       return SimpleMSEEvaluator.Calculate(originalValues, estimatedValues);
     85      //IEnumerable<double> estimatedValues = from x in interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start))
     86      //                                      let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x * beta + alpha))
     87      //                                      select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
     88      IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start));
     89      IEnumerable<double> originalValues = dataset.GetEnumeratedVariableValues(targetVariable, start, end);
     90      IEnumerator<double> originalEnumerator = originalValues.GetEnumerator();
     91      IEnumerator<double> estimatedEnumerator = estimatedValues.GetEnumerator();
     92      double cnt = 0;
     93      double sse = 0;
     94
     95      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
     96        double estimated = estimatedEnumerator.Current * beta + alpha;
     97        double original = originalEnumerator.Current;
     98        estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
     99        if (double.IsNaN(estimated))
     100          estimated = upperEstimationLimit;
     101        if (!double.IsNaN(estimated) && !double.IsInfinity(estimated) &&
     102            !double.IsNaN(original) && !double.IsInfinity(original)) {
     103          double error = estimated - original;
     104          sse += error * error;
     105          cnt++;
     106        }
     107      }
     108
     109      if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
     110        throw new ArgumentException("Number of elements in original and estimated enumeration doesn't match.");
     111      } else if (cnt == 0) {
     112        throw new ArgumentException("Mean squared errors is not defined for input vectors of NaN or Inf");
     113      } else {
     114        double mse = sse / cnt;
     115        return mse;
     116      }
    91117    }
    92118
    93     private static IEnumerable<double> CalculateScaledEstimatedValues(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, Dataset dataset, string targetVariable, int start, int end, out double beta, out double alpha) {
    94       int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
    95       var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start)).ToArray();
    96       var originalValues = dataset.GetVariableValues(targetVariable, start, end);
    97       CalculateScalingParameters(originalValues, estimatedValues, out beta, out alpha);
    98       for (int i = 0; i < estimatedValues.Length; i++)
    99         estimatedValues[i] = estimatedValues[i] * beta + alpha;
    100       return estimatedValues;
    101     }
     119    /// <summary>
     120    /// Calculates linear scaling parameters in one pass.
     121    /// The formulas to calculate the scaling parameters were taken from Scaled Symblic Regression by Maarten Keijzer.
     122    /// http://www.springerlink.com/content/x035121165125175/
     123    /// </summary>
     124    public static void CalculateScalingParameters(IEnumerable<double> original, IEnumerable<double> estimated, out double beta, out double alpha) {
     125      IEnumerator<double> originalEnumerator = original.GetEnumerator();
     126      IEnumerator<double> estimatedEnumerator = estimated.GetEnumerator();
    102127
     128      int cnt = 0;
     129      double tSum = 0;
     130      double ySum = 0;
     131      double yySum = 0;
     132      double ytSum = 0;
    103133
    104     public static void CalculateScalingParameters(IEnumerable<double> original, IEnumerable<double> estimated, out double beta, out double alpha) {
    105       double[] originalValues = original.ToArray();
    106       double[] estimatedValues = estimated.ToArray();
    107       if (originalValues.Length != estimatedValues.Length) throw new ArgumentException();
    108       var filteredResult = (from row in Enumerable.Range(0, originalValues.Length)
    109                             let t = originalValues[row]
    110                             let e = estimatedValues[row]
    111                             where IsValidValue(t)
    112                             where IsValidValue(e)
    113                             select new { Estimation = e, Target = t })
    114                    .OrderBy(x => Math.Abs(x.Target))            // make sure small values are considered before large values
    115                    .ToArray();     
     134      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
     135        double y = estimatedEnumerator.Current;
     136        double t = originalEnumerator.Current;
     137        if (IsValidValue(t) && IsValidValue(y)) {
     138          cnt++;
     139          tSum += t;
     140          ySum += y;
     141          yySum += y * y;
     142          ytSum += t * y;
     143        }
     144      }
    116145
    117       // calculate alpha and beta on the subset of rows with valid values
    118       originalValues = filteredResult.Select(x => x.Target).ToArray();
    119       estimatedValues = filteredResult.Select(x => x.Estimation).ToArray();
    120       int n = originalValues.Length;
    121       if (n > 2) {
    122         double tMean = originalValues.Average();
    123         double xMean = estimatedValues.Average();
    124         double sumXT = 0;
    125         double sumXX = 0;
    126         for (int i = 0; i < n; i++) {
    127           // calculate alpha and beta on the subset of rows with valid values
    128           double x = estimatedValues[i];
    129           double t = originalValues[i];
    130           sumXT += (x - xMean) * (t - tMean);
    131           sumXX += (x - xMean) * (x - xMean);
    132         }
    133         if (!sumXX.IsAlmost(0.0)) {
    134           beta = sumXT / sumXX;
    135         } else {
     146      if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext())
     147        throw new ArgumentException("Number of elements in original and estimated enumeration doesn't match.");
     148      if (cnt < 2) {
     149        alpha = 0;
     150        beta = 1;
     151      } else {
     152        double tMean = tSum / cnt;
     153        double yMean = ySum / cnt;
     154        //division by cnt is omited because the variance and covariance are divided afterwards.
     155        double yVariance = yySum - 2 * yMean * ySum + cnt * yMean * yMean;
     156        double ytCovariance = ytSum - tMean * ySum - yMean * tSum + cnt * yMean * tMean;
     157
     158        if (yVariance.IsAlmost(0.0))
    136159          beta = 1;
    137         }
    138         alpha = tMean - beta * xMean;
    139       } else {
    140         alpha = 0.0;
    141         beta = 1.0;
     160        else
     161          beta = ytCovariance / yVariance;
     162
     163        alpha = tMean - beta * yMean;
    142164      }
    143165    }
Note: See TracChangeset for help on using the changeset viewer.