Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/28/15 14:57:21 (9 years ago)
Author:
gkronber
Message:

#2471

  • refactoring to use state value function V(s) instead of state/action value function Q(s,a)
  • added test case for artificial ant problem
Location:
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3
Files:
5 added
5 deleted
13 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction-3.3.csproj

    r12922 r12923  
    155155  <ItemGroup>
    156156    <Compile Include="Interfaces\IStateFunction.cs" />
    157     <Compile Include="Interfaces\IQualityFunction.cs" />
    158     <Compile Include="Interfaces\ITabularQualityFunction.cs" />
     157    <Compile Include="Interfaces\IStateValueFunction.cs" />
     158    <Compile Include="Interfaces\ITabularStateValueFunction.cs" />
    159159    <Compile Include="Interfaces\ISymbolicExpressionConstructionPolicy.cs" />
    160160    <Compile Include="IteratedSymbolicExpressionConstruction.cs" />
     
    169169    </Compile>
    170170    <Compile Include="Properties\AssemblyInfo.cs" />
    171     <Compile Include="QualityFunctions\TabularQualityFunctionBase.cs" />
    172     <Compile Include="QualityFunctions\TabularAvgQualityFunction.cs" />
    173     <Compile Include="QualityFunctions\TabularMaxQualityFunction.cs" />
     171    <Compile Include="QualityFunctions\TabularStateValueFunctionBase.cs" />
     172    <Compile Include="QualityFunctions\TabularAvgStateValueFunction.cs" />
     173    <Compile Include="QualityFunctions\TabularMaxStateValueFunction.cs" />
    174174    <Compile Include="SearchTree.cs" />
    175175    <Compile Include="StateFunctions\ParentChildStateFunction.cs" />
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Interfaces/IStateFunction.cs

    r12909 r12923  
    1010  // creates a state from the information available at sequential derivation steps of symbolic expression trees
    1111  public interface IStateFunction : IItem {
    12     object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx);
     12    object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx);
    1313  }
    1414}
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Interfaces/ISymbolicExpressionConstructionPolicy.cs

    r12909 r12923  
    99  public interface ISymbolicExpressionConstructionPolicy : IItem {
    1010    void Initialize(SymbolicExpressionTreeProblem problem, IRandom random);
    11     ISymbolicExpressionTree Sample(out IEnumerable<Tuple<object, int>> stateActionSequence);
    12     void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality);
     11    ISymbolicExpressionTree Sample(out IEnumerable<object> stateSequence);
     12    void Update(IEnumerable<object> stateSequence, double quality);
    1313  }
    1414}
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/IteratedSymbolicExpressionConstruction.cs

    r12909 r12923  
    147147      random.Reset(Seed);
    148148
    149       //var policy = new RandomSymbolicExpressionConstructionPolicy(Problem, random);
    150       //var policy = new EpsGreedySymbolicExpressionConstructionPolicy<string>(Problem, random, new TabularMaxQualityFunction<string>(new DefaultStateFunction()));
    151       //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularAvgQualityFunction<string>(new DefaultStateFunction()));
    152       //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularMaxQualityFunction<string>(new ParentChildStateFunction()));
    153       //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularAvgQualityFunction<string>(new ParentChildStateFunction()));
    154       //var policy = new UcbSymbolicExpressionConstructionPolicy<string>(Problem, random, new TabularMaxQualityFunction<string>(new DefaultStateFunction()));
    155       //var policy = new UcbWithStateAggregationSymbolicExpressionConstructionPolicy(Problem, random, 40);
    156 
    157149      var policy = PolicyParameter.Value;
    158150      policy.Initialize(Problem, random);
     
    186178      double sumQuality = 0; // for average quality calculation
    187179      int resultUpdateInterval = ResultUpdateInterval;
    188       while (evals < MaximumEvaluations) {
    189         double quality = double.NaN;
    190         ISymbolicExpressionTree tree = null;
    191         IEnumerable<Tuple<object, int>> actionSequence;
    192         tree = policy.Sample(out actionSequence);
    193         quality = Problem.Evaluate(tree, random);
    194         evals++;
    195         sumQuality += quality;
    196 
    197         policy.Update(actionSequence, quality);
    198 
    199         // update statistics results in regular update intervals
    200         if ((evals - 1) % resultUpdateInterval == resultUpdateInterval - 1) {
    201           evaluations.Value = evals;
    202           bestQualityRow.Values.Add(bestQuality.Value);
    203           currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval);
    204           sumQuality = 0;
     180      try {
     181        while (evals < MaximumEvaluations) {
     182          double quality = double.NaN;
     183          ISymbolicExpressionTree tree = null;
     184          IEnumerable<object> stateSequence;
     185          tree = policy.Sample(out stateSequence);
     186          quality = Problem.Evaluate(tree, random);
     187          evals++;
     188          sumQuality += quality;
     189
     190          policy.Update(stateSequence, quality);
     191          cancellationToken.ThrowIfCancellationRequested();
     192
     193          // update statistics results in regular update intervals
     194          if ((evals - 1) % resultUpdateInterval == resultUpdateInterval - 1) {
     195            evaluations.Value = evals;
     196            bestQualityRow.Values.Add(bestQuality.Value);
     197            currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval);
     198            sumQuality = 0;
     199          }
     200
     201          // update best solution results whenever a new better solution is found
     202          if (Problem.IsBetter(quality, bestQuality.Value)) {
     203            bestQuality.Value = quality;
     204            bestFoundOnEvaluation.Value = evals;
     205
     206            // for problem-specific analyzer
     207            solutions[0] = tree;
     208            qualities[0] = quality;
     209          }
     210
     211          // run problem-specific analyzer in each iteration
     212          Problem.Analyze(solutions, qualities, Results, random);
    205213        }
    206 
    207         // update best solution results whenever a new better solution is found
    208         if (Problem.IsBetter(quality, bestQuality.Value)) {
    209           bestQuality.Value = quality;
    210           bestFoundOnEvaluation.Value = evals;
    211 
    212           // for problem-specific analyzer
    213           solutions[0] = tree;
    214           qualities[0] = quality;
    215         }
    216 
    217         // run problem-specific analyzer in each iteration
    218         Problem.Analyze(solutions, qualities, Results, random);
    219 
    220         cancellationToken.ThrowIfCancellationRequested();
     214      } finally {
     215        // update stats whenever the alg is stopped
     216        evaluations.Value = evals;
     217        bestQualityRow.Values.Add(bestQuality.Value);
     218        currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval);
    221219      }
    222220    }
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/BoltzmannExplorationSymbolicExpressionConstructionPolicy.cs

    r12909 r12923  
    2020    }
    2121
    22     public ITabularQualityFunction QualityFunction {
     22    public ITabularStateValueFunction StateValueFunction {
    2323      get {
    24         return ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value;
     24        return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value;
    2525      }
    26       set { ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value = value; }
     26      set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; }
    2727    }
    2828
     
    3939      : base() {
    4040      Parameters.Add(new FixedValueParameter<DoubleValue>("Beta", "The weighting factor beta", new DoubleValue(1.0)));
    41       Parameters.Add(new ValueParameter<ITabularQualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));
     41      Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction()));
    4242    }
    4343
    44     protected sealed override int Select(object state, IEnumerable<int> actions, IRandom random) {
    45 
     44    protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) {
     45      var idxs = Enumerable.Range(0, followStates.Count);
    4646      // find best action
    47       var bestActions = new List<int>();
    4847      var bestQuality = double.NegativeInfinity;
    49       if (actions.Any(a => QualityFunction.Tries(state, a) == 0)) {
    50         return actions.Where(a => QualityFunction.Tries(state, a) == 0).SampleRandom(random, 1).First();
     48      if (followStates.Any(s => StateValueFunction.Tries(s) == 0)) {
     49        return idxs.Where(idx => StateValueFunction.Tries(followStates[idx]) == 0).SampleRandom(random);
    5150      }
    5251
    5352      // windowing
    54       var max = actions.Select(a => QualityFunction.Q(state, a)).Max();
    55       var min = actions.Select(a => QualityFunction.Q(state, a)).Min();
     53      var max = followStates.Select(s => StateValueFunction.Value(s)).Max();
     54      var min = followStates.Select(s => StateValueFunction.Value(s)).Min();
    5655      double range = max - min;
    57       if (range.IsAlmost(0.0)) return actions.SampleRandom(random, 1).First();
     56      if (range.IsAlmost(0.0)) return idxs.SampleRandom(random);
    5857
    59       var w = from a in actions
    60               select Math.Exp(Beta * (QualityFunction.Q(state, a) - min) / range);
     58      var w = from s in followStates
     59              select Math.Exp(Beta * (StateValueFunction.Value(s) - min) / range);
    6160
    62       return actions.SampleProportional(random, 1, w).First();
     61      return idxs.SampleProportional(random, 1, w).First();
    6362
    6463    }
    6564
    66     public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) {
    67       foreach (var t in stateActionSequence) {
    68         var state = t.Item1;
    69         var action = t.Item2;
    70         QualityFunction.Update(state, action, quality);
     65    public sealed override void Update(IEnumerable<object> stateSequence, double quality) {
     66      foreach (var state in stateSequence) {
     67        StateValueFunction.Update(state, quality);
    7168      }
    7269    }
    7370
    74     protected override object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
    75       return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);
     71    protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
     72      return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx);
    7673    }
    7774
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/EpsGreedySymbolicExpressionConstructionPolicy.cs

    r12909 r12923  
    2020    }
    2121
    22     public IQualityFunction QualityFunction {
     22    public IStateValueFunction StateValueFunction {
    2323      get {
    24         return ((IValueParameter<IQualityFunction>)Parameters["Quality function"]).Value;
     24        return ((IValueParameter<IStateValueFunction>)Parameters["Quality function"]).Value;
    2525      }
    26       set { ((IValueParameter<IQualityFunction>)Parameters["Quality function"]).Value = value; }
     26      set { ((IValueParameter<IStateValueFunction>)Parameters["Quality function"]).Value = value; }
    2727    }
    2828
     
    3030      : base() {
    3131      Parameters.Add(new FixedValueParameter<DoubleValue>("Eps", "The fraction of random pulls", new PercentValue(0.1, true)));
    32       Parameters.Add(new ValueParameter<IQualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));
     32      Parameters.Add(new ValueParameter<IStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction()));
    3333    }
    3434
    35     protected override int Select(object state, IEnumerable<int> actions, IRandom random) {
     35    protected override int Select(IReadOnlyList<object> followStates, IRandom random) {
     36      var idxs = Enumerable.Range(0, followStates.Count);
    3637      if (random.NextDouble() < Eps) {
    37         return actions.SampleRandom(random, 1).First();
     38        return idxs.SampleRandom(random);
    3839      }
    3940
    4041      // find best action
    41       var bestActions = new List<int>();
     42      var bestFollowStates = new List<int>();
    4243      var bestQuality = double.NegativeInfinity;
    43       foreach (var a in actions) {
    44         double quality = QualityFunction.Q(state, a);
     44      for (int idx = 0; idx < followStates.Count; idx++) {
     45        double quality = StateValueFunction.Value(followStates[idx]);
    4546
    4647        if (quality >= bestQuality) {
    4748          if (quality > bestQuality) {
    48             bestActions.Clear();
     49            bestFollowStates.Clear();
    4950            bestQuality = quality;
    5051          }
    51           bestActions.Add(a);
     52          bestFollowStates.Add(idx);
    5253        }
    5354      }
    54       return bestActions.SampleRandom(random, 1).First();
     55      return bestFollowStates.SampleRandom(random);
    5556    }
    5657
    57     public override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) {
    58       foreach (var t in stateActionSequence) {
    59         var state = t.Item1;
    60         var action = t.Item2;
    61         QualityFunction.Update(state, action, quality);
     58    public sealed override void Update(IEnumerable<object> stateSequence, double quality) {
     59      foreach (var state in stateSequence) {
     60        StateValueFunction.Update(state, quality);
    6261      }
    6362    }
    6463
    65     protected override object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
    66       return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);
     64    protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
     65      return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx);
    6766    }
    6867
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/RandomSymbolicExpressionConstructionPolicy.cs

    r12909 r12923  
    1717    }
    1818
    19     protected override int Select(object state, IEnumerable<int> actions, IRandom random) {
    20       return actions.SampleRandom(random, 1).First();
     19    protected override int Select(IReadOnlyList<object> followStates, IRandom random) {
     20      var idxs = Enumerable.Range(0, followStates.Count);
     21      return idxs.SampleRandom(random);
    2122    }
    2223
    23     public override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) {
     24    public sealed override void Update(IEnumerable<object> stateSequence, double quality) {
     25
    2426      // ignore
    2527    }
    2628
    27     protected override object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
     29    protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
    2830      return null; // doesn't use state information
    2931    }
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/SymbolicExpressionConstructionPolicyBase.cs

    r12922 r12923  
    2121    [Storable]
    2222    public IRandom Random { get; private set; }
    23     private SearchTree searchTree;
     23    private SearchTree<ISymbol> searchTree; // tree of replacement symbols
    2424
    2525    private class Slot {
     
    3131    [StorableHook(HookType.AfterDeserialization)]
    3232    private void AfterDeserialization() {
    33       searchTree = new SearchTree();
     33      searchTree = new SearchTree<ISymbol>();
    3434    }
    3535    protected SymbolicExpressionConstructionPolicyBase(SymbolicExpressionConstructionPolicyBase original, Cloner cloner)
     
    3939
    4040      // search tree is not cloned or stored
    41       searchTree = new SearchTree();
     41      searchTree = new SearchTree<ISymbol>();
    4242    }
    4343
     
    5252      this.Problem = problem;
    5353      this.Random = random;
    54       this.searchTree = new SearchTree(); // represents all realized actionSequences as a prefix tree
     54      this.searchTree = new SearchTree<ISymbol>(); // represents all realized actionSequences as a prefix tree
    5555    }
    5656
    57     public ISymbolicExpressionTree Sample(out IEnumerable<Tuple<object, int>> stateActionSequence) {
    58       var actions = new List<int>();
     57    public ISymbolicExpressionTree Sample(out IEnumerable<object> stateSequence) {
     58      var actions = new List<ISymbol>();
    5959      var states = new List<object>();
    6060
     
    6767
    6868      Contract.Assert(Problem.Encoding.FunctionDefinitions == 0);
    69       openSlots.Push(new Slot() { parent = root, childIdx = 0, minSize = 2 }); // at least two nodes are necessary below root
     69      openSlots.Push(new Slot() { parent = root, childIdx = 0, minSize = g.GetMinimumExpressionLength(root.Symbol) - 1 }); // at least two nodes are necessary below root
    7070
    7171      // tree size lower bound is the current tree size + the sum of the minimal size for all open slots
     
    7979        var childIdx = next.childIdx;
    8080
    81         // states might be defined differently be different policies
    82         // this allows policies to calculate the state as a function of the current tree and the position where it is changed,
    83         // or as a function of the list of actions so far,
    84         // or as a function of both
    85         var currentState = CreateState(root, actions, parent, childIdx);
    86         states.Add(currentState);
     81        if (searchTree.IsLeafNode()) {
     82          var allowedChildSymbols = g.GetAllowedChildSymbols(parent.Symbol, childIdx)
     83            .Where(a => a.Enabled)
     84            .Where(a => treeSize + g.GetMinimumExpressionLength(a) + openSlots.Select(e => e.minSize).Sum() <= maxLen);
    8785
    88 
    89         // TODO: only filter the first time later use info from search tree
    90         var alts = g.GetAllowedChildSymbols(parent.Symbol, childIdx)
    91           .Where(a => treeSize + g.GetMinimumExpressionLength(a) + openSlots.Select(e => e.minSize).Sum() <= maxLen)
    92           .ToArray();
    93 
    94         if (searchTree.IsLeafNode()) {
    95           searchTree.ExpandCurrentNode(alts);
     86          searchTree.ExpandCurrentNode(allowedChildSymbols);
    9687        }
    9788
    98         if (!searchTree.PossibleActions.Any()) {
     89        if (!searchTree.ChildValues.Any()) {
    9990          throw new InvalidProgramException(string.Format("Couldn't construct a valid tree of maximum length {0} or all possible trees have been visited", maxLen));
    10091        }
    10192
    102         // select a symbol randomly for the child
    103         // select random alternative
    104         var selectedIdx = Select(currentState, searchTree.PossibleActions, Random);
    105         actions.Add(selectedIdx);
     93        var alternatives = searchTree.ChildValues.ToArray(); // TODO perf
    10694
    107         // and add child node to parent
    108         var childNode = alts[selectedIdx].CreateTreeNode();
    109         if (childNode.HasLocalParameters) {
    110           throw new NotSupportedException("Symbols with parameters are not supported by construction policies for symbolic expressions. Try to reformulate the problem so that only discrete actions are necessary");
    111           // childNode.ResetLocalParameters(Random);
     95        // generate follow states
     96        var followStates = new object[alternatives.Length];
     97        for (int i = 0; i < followStates.Length; i++) {
     98          // temporarily make the replacement and create the followState object
     99          var childNode = alternatives[i].CreateTreeNode();
     100          if (childNode.HasLocalParameters) {
     101            throw new NotSupportedException("Symbols with parameters are not supported by construction policies for symbolic expressions. " +
     102                                            "Try to reformulate the problem so that only discrete actions are necessary");
     103            // childNode.ResetLocalParameters(Random);
     104          }
     105          parent.AddSubtree(childNode);
     106          actions.Add(alternatives[i]);
     107
     108          // states might be defined differently be different policies
     109          // this allows policies to calculate the state as a function of the current tree and the position where it is changed,
     110          // or as a function of the list of actions so far,
     111          // or as a function of both
     112          followStates[i] = CreateState(root, actions, parent, childIdx);
     113
     114          // roll back the change
     115          parent.RemoveSubtree(parent.SubtreeCount - 1);
     116          actions.RemoveAt(actions.Count - 1);
    112117        }
    113118
    114         Contract.Assert(parent.SubtreeCount == childIdx);
    115         parent.AddSubtree(childNode); // enforce left-canonical derivation
    116         treeSize++;
     119        // select one of the follow states and prepare for the next step
     120        var selectedIdx = Select(followStates, Random);
     121        actions.Add(alternatives[selectedIdx]);
     122        states.Add(followStates[selectedIdx]);
    117123
    118         // push new slots
    119         for (int chIdx = g.GetMinimumSubtreeCount(childNode.Symbol) - 1; chIdx >= 0; chIdx--) {
    120           int minForChild = g.GetAllowedChildSymbols(childNode.Symbol, chIdx).Min(a => g.GetMinimumExpressionLength(a)); // min length of all possible alts for the slot
    121           openSlots.Push(new Slot() { parent = childNode, childIdx = chIdx, minSize = minForChild });
     124        {
     125          // and add child node to parent
     126          var childNode = alternatives[selectedIdx].CreateTreeNode();
     127
     128          Contract.Assert(parent.SubtreeCount == childIdx); // enforce left-canonical derivation
     129          parent.AddSubtree(childNode);
     130          treeSize++;
     131
     132          // push new slots
     133          for (int chIdx = g.GetMinimumSubtreeCount(childNode.Symbol) - 1; chIdx >= 0; chIdx--) {
     134            int minForChild = g.GetAllowedChildSymbols(childNode.Symbol, chIdx)
     135              .Min(a => g.GetMinimumExpressionLength(a)); // min length of all possible alts for the slot
     136            openSlots.Push(new Slot() { parent = childNode, childIdx = chIdx, minSize = minForChild });
     137          }
    122138        }
    123139
    124140        // if this is the last slot we never have to revisit selectedIdx
    125141        if (!openSlots.Any()) {
    126           searchTree.RemoveBranch(selectedIdx);
     142          searchTree.RemoveBranch(alternatives[selectedIdx]);
    127143        } else {
    128           searchTree.Follow(selectedIdx);
     144          searchTree.Follow(alternatives[selectedIdx]);
    129145        }
    130146      }
    131147
    132       stateActionSequence = states.Zip(actions, Tuple.Create);
     148      stateSequence = states;
    133149      return new SymbolicExpressionTree(root);
    134150    }
    135151
     152    /// <summary>
     153    /// Choose one of the follow states
     154    /// </summary>
     155    /// <param name="followStates"></param>
     156    /// <param name="random"></param>
     157    /// <returns>The index of the selected follow state</returns>
     158    protected abstract int Select(IReadOnlyList<object> followStates, IRandom random);
     159    public abstract void Update(IEnumerable<object> stateSequence, double quality);
    136160
    137     protected abstract int Select(object state, IEnumerable<int> possibleActions, IRandom random);
    138     public abstract void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality);
    139 
    140     protected abstract object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx);
     161    protected abstract object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx);
    141162  }
    142163}
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/UcbSymbolicExpressionConstructionPolicy.cs

    r12909 r12923  
    2020    }
    2121
    22     public ITabularQualityFunction QualityFunction {
     22    public ITabularStateValueFunction StateValueFunction {
    2323      get {
    24         return ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value;
     24        return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value;
    2525      }
    26       set { ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value = value; }
     26      set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; }
    2727    }
    2828
     
    3939      : base() {
    4040      Parameters.Add(new FixedValueParameter<DoubleValue>("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0)));
    41       Parameters.Add(new ValueParameter<ITabularQualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));
     41      Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction()));
    4242    }
    4343
    44     protected sealed override int Select(object state, IEnumerable<int> actions, IRandom random) {
    45 
    46       // find best action
    47       var bestActions = new List<int>();
     44    protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) {
     45      var bestFollowStates = new List<int>();
    4846      var bestQuality = double.NegativeInfinity;
    49       int totalTries = actions.Sum(a => QualityFunction.Tries(state, a));
    50       foreach (var a in actions) {
     47      int totalTries = followStates.Sum(s => StateValueFunction.Tries(s));
     48      for (int idx = 0; idx < followStates.Count; idx++) {
    5149        double quality;
    52         if (QualityFunction.Tries(state, a) == 0) {
     50        var s = followStates[idx];
     51        if (StateValueFunction.Tries(s) == 0) {
    5352          quality = double.PositiveInfinity;
    5453        } else {
    55           quality = QualityFunction.Q(state, a) + R * Math.Sqrt((2 * Math.Log(totalTries)) / QualityFunction.Tries(state, a));
     54          quality = StateValueFunction.Value(s) + R * Math.Sqrt((2 * Math.Log(totalTries)) / StateValueFunction.Tries(s));
    5655        }
    5756        if (quality >= bestQuality) {
    5857          if (quality > bestQuality) {
    59             bestActions.Clear();
     58            bestFollowStates.Clear();
    6059            bestQuality = quality;
    6160          }
    62           bestActions.Add(a);
     61          bestFollowStates.Add(idx);
    6362        }
    6463      }
    65       return bestActions.SampleRandom(random, 1).First();
     64      return bestFollowStates.SampleRandom(random);
    6665    }
    6766
    68     public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) {
    69       foreach (var t in stateActionSequence) {
    70         var state = t.Item1;
    71         var action = t.Item2;
    72         QualityFunction.Update(state, action, quality);
     67    public sealed override void Update(IEnumerable<object> stateSequence, double quality) {
     68      foreach (var state in stateSequence) {
     69        StateValueFunction.Update(state, quality);
    7370      }
    7471    }
    7572
    76     protected override object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
    77       return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);
     73    protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
     74      return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx);
    7875    }
    7976
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/UcbTunedSymbolicExpressionConstructionPolicy.cs

    r12922 r12923  
    2020    }
    2121
    22     public ITabularQualityFunction QualityFunction {
     22    public ITabularStateValueFunction StateValueFunction {
    2323      get {
    24         return ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value;
     24        return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value;
    2525      }
    26       set { ((IValueParameter<ITabularQualityFunction>)Parameters["Quality function"]).Value = value; }
     26      set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; }
    2727    }
    2828
     
    3939      : base() {
    4040      Parameters.Add(new FixedValueParameter<DoubleValue>("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0)));
    41       Parameters.Add(new ValueParameter<ITabularQualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));
     41      Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction()));
    4242    }
    4343
    44     protected sealed override int Select(object state, IEnumerable<int> actions, IRandom random) {
    45 
    46       // find best action
    47       var bestActions = new List<int>();
     44    protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) {
     45      var bestFollowStates = new List<int>();
    4846      var bestQuality = double.NegativeInfinity;
    49       int totalTries = actions.Sum(a => QualityFunction.Tries(state, a));
    50       foreach (var a in actions) {
     47      int totalTries = followStates.Sum(s => StateValueFunction.Tries(s));
     48      for (int idx = 0; idx < followStates.Count; idx++) {
     49        var s = followStates[idx];
    5150        double quality;
    52         if (QualityFunction.Tries(state, a) == 0) {
     51        if (StateValueFunction.Tries(s) == 0) {
    5352          quality = double.PositiveInfinity;
    5453        } else {
    55           double v = QualityFunction.QVariance(state, a) + Math.Sqrt(2 * Math.Log(totalTries) / QualityFunction.Tries(state, a));
    56           quality = QualityFunction.Q(state, a) + R * Math.Sqrt(Math.Log(totalTries) / QualityFunction.Tries(state, a) * v);
     54          double v = StateValueFunction.ValueVariance(s) + Math.Sqrt(2 * Math.Log(totalTries) / StateValueFunction.Tries(s));
     55          quality = StateValueFunction.Value(s) + R * Math.Sqrt(Math.Log(totalTries) / StateValueFunction.Tries(s) * v);
    5756        }
    5857        if (quality >= bestQuality) {
    5958          if (quality > bestQuality) {
    60             bestActions.Clear();
     59            bestFollowStates.Clear();
    6160            bestQuality = quality;
    6261          }
    63           bestActions.Add(a);
     62          bestFollowStates.Add(idx);
    6463        }
    6564      }
    66       return bestActions.SampleRandom(random, 1).First();
     65      return bestFollowStates.SampleRandom(random);
    6766    }
    6867
    69     public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) {
    70       foreach (var t in stateActionSequence) {
    71         var state = t.Item1;
    72         var action = t.Item2;
    73         QualityFunction.Update(state, action, quality);
     68    public sealed override void Update(IEnumerable<object> stateSequence, double quality) {
     69      foreach (var state in stateSequence) {
     70        StateValueFunction.Update(state, quality);
    7471      }
    7572    }
    7673
    77     protected override object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
    78       return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);
     74    protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) {
     75      return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx);
    7976    }
    8077
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/SearchTree.cs

    r12922 r12923  
    77
    88namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction {
    9   internal class SearchTree {
    10     private class Node {
    11       internal Node parent;
    12       internal Node[] children;
     9  internal class SearchTree<TValue> {
     10    private class Node<TValue> {
     11      internal TValue value;
     12      internal Node<TValue> parent;
     13      internal Node<TValue>[] children;
    1314      // children == null -> never visited
    1415      // children[i] != null -> visited at least once, still allowed
     
    1617    }
    1718
    18     private Node root;
     19    private Node<TValue> root;
    1920
    2021    // for iteration
    21     private Node currentNode;
     22    private Node<TValue> currentNode;
    2223
    2324    public SearchTree() {
    24       root = new Node();
     25      root = new Node<TValue>();
    2526      currentNode = root;
    2627    }
     
    3435    }
    3536
    36     public void ExpandCurrentNode<T>(IEnumerable<T> actions) {
    37       Contract.Assert(actions.Any());
     37    public void ExpandCurrentNode(IEnumerable<TValue> values) {
     38      Contract.Assert(values.Any());
    3839      Contract.Assert(currentNode.children == null);
    39       currentNode.children = actions.Select(_ => new Node() { parent = currentNode }).ToArray();
     40      currentNode.children = values.Select(val => new Node<TValue>() { value = val, parent = currentNode }).ToArray();
    4041    }
    4142
    42     public void Follow(int action) {
    43       Contract.Assert(currentNode.children != null);
    44       Contract.Assert(currentNode.children[action] != null);
    45       currentNode = currentNode.children[action];
     43    public void Follow(TValue value) {
     44      // TODO: perf
     45      int i = 0;
     46      while (i < currentNode.children.Length && (
     47        currentNode.children[i] == null || !currentNode.children[i].value.Equals(value))) i++;
     48
     49      if (i >= currentNode.children.Length) throw new InvalidProgramException();
     50      currentNode = currentNode.children[i];
    4651    }
    4752
    48     public IEnumerable<int> PossibleActions {
     53    public IEnumerable<TValue> ChildValues {
    4954      get {
    50         return Enumerable.Range(0, currentNode.children.Length)
    51                          .Where(i => currentNode.children[i] != null);
     55        return from ch in currentNode.children
     56               where ch != null
     57               select ch.value;
    5258      }
    5359    }
    5460
    55     public void RemoveBranch(int action) {
    56       Contract.Assert(currentNode.children != null);
    57       Contract.Assert(currentNode.children[action] != null);
    58       currentNode.children[action] = null;
     61    public void RemoveBranch(TValue value) {
     62      // TODO: perf
     63      int i = 0;
     64      while (i < currentNode.children.Length && (
     65        currentNode.children[i] == null || !currentNode.children[i].value.Equals(value))) i++;
     66
     67      if (i >= currentNode.children.Length) throw new InvalidProgramException();
     68      currentNode.children[i] = null;
    5969
    6070      RemoveRecursively(currentNode);
    6171    }
    6272
    63     private void RemoveRecursively(Node node) {
     73    private void RemoveRecursively(Node<TValue> node) {
    6474      // when the last child has been removed we must remove the current node from it's parent
    6575      while (node.parent != null && node.children.All(ch => ch == null)) {
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/StateFunctions/DefaultStateFunction.cs

    r12909 r12923  
    1818    }
    1919
    20     public object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
    21       return string.Join(",", actions);
     20    public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
     21      return string.Join(",", actions.Select(a => a.Name));
    2222    }
    2323
  • branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/StateFunctions/ParentChildStateFunction.cs

    r12909 r12923  
    1919    }
    2020
    21     public object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
     21    public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {
    2222      return (parentNode == null ? "" : parentNode.Symbol.Name) + childIdx;
    2323    }
Note: See TracChangeset for help on using the changeset viewer.