Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15420


Ignore:
Timestamp:
10/17/17 09:56:30 (7 years ago)
Author:
gkronber
Message:

#2796: debugging

Location:
branches/MCTS-SymbReg-2796
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r15416 r15420  
    385385
    386386    public void Goto(int targetState) {
     387      Debug.Assert(followStates[CurrentState].Contains(targetState));
    387388      if (actions[CurrentState, targetState] != null)
    388389        actions[CurrentState, targetState].ForEach(a => a()); // execute all actions
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15416 r15420  
    7171    // TODO: is it OK to initialize all constants to 1 (Obj 2)?
    7272    // TODO: improve memory usage
     73    // TODO: support empty test partition
    7374    #region static API
    7475
     
    138139      private readonly double[] predBuf, testPredBuf;
    139140      private readonly double[][] gradBuf;
     141
     142      // debugging stats
     143      // calculate for each level the number of alternatives the average 'inequality' of tries and 'inequality' of quality over the alternatives for each trie
     144      // inequality can be calculated using the Gini coefficient
     145      internal readonly double[] giniCoeffs = new double[100];
     146     
    140147
    141148      public State(IRegressionProblemData problemData, uint randSeed, int maxVariables, bool scaleVariables,
     
    436443        }
    437444      }
     445
    438446      #endregion
     447
     448#if DEBUG
     449      internal void ClearStats() {
     450        for (int i = 0; i < giniCoeffs.Length; i++) giniCoeffs[i] = -1;
     451      }
     452      internal void WriteStats() {
     453        Console.WriteLine(string.Join("\t", giniCoeffs.TakeWhile(x => x >= 0).Select(x => string.Format("{0:N3}", x))));
     454      }
     455
     456#endif
     457
    439458    }
    440459
     
    496515      bool success = false;
    497516      do {
     517#if DEBUG
     518        mctsState.ClearStats();
     519#endif
    498520        automaton.Reset();
    499521        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
     
    502524      mctsState.effectiveRollouts++;
    503525
     526#if DEBUG
     527      mctsState.WriteStats();
     528#endif
    504529      //if (mctsState.effectiveRollouts % 100 == 1) {
    505         // Console.WriteLine(WriteTree(tree, mctsState));
    506         // Console.WriteLine(TraceTree(tree, mctsState));
     530      // Console.WriteLine(WriteTree(tree, mctsState));
     531      // Console.WriteLine(TraceTree(tree, mctsState));
    507532      //}
    508533      return q;
    509534    }
    510 
    511535
    512536    // search forward
     
    540564            selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
    541565          }
     566
     567          // STATS
     568          state.giniCoeffs[tree.level] = InequalityCoefficient(state.children[tree].Select(ch => (double)ch.actionStatistics.AverageQuality));
     569
    542570          tree = state.children[tree][selectedIdx];
    543571
     
    548576          int nFs;
    549577          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    550           while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
     578          while (automaton.CurrentState != tree.state && nFs == 1 &&
     579            !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
    551580            automaton.Goto(possibleFollowStates[0]);
    552581            automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
     
    658687    }
    659688
     689    private static double InequalityCoefficient(IEnumerable<double> xs) {
     690      var arr = xs.ToArray();
     691      var sad = 0.0;
     692      var sum = 0.0;
     693
     694      for(int i=0;i<arr.Length;i++) {
     695        for(int j=0;j<arr.Length;j++) {
     696          sad += Math.Abs(arr[i] - arr[j]);
     697          sum += arr[j];
     698        }
     699      }
     700      return 0.5 * sad / sum;     
     701    }
    660702
    661703    private static double TransformQuality(double q) {
     
    871913      return Disassembler.CodeToString(code);
    872914    }
     915
    873916
    874917    private static string WriteStatistics(Tree tree, State state) {
  • branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs

    r15416 r15420  
    11using System;
     2using System.Collections.Generic;
    23using System.Linq;
    34using System.Threading;
     
    758759    [TestMethod]
    759760    [TestCategory("Algorithms.DataAnalysis")]
    760     [TestProperty("Time", "short")]
    761     public void MctsSymbReg_NoConstants_Poly10() {
     761    [TestProperty("Time", "long")]
     762    public void MctsSymbReg_NoConstants_Poly10_250rows() {
    762763      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
    763764      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     765      regProblem.TrainingPartition.Start = 0;
     766      regProblem.TrainingPartition.End = regProblem.Dataset.Rows;
     767      regProblem.TestPartition.Start = 0;
     768      regProblem.TestPartition.End = 2;
    764769      TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 200000, allowExp: false, allowLog: false, allowInv: false);
     770    }
     771    [TestMethod]
     772    [TestCategory("Algorithms.DataAnalysis")]
     773    [TestProperty("Time", "long")]
     774    public void MctsSymbReg_NoConstants_Poly10_10000rows() {
     775      // as poly-10 but more rows
     776      var rand = new FastRandom(1234);
     777      var x1 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     778      var x2 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     779      var x3 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     780      var x4 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     781      var x5 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     782      var x6 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     783      var x7 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     784      var x8 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     785      var x9 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     786      var x10 = Enumerable.Range(0, 10000).Select(_ => rand.NextDouble()).ToList();
     787      var ys = new List<double>();
     788      for (int i = 0; i < x1.Count; i++) {
     789        ys.Add(x1[i] * x2[i] + x3[i] * x4[i] + x5[i] * x6[i] + x1[i] * x7[i] * x9[i] + x3[i] * x6[i] * x10[i]);
     790      }
     791
     792      var ds = new Dataset(new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "y" },
     793        new[] { x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, ys });
     794
     795
     796      var problemData = new RegressionProblemData(ds, new string[] { "a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, "y");
     797
     798      problemData.TrainingPartition.Start = 0;
     799      problemData.TrainingPartition.End = problemData.Dataset.Rows;
     800      problemData.TestPartition.Start = 0;
     801      problemData.TestPartition.End = 2; // must not be empty
     802
     803
     804      TestMctsWithoutConstants(problemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
    765805    }
    766806
     
    797837      var @as = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
    798838      var bs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
    799       var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() *1.0e-3).ToList();
    800       var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList();
    801       var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList();
     839      var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() * 1.0e-3).ToList();
     840      var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
     841      var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
    802842      var ys = new double[@as.Count];
    803       for(int i=0;i<ys.Length;i++)
    804         ys[i] = @as[i] + bs[i] + @as[i]*bs[i]*cs[i];
     843      for (int i = 0; i < ys.Length; i++)
     844        ys[i] = @as[i] + bs[i] + @as[i] * bs[i] * cs[i];
    805845
    806846      var dataset = new Dataset(new string[] { "a", "b", "c", "d", "e", "y" }, new[] { @as, bs, cs, ds, es, ys.ToList() });
    807847
    808       var problemData = new RegressionProblemData(dataset, new string[] { "a", "b","c","d","e" }, "y");
     848      var problemData = new RegressionProblemData(dataset, new string[] { "a", "b", "c", "d", "e" }, "y");
    809849
    810850
     
    854894      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234);
    855895      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F7 ")));
    856       TestMcts(regProblem, maxVariableReferences:5, allowExp: false, allowLog: true, allowInv: false);
     896      TestMcts(regProblem, maxVariableReferences: 5, allowExp: false, allowLog: true, allowInv: false);
    857897    }
    858898    [TestMethod]
     
    895935      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234);
    896936      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F11 ")));
    897       TestMcts(regProblem, maxVariableReferences: 5, allowExp: true, allowLog: true, allowInv: false); 
     937      TestMcts(regProblem, maxVariableReferences: 5, allowExp: true, allowLog: true, allowInv: false);
    898938    }
    899939    [TestMethod]
Note: See TracChangeset for help on using the changeset viewer.