Free cookie consent management tool by TermsFeed Policy Generator

Changeset 15416


Ignore:
Timestamp:
10/11/17 08:15:19 (5 years ago)
Author:
gkronber
Message:

#2796 worked on MCTS for symbreg

Location:
branches/MCTS-SymbReg-2796
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis.MCTSSymbReg.csproj

    r15414 r15416  
    2020    <DebugType>full</DebugType>
    2121    <Optimize>false</Optimize>
    22     <OutputPath>bin\</OutputPath>
     22    <OutputPath>..\..\..\..\trunk\sources\bin\</OutputPath>
    2323    <DefineConstants>DEBUG;TRACE</DefineConstants>
    2424    <ErrorReport>prompt</ErrorReport>
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Automaton.cs

    r15414 r15416  
    139139        codeGenerator.Emit1(OpCodes.LoadConst0);
    140140        constraintHandler.Reset();
    141       }, "0, Reset");
     141      }, "0");
    142142      AddTransition(StateTermEnd, StateExprEnd, () => {
    143143        codeGenerator.Emit1(OpCodes.Add);
     
    155155          constraintHandler.StartTerm();
    156156        },
    157         "c, StartTerm");
     157        "c");
    158158      AddTransition(StateFactorEnd, StateTermEnd,
    159159        () => {
     
    161161          constraintHandler.EndTerm();
    162162        },
    163         "*, EndTerm");
     163        "*");
    164164
    165165      AddTransition(StateFactorEnd, StateFactorStart,
     
    172172        AddTransition(StateFactorStart, StateVariableFactorStart, () => {
    173173          constraintHandler.StartFactor(StateVariableFactorStart);
    174         }, "StartFactor");
     174        }, "");
    175175      if (allowExp)
    176176        AddTransition(StateFactorStart, StateExpFactorStart, () => {
    177177          constraintHandler.StartFactor(StateExpFactorStart);
    178         }, "StartFactor");
     178        }, "");
    179179      if (allowLog)
    180180        AddTransition(StateFactorStart, StateLogFactorStart, () => {
    181181          constraintHandler.StartFactor(StateLogFactorStart);
    182         }, "StartFactor");
     182        }, "");
    183183      if (allowInv)
    184184        AddTransition(StateFactorStart, StateInvFactorStart, () => {
    185185          constraintHandler.StartFactor(StateInvFactorStart);
    186         }, "StartFactor");
    187       AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
    188       AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
    189       AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
    190       AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "EndFactor");
     186        }, "");
     187      AddTransition(StateVariableFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
     188      AddTransition(StateExpFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
     189      AddTransition(StateLogFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
     190      AddTransition(StateInvFactorEnd, StateFactorEnd, () => { constraintHandler.EndFactor(); }, "");
    191191
    192192      // VarFact -> var_1 ... var_n
     
    202202            constraintHandler.AddVarToCurrentFactor(varState);
    203203          },
    204           "var_" + varIdx + ", AddVar");
     204          "var_" + varIdx + "");
    205205        AddTransition(curDynVarState, StateVariableFactorEnd);
    206206        curDynVarState++;
     
    234234            constraintHandler.AddVarToCurrentFactor(varState);
    235235          },
    236           "var_" + varIdx + ", AddVar");
     236          "var_" + varIdx + "");
    237237        AddTransition(curDynVarState, StateExpFEnd,
    238238          () => {
     
    249249          constraintHandler.StartNewTermInPoly();
    250250        },
    251         "0, StartTermInPoly");
     251        "0");
    252252      AddTransition(StateLogTEnd, StateLogFactorEnd,
    253253        () => {
     
    289289            constraintHandler.AddVarToCurrentFactor(varState);
    290290          },
    291           "var_" + varIdx + ", AddVar");
     291          "var_" + varIdx + "");
    292292        AddTransition(curDynVarState, StateLogTFEnd);
    293293        curDynVarState++;
     
    300300          constraintHandler.StartNewTermInPoly();
    301301        },
    302         "c, StartTermInPoly");
     302        "c");
    303303      AddTransition(StateInvTEnd, StateInvFactorEnd,
    304304        () => {
     
    338338            constraintHandler.AddVarToCurrentFactor(varState);
    339339          },
    340           "var_" + varIdx + ", AddVar");
     340          "var_" + varIdx + "");
    341341        AddTransition(curDynVarState, StateInvTFEnd);
    342342        curDynVarState++;
     
    442442      codeGenerator.Reset();
    443443      constraintHandler.Reset();
     444    }
     445
     446    internal string GetActionString(int fromState, int toState) {
     447      return actionStrings[fromState,toState] != null ? string.Join(" , ", actionStrings[fromState, toState]) : "";
    444448    }
    445449
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionAlgorithm.cs

    r15403 r15416  
    304304          if (bestQ > bestQuality.Value) {
    305305            bestSolutionIteration.Value = i;
     306            if (state.BestSolutionTrainingQuality > 0.99999) break;
    306307          }
    307308          bestQuality.Value = bestQ;
     
    329330
    330331      // final results (assumes that at least one iteration was calculated)
    331       if (n > 0) {
     332      if (n > 0) {       
    332333        if (bestQ > bestQuality.Value) {
    333334          bestSolutionIteration.Value = iterations.Value + n;
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r15414 r15416  
    6161    // TODO: Solve Poly-10
    6262    // TODO: After state unification the recursive backpropagation of results takes a lot of time. How can this be improved?
     63    // TODO: Why is the algorithm so slow for rather greedy policies (e.g. low C value in UCB)?
     64    // TODO: check if we can use a quality measure with range [-1..1] in policies
    6365    // TODO: unit tests for benchmark problems which contain log / exp / x^-1 but without numeric constants
    6466    // TODO: check if transformation of y is correct and works (Obj 2)
     
    117119      private readonly ExpressionEvaluator evaluator, testEvaluator;
    118120
     121      internal readonly Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
     122      internal readonly Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
     123      internal readonly Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
     124
    119125      // values for best solution
    120       private double bestRSq;
     126      private double bestR;
    121127      private byte[] bestCode;
    122128      private int bestNParams;
     
    187193
    188194        // reset best solution
    189         this.bestRSq = 0;
     195        this.bestR = 0;
    190196        // code for default solution (constant model)
    191197        this.bestCode = new byte[] { (byte)OpCodes.LoadConst0, (byte)OpCodes.Exit };
     
    208214        get {
    209215          evaluator.Exec(bestCode, x, bestConsts, predBuf);
    210           return RSq(y, predBuf);
     216          return Rho(y, predBuf);
    211217        }
    212218      }
     
    215221        get {
    216222          testEvaluator.Exec(bestCode, testX, bestConsts, testPredBuf);
    217           return RSq(testY, testPredBuf);
     223          return Rho(testY, testPredBuf);
    218224        }
    219225      }
     
    248254
    249255        // single objective best
    250         if (q > bestRSq) {
    251           bestRSq = q;
     256        if (q > bestR) {
     257          bestR = q;
    252258          bestNParams = nParams;
    253259          this.bestCode = new byte[code.Length];
     
    266272      }
    267273
    268       private void Eval(byte[] code, int nParams, out double rsq, out double[] optConsts) {
     274      private void Eval(byte[] code, int nParams, out double rho, out double[] optConsts) {
    269275        // we make a first pass to determine a valid starting configuration for all constants
    270276        // constant c in log(c + f(x)) is adjusted to guarantee that x is positive (see expression evaluator)
     
    280286          // if we don't need to optimize parameters then we are done
    281287          // changing scale and offset does not influence r²
    282           rsq = RSq(y, predBuf);
     288          rho = Rho(y, predBuf);
    283289          optConsts = constsBuf;
    284290        } else {
     
    289295          funcEvaluations++;
    290296
    291           rsq = RSq(y, predBuf);
     297          rho = Rho(y, predBuf);
    292298          optConsts = constsBuf;
    293299        }
     
    297303
    298304      #region helpers
    299       private static double RSq(IEnumerable<double> x, IEnumerable<double> y) {
     305      private static double Rho(IEnumerable<double> x, IEnumerable<double> y) {
    300306        OnlineCalculatorError error;
    301307        double r = OnlinePearsonsRCalculator.Calculate(x, y, out error);
    302         return error == OnlineCalculatorError.None ? r * r : 0.0;
     308        return error == OnlineCalculatorError.None ? r : 0.0;
    303309      }
    304310
     
    491497      do {
    492498        automaton.Reset();
    493         success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, out q);
     499        success = TryTreeSearchRec2(rand, tree, automaton, eval, treePolicy, mctsState, out q);
    494500        mctsState.totalRollouts++;
    495501      } while (!success && !tree.Done);
    496502      mctsState.effectiveRollouts++;
    497503
    498       if (mctsState.effectiveRollouts % 10 == 1) {
    499         //Console.WriteLine(WriteTree(tree));
    500         //Console.WriteLine(TraceTree(tree));
    501       }
     504      //if (mctsState.effectiveRollouts % 100 == 1) {
     505        // Console.WriteLine(WriteTree(tree, mctsState));
     506        // Console.WriteLine(TraceTree(tree, mctsState));
     507      //}
    502508      return q;
    503509    }
    504 
    505     private static Dictionary<Tree, List<Tree>> children = new Dictionary<Tree, List<Tree>>();
    506     private static Dictionary<Tree, List<Tree>> parents = new Dictionary<Tree, List<Tree>>();
    507     private static Dictionary<ulong, Tree> nodes = new Dictionary<ulong, Tree>();
    508 
    509510
    510511
    511512    // search forward
    512513    private static bool TryTreeSearchRec2(IRandom rand, Tree tree, Automaton automaton, Func<byte[], int, double> eval, IPolicy treePolicy,
     514      State state,
    513515      out double q) {
    514516      // ROLLOUT AND EXPANSION
     
    527529
    528530      while (!automaton.IsFinalState(automaton.CurrentState)) {
    529         if (children.ContainsKey(tree)) {
    530           if (children[tree].All(ch => ch.Done)) {
     531        if (state.children.ContainsKey(tree)) {
     532          if (state.children[tree].All(ch => ch.Done)) {
    531533            tree.Done = true;
    532534            break;
     
    535537          // UCT selection within tree
    536538          int selectedIdx = 0;
    537           if (children[tree].Count > 1) {
    538             selectedIdx = treePolicy.Select(children[tree].Select(ch => ch.actionStatistics), rand);
    539           }
    540           tree = children[tree][selectedIdx];
     539          if (state.children[tree].Count > 1) {
     540            selectedIdx = treePolicy.Select(state.children[tree].Select(ch => ch.actionStatistics), rand);
     541          }
     542          tree = state.children[tree][selectedIdx];
    541543
    542544          // move the automaton forward until reaching the state
     
    556558          int[] possibleFollowStates;
    557559          int nFs;
     560          string actionString = "";
    558561          automaton.FollowStates(automaton.CurrentState, out possibleFollowStates, out nFs);
    559562          while (nFs == 1 && !automaton.IsEvalState(possibleFollowStates[0]) && !automaton.IsFinalState(possibleFollowStates[0])) {
     563            actionString += " " + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[0]);
    560564            // no alternatives -> just go to the next state
    561565            automaton.Goto(possibleFollowStates[0]);
     
    568572          }
    569573          var newChildren = new List<Tree>(nFs);
    570           children.Add(tree, newChildren);
     574          state.children.Add(tree, newChildren);
    571575          for (int i = 0; i < nFs; i++) {
    572576            Tree child = null;
     
    574578            if (automaton.IsEvalState(possibleFollowStates[i])) {
    575579              var hc = Hashcode(automaton);
    576               if (!nodes.TryGetValue(hc, out child)) {
     580              if (!state.nodes.TryGetValue(hc, out child)) {
    577581                child = new Tree() {
    578582                  children = null,
    579583                  state = possibleFollowStates[i],
    580584                  actionStatistics = treePolicy.CreateActionStatistics(),
    581                   expr = string.Empty, // ExprStr(automaton),
     585                  expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    582586                  level = tree.level + 1
    583587                };
    584                 nodes.Add(hc, child);
     588                state.nodes.Add(hc, child);
    585589              }
    586590              // only allow forward edges (don't add the child if we would go back in the graph)
    587               else if (child.level > tree.level) {
     591              else if (child.level > tree.level)  {
    588592                // whenever we join paths we need to propagate back the statistics of the existing node through the newly created link
    589593                // to all parents
    590                 BackpropagateStatistics(child.actionStatistics, tree);
     594                BackpropagateStatistics(child.actionStatistics, tree, state);
    591595              } else {
    592596                // prevent cycles
    593597                Debug.Assert(child.level <= tree.level);
    594598                child = null;
    595               }
     599              }   
    596600            } else {
    597601              child = new Tree() {
     
    599603                state = possibleFollowStates[i],
    600604                actionStatistics = treePolicy.CreateActionStatistics(),
    601                 expr = string.Empty, // ExprStr(automaton),
     605                expr = actionString + automaton.GetActionString(automaton.CurrentState, possibleFollowStates[i]),
    602606                level = tree.level + 1
    603607              };
     
    614618
    615619          foreach (var ch in newChildren) {
    616             if (!parents.ContainsKey(ch)) {
    617               parents.Add(ch, new List<Tree>());
     620            if (!state.parents.ContainsKey(ch)) {
     621              state.parents.Add(ch, new List<Tree>());
    618622            }
    619             parents[ch].Add(tree);
    620           }
    621 
    622 
    623           // follow one of the children
    624           tree = SelectFinalOrRandom2(automaton, tree, rand);
     623            state.parents[ch].Add(tree);
     624          }
     625
     626
     627          // follow one of the children 
     628          tree = SelectStateLeadingToFinal(automaton, tree, rand, state);
    625629          automaton.Goto(tree.state);
    626630        }
     
    636640        automaton.GetCode(out code, out nParams);
    637641        q = eval(code, nParams);
     642        // Console.WriteLine("{0:N4}\t{1}", q*q, tree.expr);
    638643        q = TransformQuality(q);
    639644        success = true;
    640645      } else {
    641646        // we got stuck in roll-out (not evaluation necessary!)
     647        // Console.WriteLine("\t" + ExprStr(automaton) + " STOP");
    642648        q = 0.0;
    643649        success = false;
     
    647653      // Update statistics
    648654      // Set branch to done if all children are done.
    649       BackpropagateQuality(tree, q, treePolicy);
     655      BackpropagateQuality(tree, q, treePolicy, state);
    650656
    651657      return success;
     
    655661    private static double TransformQuality(double q) {
    656662      // no transformation
    657       return q;
     663      // return q;
    658664
    659665      // EXPERIMENTAL!
     666
     667      // Fisher transformation
     668      // (assumes q is Correl(pred, target)
     669
     670      q = Math.Min(q,  0.99999999);
     671      q = Math.Max(q, -0.99999999);
     672      return 0.5 * Math.Log((1 + q) / (1 - q));
     673
    660674      // optimal result: q = 1 -> return huge value
    661675      // if (q >= 1.0) return 1E16;
     
    665679
    666680    // backpropagate existing statistics to all parents
    667     private static void BackpropagateStatistics(IActionStatistics stats, Tree tree) {
     681    private static void BackpropagateStatistics(IActionStatistics stats, Tree tree, State state) {
    668682      tree.actionStatistics.Add(stats);
    669       if (parents.ContainsKey(tree)) {
    670         foreach (var parent in parents[tree]) {
    671           BackpropagateStatistics(stats, parent);
     683      if (state.parents.ContainsKey(tree)) {
     684        foreach (var parent in state.parents[tree]) {
     685          BackpropagateStatistics(stats, parent, state);
    672686        }
    673687      }
     
    681695    }
    682696
    683     private static void BackpropagateQuality(Tree tree, double q, IPolicy policy) {
     697    private static void BackpropagateQuality(Tree tree, double q, IPolicy policy, State state) {
    684698      if (q > 0) policy.Update(tree.actionStatistics, q);
    685       if (children.ContainsKey(tree) && children[tree].All(ch => ch.Done)) {
     699      if (state.children.ContainsKey(tree) && state.children[tree].All(ch => ch.Done)) {
    686700        tree.Done = true;
    687701        // children[tree] = null; keep all nodes
    688702      }
    689703
    690       if (parents.ContainsKey(tree)) {
    691         foreach (var parent in parents[tree]) {
    692           BackpropagateQuality(parent, q, policy);
    693         }
    694       }
    695     }
    696 
    697     private static Tree SelectFinalOrRandom2(Automaton automaton, Tree tree, IRandom rand) {
    698       // if one of the new children leads to a final state then go there
    699       // otherwise choose a random child
    700       int selectedChildIdx = -1;
    701       // find first final state if there is one
    702       var children = MctsSymbolicRegressionStatic.children[tree];
    703       for (int i = 0; i < children.Count; i++) {
    704         if (automaton.IsFinalState(children[i].state)) {
     704      if (state.parents.ContainsKey(tree)) {
     705        foreach (var parent in state.parents[tree]) {
     706          BackpropagateQuality(parent, q, policy, state);
     707        }
     708      }
     709    }
     710
     711    private static Tree SelectStateLeadingToFinal(Automaton automaton, Tree tree, IRandom rand, State state) {
     712      // find the child with the smallest state value (smaller values are closer to the final state)
     713      int selectedChildIdx = 0;
     714      var children = state.children[tree];
     715      Tree minChild = children.First();
     716      for (int i = 1; i < children.Count; i++) {
     717        if(children[i].state < minChild.state)
    705718          selectedChildIdx = i;
    706           break;
    707         }
    708       }
    709       // no final state -> select the first child
    710       if (selectedChildIdx == -1) {
    711         selectedChildIdx = 0;
    712719      }
    713720      return children[selectedChildIdx];
     
    865872    }
    866873
    867     private static string WriteStatistics(Tree tree) {
     874    private static string WriteStatistics(Tree tree, State state) {
    868875      var sb = new System.IO.StringWriter();
    869876      sb.WriteLine("{0} {1:N5}", tree.actionStatistics.Tries, tree.actionStatistics.AverageQuality);
    870       if (children.ContainsKey(tree)) {
    871         foreach (var ch in children[tree]) {
     877      if (state.children.ContainsKey(tree)) {
     878        foreach (var ch in state.children[tree]) {
    872879          sb.WriteLine("{0} {1:N5}", ch.actionStatistics.Tries, ch.actionStatistics.AverageQuality);
    873880        }
     
    876883    }
    877884
    878     private static string TraceTree(Tree tree) {
     885    private static string TraceTree(Tree tree, State state) {
    879886      var sb = new StringBuilder();
    880887      sb.Append(
     
    885892      int nodeId = 0;
    886893
    887       TraceTreeRec(tree, 0, sb, ref nodeId);
     894      TraceTreeRec(tree, 0, sb, ref nodeId, state);
    888895      sb.Append("}");
    889896      return sb.ToString();
    890897    }
    891898
    892     private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId) {
     899    private static void TraceTreeRec(Tree tree, int parentId, StringBuilder sb, ref int nextId, State state) {
    893900      var avgNodeQ = tree.actionStatistics.AverageQuality;
    894901      var tries = tree.actionStatistics.Tries;
    895902      if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    896903      var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    897 
    898       sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
     904      hue = 0.0;
     905
     906      sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue).AppendLine();
    899907
    900908      var list = new List<Tuple<int, int, Tree>>();
    901       if (children.ContainsKey(tree)) {
    902         foreach (var ch in children[tree]) {
     909      if (state.children.ContainsKey(tree)) {
     910        foreach (var ch in state.children[tree]) {
    903911          nextId++;
    904912          avgNodeQ = ch.actionStatistics.AverageQuality;
     
    906914          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    907915          hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
    908           sb.AppendFormat("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
    909           sb.AppendFormat("{0} -> {1}", parentId, nextId, avgNodeQ).AppendLine();
     916          hue = 0.0;
     917          sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
     918          sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", parentId, nextId, avgNodeQ, ch.expr).AppendLine();
    910919          list.Add(Tuple.Create(tries, nextId, ch));
    911920        }
     921
     922        foreach(var tup in list) {
     923          var ch = tup.Item3;
     924          var chId = tup.Item2;
     925          if(state.children.ContainsKey(ch) && state.children[ch].Count == 1) {
     926            var chch = state.children[ch].First();
     927            nextId++;
     928            avgNodeQ = chch.actionStatistics.AverageQuality;
     929            tries = chch.actionStatistics.Tries;
     930            if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
     931            hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     932            hue = 0.0;
     933            sb.AppendFormat("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", nextId, avgNodeQ, tries, hue).AppendLine();
     934            sb.AppendFormat("{0} -> {1} [label=\"{3}\"]", chId, nextId, avgNodeQ, chch.expr).AppendLine();
     935          }
     936        }
     937
    912938        foreach (var tup in list.OrderByDescending(t => t.Item1).Take(1)) {
    913           TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId);
    914         }
    915       }
    916     }
    917 
    918     private static string WriteTree(Tree tree) {
     939          TraceTreeRec(tup.Item3, tup.Item2, sb, ref nextId, state);
     940        }
     941      }
     942    }
     943
     944    private static string WriteTree(Tree tree, State state) {
    919945      var sb = new System.IO.StringWriter(System.Globalization.CultureInfo.InvariantCulture);
    920946      var nodeIds = new Dictionary<Tree, int>();
     
    924950  node [style=filled];
    925951");
    926       int threshold = nodes.Count > 500 ? 10 : 0;
    927       foreach (var kvp in children) {
     952      int threshold = /* state.nodes.Count > 500 ? 10 : */ 0;
     953      foreach (var kvp in state.children) {
    928954        var parent = kvp.Key;
    929955        int parentId;
     
    934960          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    935961          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     962          hue = 0.0;
    936963          if (parent.actionStatistics.Tries > threshold)
    937             sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
     964            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", parentId, avgNodeQ, tries, hue);
    938965          nodeIds.Add(parent, parentId);
    939966        }
     
    949976          if (double.IsNaN(avgNodeQ)) avgNodeQ = 0.0;
    950977          var hue = (1 - avgNodeQ) / 360.0 * 240.0; // 0 equals red, 240 equals blue
     978          hue = 0.0;
    951979          if (tries > threshold) {
    952             sb.Write("{0} [label=\"{1:N3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
     980            sb.Write("{0} [label=\"{1:E3} {2}\" color=\"{3:N3} 0.999 0.999\"]; ", childId, avgNodeQ, tries, hue);
    953981            var edgeLabel = child.expr;
    954982            // if (parent.expr.Length > 0) edgeLabel = edgeLabel.Replace(parent.expr, "");
  • branches/MCTS-SymbReg-2796/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Policies/Ucb.cs

    r15414 r15416  
    7878        else totalTries += a.Tries;
    7979      }
    80       // if there are unvisited actions select a random action
     80      // if there are unvisited actions select a random unvisited action
    8181      if (buf.Any()) {
    8282        return buf[rand.Next(buf.Count)];
    8383      }
     84                     
    8485      Debug.Assert(totalTries > 0);
    8586      double logTotalTries = Math.Log(totalTries);
  • branches/MCTS-SymbReg-2796/Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/MctsSymbolicRegressionTest.cs

    r15414 r15416  
    77using HeuristicLab.Problems.DataAnalysis;
    88using HeuristicLab.Problems.Instances.DataAnalysis;
     9using HeuristicLab.Random;
    910using Microsoft.VisualStudio.TestTools.UnitTesting;
    1011
     
    550551      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.NguyenInstanceProvider(seed: 1234);
    551552      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("F7 ")));
    552       TestMctsWithoutConstants(regProblem, nVarRefs: 10, iterations: 1000000, allowExp: false, allowLog: true, allowInv: false);
    553     }
     553      TestMctsWithoutConstants(regProblem, nVarRefs: 10, iterations: 100000, allowExp: false, allowLog: true, allowInv: false);
     554    }
     555
     556    [TestMethod]
     557    [TestCategory("Algorithms.DataAnalysis")]
     558    [TestProperty("Time", "short")]
     559    public void MctsSymbReg_NoConstants_Poly10_Part1() {
     560      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     561      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     562
     563      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     564      //  Y' = X1*X2 + X3*X4 + X5*X6
     565      // simplify problem by changing target
     566      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     567      var ys = ds.GetDoubleValues("Y").ToArray();
     568      var x1 = ds.GetDoubleValues("X1").ToArray();
     569      var x2 = ds.GetDoubleValues("X2").ToArray();
     570      var x3 = ds.GetDoubleValues("X3").ToArray();
     571      var x4 = ds.GetDoubleValues("X4").ToArray();
     572      var x5 = ds.GetDoubleValues("X5").ToArray();
     573      var x6 = ds.GetDoubleValues("X6").ToArray();
     574      var x7 = ds.GetDoubleValues("X7").ToArray();
     575      var x8 = ds.GetDoubleValues("X8").ToArray();
     576      var x9 = ds.GetDoubleValues("X9").ToArray();
     577      var x10 = ds.GetDoubleValues("X10").ToArray();
     578      for (int i = 0; i < ys.Length; i++) {
     579        ys[i] -= x1[i] * x7[i] * x9[i];
     580        ys[i] -= x3[i] * x6[i] * x10[i];
     581      }
     582      ds.ReplaceVariable("Y", ys.ToList());
     583
     584      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     585
     586
     587      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     588    }
     589
     590    [TestMethod]
     591    [TestCategory("Algorithms.DataAnalysis")]
     592    [TestProperty("Time", "short")]
     593    public void MctsSymbReg_NoConstants_Poly10_Part2() {
     594      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     595      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     596
     597      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     598      //  Y' = X1*X7*X9 + X3*X6*X10
     599      // simplify problem by changing target
     600      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     601      var ys = ds.GetDoubleValues("Y").ToArray();
     602      var x1 = ds.GetDoubleValues("X1").ToArray();
     603      var x2 = ds.GetDoubleValues("X2").ToArray();
     604      var x3 = ds.GetDoubleValues("X3").ToArray();
     605      var x4 = ds.GetDoubleValues("X4").ToArray();
     606      var x5 = ds.GetDoubleValues("X5").ToArray();
     607      var x6 = ds.GetDoubleValues("X6").ToArray();
     608      var x7 = ds.GetDoubleValues("X7").ToArray();
     609      var x8 = ds.GetDoubleValues("X8").ToArray();
     610      var x9 = ds.GetDoubleValues("X9").ToArray();
     611      var x10 = ds.GetDoubleValues("X10").ToArray();
     612      for (int i = 0; i < ys.Length; i++) {
     613        ys[i] -= x1[i] * x2[i];
     614        ys[i] -= x3[i] * x4[i];
     615        ys[i] -= x5[i] * x6[i];
     616      }
     617      ds.ReplaceVariable("Y", ys.ToList());
     618
     619      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     620
     621
     622      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     623    }
     624
     625    [TestMethod]
     626    [TestCategory("Algorithms.DataAnalysis")]
     627    [TestProperty("Time", "short")]
     628    public void MctsSymbReg_NoConstants_Poly10_Part3() {
     629      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     630      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     631
     632      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     633      //  Y' = X1*X2 + X1*X7*X9
     634      // simplify problem by changing target
     635      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     636      var ys = ds.GetDoubleValues("Y").ToArray();
     637      var x1 = ds.GetDoubleValues("X1").ToArray();
     638      var x2 = ds.GetDoubleValues("X2").ToArray();
     639      var x3 = ds.GetDoubleValues("X3").ToArray();
     640      var x4 = ds.GetDoubleValues("X4").ToArray();
     641      var x5 = ds.GetDoubleValues("X5").ToArray();
     642      var x6 = ds.GetDoubleValues("X6").ToArray();
     643      var x7 = ds.GetDoubleValues("X7").ToArray();
     644      var x8 = ds.GetDoubleValues("X8").ToArray();
     645      var x9 = ds.GetDoubleValues("X9").ToArray();
     646      var x10 = ds.GetDoubleValues("X10").ToArray();
     647      for (int i = 0; i < ys.Length; i++) {
     648        ys[i] -= x3[i] * x4[i];
     649        ys[i] -= x5[i] * x6[i];
     650        ys[i] -= x3[i] * x6[i] * x10[i];
     651      }
     652      ds.ReplaceVariable("Y", ys.ToList());
     653
     654      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     655
     656
     657      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     658    }
     659
     660    [TestMethod]
     661    [TestCategory("Algorithms.DataAnalysis")]
     662    [TestProperty("Time", "short")]
     663    public void MctsSymbReg_NoConstants_Poly10_Part4() {
     664      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     665      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     666
     667      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     668      //  Y' = X3*X4 + X5*X6 + X3*X6*X10
     669      // simplify problem by changing target
     670      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     671      var ys = ds.GetDoubleValues("Y").ToArray();
     672      var x1 = ds.GetDoubleValues("X1").ToArray();
     673      var x2 = ds.GetDoubleValues("X2").ToArray();
     674      var x3 = ds.GetDoubleValues("X3").ToArray();
     675      var x4 = ds.GetDoubleValues("X4").ToArray();
     676      var x5 = ds.GetDoubleValues("X5").ToArray();
     677      var x6 = ds.GetDoubleValues("X6").ToArray();
     678      var x7 = ds.GetDoubleValues("X7").ToArray();
     679      var x8 = ds.GetDoubleValues("X8").ToArray();
     680      var x9 = ds.GetDoubleValues("X9").ToArray();
     681      var x10 = ds.GetDoubleValues("X10").ToArray();
     682      for (int i = 0; i < ys.Length; i++) {
     683        ys[i] -= x1[i] * x2[i];
     684        ys[i] -= x1[i] * x7[i] * x9[i];
     685      }
     686      ds.ReplaceVariable("Y", ys.ToList());
     687      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     688
     689
     690      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     691    }
     692
     693    [TestMethod]
     694    [TestCategory("Algorithms.DataAnalysis")]
     695    [TestProperty("Time", "short")]
     696    public void MctsSymbReg_NoConstants_Poly10_Part5() {
     697      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     698      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     699
     700      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     701      //  Y' = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9
     702      // simplify problem by changing target
     703      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     704      var ys = ds.GetDoubleValues("Y").ToArray();
     705      var x1 = ds.GetDoubleValues("X1").ToArray();
     706      var x2 = ds.GetDoubleValues("X2").ToArray();
     707      var x3 = ds.GetDoubleValues("X3").ToArray();
     708      var x4 = ds.GetDoubleValues("X4").ToArray();
     709      var x5 = ds.GetDoubleValues("X5").ToArray();
     710      var x6 = ds.GetDoubleValues("X6").ToArray();
     711      var x7 = ds.GetDoubleValues("X7").ToArray();
     712      var x8 = ds.GetDoubleValues("X8").ToArray();
     713      var x9 = ds.GetDoubleValues("X9").ToArray();
     714      var x10 = ds.GetDoubleValues("X10").ToArray();
     715      for (int i = 0; i < ys.Length; i++) {
     716        ys[i] -= x3[i] * x6[i] * x10[i];
     717      }
     718      ds.ReplaceVariable("Y", ys.ToList());
     719      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     720
     721
     722      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 15, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     723    }
     724
     725    [TestMethod]
     726    [TestCategory("Algorithms.DataAnalysis")]
     727    [TestProperty("Time", "short")]
     728    public void MctsSymbReg_NoConstants_Poly10_Part6() {
     729      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
     730      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
     731
     732      //  Y = X1*X2 + X3*X4 + X5*X6 + X1*X7*X9 + X3*X6*X10
     733      //  Y' = X1*X2 + X3*X4 + X5*X6 + X3*X6*X10
     734      // simplify problem by changing target
     735      var ds = ((Dataset)regProblem.Dataset).ToModifiable();
     736      var ys = ds.GetDoubleValues("Y").ToArray();
     737      var x1 = ds.GetDoubleValues("X1").ToArray();
     738      var x2 = ds.GetDoubleValues("X2").ToArray();
     739      var x3 = ds.GetDoubleValues("X3").ToArray();
     740      var x4 = ds.GetDoubleValues("X4").ToArray();
     741      var x5 = ds.GetDoubleValues("X5").ToArray();
     742      var x6 = ds.GetDoubleValues("X6").ToArray();
     743      var x7 = ds.GetDoubleValues("X7").ToArray();
     744      var x8 = ds.GetDoubleValues("X8").ToArray();
     745      var x9 = ds.GetDoubleValues("X9").ToArray();
     746      var x10 = ds.GetDoubleValues("X10").ToArray();
     747      for (int i = 0; i < ys.Length; i++) {
     748        ys[i] -= x1[i] * x7[i] * x9[i];
     749      }
     750      ds.ReplaceVariable("Y", ys.ToList());
     751      var modifiedProblemData = new RegressionProblemData(ds, regProblem.AllowedInputVariables, regProblem.TargetVariable);
     752
     753
     754      TestMctsWithoutConstants(modifiedProblemData, nVarRefs: 9, iterations: 100000, allowExp: false, allowLog: false, allowInv: false);
     755    }
     756
    554757
    555758    [TestMethod]
     
    559762      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.VariousInstanceProvider(seed: 1234);
    560763      var regProblem = provider.LoadData(provider.GetDataDescriptors().Single(x => x.Name.Contains("Poly-10")));
    561       TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 1000000, allowExp: false, allowLog: false, allowInv: false);
     764      TestMctsWithoutConstants(regProblem, nVarRefs: 15, iterations: 200000, allowExp: false, allowLog: false, allowInv: false);
     765    }
     766
     767    [TestMethod]
     768    [TestCategory("Algorithms.DataAnalysis")]
     769    [TestProperty("Time", "short")]
     770    public void MctsSymbReg_NoConstants_TwoVars() {
     771
     772      // y = x1 + x2 + x1*x2 + x1*x2*x2 + x1*x1*x2
     773      var rand = new FastRandom(1234);
     774      var x1 = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
     775      var x2 = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
     776      var ys = x1.Zip(x2, (x1i, x2i) => x1i + x2i + x1i * x2i + x1i * x2i * x2i + x1i * x1i * x2i).ToList();
     777
     778      var ds = new Dataset(new string[] { "a", "b", "y" }, new[] { x1, x2, ys });
     779
     780      var problemData = new RegressionProblemData(ds, new string[] { "a", "b" }, "y");
     781
     782
     783      TestMctsWithoutConstants(problemData, nVarRefs: 10, iterations: 10000, allowExp: false, allowLog: false, allowInv: false);
     784    }
     785
     786    [TestMethod]
     787    [TestCategory("Algorithms.DataAnalysis")]
     788    [TestProperty("Time", "short")]
     789    public void MctsSymbReg_NoConstants_Misleading() {
     790
     791      // y = a + baaaaa (the effect of the second term should be very small)
     792      // the alg will quickly find that a has big effect and will search below a
     793      // since we prevent a + a... the algorithm must find the correct expression via a + b...
     794      // however b has a small effect so the branch might not be identified as relevant
     795
     796      var rand = new FastRandom(1234);
     797      var @as = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
     798      var bs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble()).ToList();
     799      var cs = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() *1.0e-3).ToList();
     800      var ds = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList();
     801      var es = Enumerable.Range(0, 100).Select(_ => rand.NextDouble() ).ToList();
     802      var ys = new double[@as.Count];
     803      for(int i=0;i<ys.Length;i++)
     804        ys[i] = @as[i] + bs[i] + @as[i]*bs[i]*cs[i];
     805
     806      var dataset = new Dataset(new string[] { "a", "b", "c", "d", "e", "y" }, new[] { @as, bs, cs, ds, es, ys.ToList() });
     807
     808      var problemData = new RegressionProblemData(dataset, new string[] { "a", "b","c","d","e" }, "y");
     809
     810
     811      TestMctsWithoutConstants(problemData, nVarRefs: 10, iterations: 10000, allowExp: false, allowLog: false, allowInv: false);
    562812    }
    563813    #endregion
     
    8221072      mctsSymbReg.ConstantOptimizationIterations = -1;
    8231073
     1074      // random policy
     1075      // var epsPolicy = new EpsilonGreedy();
     1076      // epsPolicy.Eps = 1.0;
     1077      // mctsSymbReg.Policy = epsPolicy;
     1078
     1079      // UCB tuned
     1080      // var ucbTuned = new UcbTuned();
     1081      // ucbTuned.C = 1.5;
     1082      // mctsSymbReg.Policy = ucbTuned;
     1083
     1084
    8241085      #endregion
    8251086      RunAlgorithm(mctsSymbReg);
  • branches/MCTS-SymbReg-2796/Tests/Test.csproj

    r15403 r15416  
    6363      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Problems.Instances.DataAnalysis-3.3.dll</HintPath>
    6464    </Reference>
     65    <Reference Include="HeuristicLab.Random-3.3, Version=3.3.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     66      <SpecificVersion>False</SpecificVersion>
     67      <HintPath>..\..\..\trunk\sources\bin\HeuristicLab.Random-3.3.dll</HintPath>
     68    </Reference>
    6569    <Reference Include="System" />
    6670  </ItemGroup>
Note: See TracChangeset for help on using the changeset viewer.