Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/07/15 14:31:06 (9 years ago)
Author:
gkronber
Message:

#2283 created a new branch to separate development from aballeit

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization-gkr
Files:
1 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Problems.GrammaticalOptimization.SymbReg/SymbolicRegressionProblem.cs

    r12099 r12290  
    33using System.Collections.Generic;
    44using System.Diagnostics;
     5using System.Globalization;
     6using System.IO;
    57using System.Linq;
    68using System.Security;
     
    6971    private readonly bool useConstantOpt;
    7072    public string Name { get; private set; }
    71 
    72     public SymbolicRegressionProblem(Random random, string partOfName, bool useConstantOpt = true) {
     73    private Random random;
     74    private double lambda;
     75
     76
     77    // lambda should be tuned using CV
     78    public SymbolicRegressionProblem(Random random, string partOfName, double lambda = 1.0, bool useConstantOpt = true) {
    7379      var instanceProviders = new RegressionInstanceProvider[]
    7480      {new RegressionRealWorldInstanceProvider(),
     
    8086    };
    8187      var instanceProvider = instanceProviders.FirstOrDefault(prov => prov.GetDataDescriptors().Any(dd => dd.Name.Contains(partOfName)));
    82       if (instanceProvider == null) throw new ArgumentException("instance not found");
     88      IRegressionProblemData problemData = null;
     89      if (instanceProvider != null) {
     90        var dds = instanceProvider.GetDataDescriptors();
     91        problemData = instanceProvider.LoadData(dds.Single(ds => ds.Name.Contains(partOfName)));
     92
     93      } else if (File.Exists(partOfName)) {
     94        // check if it is a file
     95        var prov = new RegressionCSVInstanceProvider();
     96        problemData = prov.ImportData(partOfName);
     97        problemData.TrainingPartition.Start = 0;
     98        problemData.TrainingPartition.End = problemData.Dataset.Rows;
     99        // no test partition
     100        problemData.TestPartition.Start = problemData.Dataset.Rows;
     101        problemData.TestPartition.End = problemData.Dataset.Rows;
     102      } else {
     103        throw new ArgumentException("instance not found");
     104      }
    83105
    84106      this.useConstantOpt = useConstantOpt;
    85107
    86       var dds = instanceProvider.GetDataDescriptors();
    87       var problemData = instanceProvider.LoadData(dds.Single(ds => ds.Name.Contains(partOfName)));
    88       this.Name = problemData.Name;
     108      this.Name = problemData.Name + string.Format("lambda={0:N2}", lambda);
    89109
    90110      this.N = problemData.TrainingIndices.Count();
     
    94114      this.y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    95115
     116      var varEst = new OnlineMeanAndVarianceCalculator();
     117
     118      var means = new double[d];
     119      var stdDevs = new double[d];
     120
    96121      int i = 0;
     122      foreach (var inputVariable in problemData.AllowedInputVariables) {
     123        varEst.Reset();
     124        problemData.Dataset.GetDoubleValues(inputVariable).ToList().ForEach(varEst.Add);
     125        if (varEst.VarianceErrorState != OnlineCalculatorError.None) throw new ArgumentException();
     126        means[i] = varEst.Mean;
     127        stdDevs[i] = Math.Sqrt(varEst.Variance);
     128        i++;
     129      }
     130
     131      i = 0;
    97132      foreach (var r in problemData.TrainingIndices) {
    98133        int j = 0;
    99134        foreach (var inputVariable in problemData.AllowedInputVariables) {
    100           x[i, j++] = problemData.Dataset.GetDoubleValue(inputVariable, r);
     135          x[i, j] = (problemData.Dataset.GetDoubleValue(inputVariable, r) - means[j]) / stdDevs[j];
     136          j++;
    101137        }
    102138        i++;
    103139      }
     140
     141      this.random = random;
     142      this.lambda = lambda;
     143
    104144      // initialize ERC values
    105145      erc = Enumerable.Range(0, 10).Select(_ => Rand.RandNormal(random) * 10).ToArray();
     
    132172        return OptimizeConstantsAndEvaluate(sentence);
    133173      else {
    134 
    135174        Debug.Assert(SimpleEvaluate(sentence) == SimpleEvaluate(extender.CanonicalRepresentation(sentence)));
    136175        return SimpleEvaluate(sentence);
     
    154193
    155194    public IEnumerable<Feature> GetFeatures(string phrase) {
    156       // throw new NotImplementedException();
    157       phrase = CanonicalRepresentation(phrase);
    158       return phrase.Split('+').Distinct().Select(t => new Feature(t, 1.0));
     195      throw new NotImplementedException();
     196      //phrase = CanonicalRepresentation(phrase);
     197      //return phrase.Split('+').Distinct().Select(t => new Feature(t, 1.0));
    159198      // return new Feature[] { new Feature(phrase, 1.0) };
    160199    }
     
    176215      if (!constants.Any()) return SimpleEvaluate(sentence);
    177216
    178       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(constants, variables); // variate constants leave variables fixed to data
    179 
    180       double[] c = constants.Select(_ => 1.0).ToArray(); // start with ones
    181 
     217      // L2 regularization
     218      // not possible with lsfit, would need to change to minlm below
     219      // func = TermBuilder.Sum(func, lambda * TermBuilder.Sum(constants.Select(con => con * con)));
     220
     221      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(constants, variables); // variate constants, leave variables fixed to data
     222
     223      // 10 restarts with random starting points
     224      double[] bestStart = null;
     225      double bestError = double.PositiveInfinity;
     226      int info;
    182227      alglib.lsfitstate state;
    183228      alglib.lsfitreport rep;
    184       int info;
    185 
    186 
    187       int k = c.Length;
    188 
    189229      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    190230      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
    191 
    192       const int maxIterations = 10;
    193       try {
    194         alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
    195         alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
    196         //alglib.lsfitsetgradientcheck(state, 0.001);
    197         alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
    198         alglib.lsfitresults(state, out info, out c, out rep);
    199       } catch (ArithmeticException) {
    200         return 0.0;
    201       } catch (alglib.alglibexception) {
    202         return 0.0;
    203       }
    204 
    205       //info == -7  => constant optimization failed due to wrong gradient
    206       if (info == -7) throw new ArgumentException();
     231      for (int t = 0; t < 10; t++) {
     232        double[] cStart = constants.Select(_ => Rand.RandNormal(random) * 10).ToArray();
     233        double[] cEnd;
     234        // start with normally distributed (N(0, 10)) weights
     235
     236
     237        int k = cStart.Length;
     238
     239
     240        const int maxIterations = 10;
     241        try {
     242          alglib.lsfitcreatefg(x, y, cStart, n, m, k, false, out state);
     243          alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
     244          //alglib.lsfitsetgradientcheck(state, 0.001);
     245          alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
     246          alglib.lsfitresults(state, out info, out cEnd, out rep);
     247          if (info != -7 && rep.rmserror < bestError) {
     248            bestStart = cStart;
     249            bestError = rep.rmserror;
     250          }
     251        } catch (ArithmeticException) {
     252          return 0.0;
     253        } catch (alglib.alglibexception) {
     254          return 0.0;
     255        }
     256      }
     257
     258      // 100 iteration steps from the best starting point
    207259      {
    208         var rowData = new double[d];
    209         return HeuristicLab.Common.Extensions.RSq(y, Enumerable.Range(0, N).Select(i => {
    210           for (int j = 0; j < d; j++) rowData[j] = x[i, j];
    211           return compiledFunc.Evaluate(c, rowData);
    212         }));
     260        double[] c = bestStart;
     261
     262        int k = c.Length;
     263
     264        const int maxIterations = 100;
     265        try {
     266          alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
     267          alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
     268          //alglib.lsfitsetgradientcheck(state, 0.001);
     269          alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
     270          alglib.lsfitresults(state, out info, out c, out rep);
     271        } catch (ArithmeticException) {
     272          return 0.0;
     273        } catch (alglib.alglibexception) {
     274          return 0.0;
     275        }
     276        //info == -7  => constant optimization failed due to wrong gradient
     277        if (info == -7) throw new ArgumentException();
     278        {
     279          var rowData = new double[d];
     280          return HeuristicLab.Common.Extensions.RSq(y, Enumerable.Range(0, N).Select(i => {
     281            for (int j = 0; j < d; j++) rowData[j] = x[i, j];
     282            return compiledFunc.Evaluate(c, rowData);
     283          }));
     284        }
    213285      }
    214286    }
Note: See TracChangeset for help on using the changeset viewer.