Changeset 12923 for branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction
- Timestamp:
- 08/28/15 14:57:21 (9 years ago)
- Location:
- branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3
- Files:
-
- 5 added
- 5 deleted
- 13 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction-3.3.csproj
r12922 r12923 155 155 <ItemGroup> 156 156 <Compile Include="Interfaces\IStateFunction.cs" /> 157 <Compile Include="Interfaces\I QualityFunction.cs" />158 <Compile Include="Interfaces\ITabular QualityFunction.cs" />157 <Compile Include="Interfaces\IStateValueFunction.cs" /> 158 <Compile Include="Interfaces\ITabularStateValueFunction.cs" /> 159 159 <Compile Include="Interfaces\ISymbolicExpressionConstructionPolicy.cs" /> 160 160 <Compile Include="IteratedSymbolicExpressionConstruction.cs" /> … … 169 169 </Compile> 170 170 <Compile Include="Properties\AssemblyInfo.cs" /> 171 <Compile Include="QualityFunctions\Tabular QualityFunctionBase.cs" />172 <Compile Include="QualityFunctions\TabularAvg QualityFunction.cs" />173 <Compile Include="QualityFunctions\TabularMax QualityFunction.cs" />171 <Compile Include="QualityFunctions\TabularStateValueFunctionBase.cs" /> 172 <Compile Include="QualityFunctions\TabularAvgStateValueFunction.cs" /> 173 <Compile Include="QualityFunctions\TabularMaxStateValueFunction.cs" /> 174 174 <Compile Include="SearchTree.cs" /> 175 175 <Compile Include="StateFunctions\ParentChildStateFunction.cs" /> -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Interfaces/IStateFunction.cs
r12909 r12923 10 10 // creates a state from the information available at sequential derivation steps of symbolic expression trees 11 11 public interface IStateFunction : IItem { 12 object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx);12 object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx); 13 13 } 14 14 } -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Interfaces/ISymbolicExpressionConstructionPolicy.cs
r12909 r12923 9 9 public interface ISymbolicExpressionConstructionPolicy : IItem { 10 10 void Initialize(SymbolicExpressionTreeProblem problem, IRandom random); 11 ISymbolicExpressionTree Sample(out IEnumerable< Tuple<object, int>> stateActionSequence);12 void Update(IEnumerable< Tuple<object, int>> stateActionSequence, double quality);11 ISymbolicExpressionTree Sample(out IEnumerable<object> stateSequence); 12 void Update(IEnumerable<object> stateSequence, double quality); 13 13 } 14 14 } -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/IteratedSymbolicExpressionConstruction.cs
r12909 r12923 147 147 random.Reset(Seed); 148 148 149 //var policy = new RandomSymbolicExpressionConstructionPolicy(Problem, random);150 //var policy = new EpsGreedySymbolicExpressionConstructionPolicy<string>(Problem, random, new TabularMaxQualityFunction<string>(new DefaultStateFunction()));151 //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularAvgQualityFunction<string>(new DefaultStateFunction()));152 //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularMaxQualityFunction<string>(new ParentChildStateFunction()));153 //var policy = new EpsGreedySymbolicExpressionConstructionPolicy(Problem, random, new TabularAvgQualityFunction<string>(new ParentChildStateFunction()));154 //var policy = new UcbSymbolicExpressionConstructionPolicy<string>(Problem, random, new TabularMaxQualityFunction<string>(new DefaultStateFunction()));155 //var policy = new UcbWithStateAggregationSymbolicExpressionConstructionPolicy(Problem, random, 40);156 157 149 var policy = PolicyParameter.Value; 158 150 policy.Initialize(Problem, random); … … 186 178 double sumQuality = 0; // for average quality calculation 187 179 int resultUpdateInterval = ResultUpdateInterval; 188 while (evals < MaximumEvaluations) { 189 double quality = double.NaN; 190 ISymbolicExpressionTree tree = null; 191 IEnumerable<Tuple<object, int>> actionSequence; 192 tree = policy.Sample(out actionSequence); 193 quality = Problem.Evaluate(tree, random); 194 evals++; 195 sumQuality += quality; 196 197 policy.Update(actionSequence, quality); 198 199 // update statistics results in regular update intervals 200 if ((evals - 1) % resultUpdateInterval == resultUpdateInterval - 1) { 201 evaluations.Value = evals; 202 bestQualityRow.Values.Add(bestQuality.Value); 203 currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval); 204 sumQuality = 0; 180 try { 181 while (evals < MaximumEvaluations) { 182 double quality = double.NaN; 183 ISymbolicExpressionTree tree = null; 184 IEnumerable<object> stateSequence; 185 tree = policy.Sample(out stateSequence); 186 quality = Problem.Evaluate(tree, random); 187 evals++; 188 sumQuality += quality; 189 190 policy.Update(stateSequence, quality); 191 cancellationToken.ThrowIfCancellationRequested(); 192 193 // update statistics results in regular update intervals 194 if ((evals - 1) % resultUpdateInterval == resultUpdateInterval - 1) { 195 evaluations.Value = evals; 196 bestQualityRow.Values.Add(bestQuality.Value); 197 currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval); 198 sumQuality = 0; 199 } 200 201 // update best solution results whenever a new better solution is found 202 if (Problem.IsBetter(quality, bestQuality.Value)) { 203 bestQuality.Value = quality; 204 bestFoundOnEvaluation.Value = evals; 205 206 // for problem-specific analyzer 207 solutions[0] = tree; 208 qualities[0] = quality; 209 } 210 211 // run problem-specific analyzer in each iteration 212 Problem.Analyze(solutions, qualities, Results, random); 205 213 } 206 207 // update best solution results whenever a new better solution is found 208 if (Problem.IsBetter(quality, bestQuality.Value)) { 209 bestQuality.Value = quality; 210 bestFoundOnEvaluation.Value = evals; 211 212 // for problem-specific analyzer 213 solutions[0] = tree; 214 qualities[0] = quality; 215 } 216 217 // run problem-specific analyzer in each iteration 218 Problem.Analyze(solutions, qualities, Results, random); 219 220 cancellationToken.ThrowIfCancellationRequested(); 214 } finally { 215 // update stats whenever the alg is stopped 216 evaluations.Value = evals; 217 bestQualityRow.Values.Add(bestQuality.Value); 218 currentQualityRow.Values.Add(sumQuality / (double)resultUpdateInterval); 221 219 } 222 220 } -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/BoltzmannExplorationSymbolicExpressionConstructionPolicy.cs
r12909 r12923 20 20 } 21 21 22 public ITabular QualityFunction QualityFunction {22 public ITabularStateValueFunction StateValueFunction { 23 23 get { 24 return ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value;24 return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value; 25 25 } 26 set { ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value = value; }26 set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; } 27 27 } 28 28 … … 39 39 : base() { 40 40 Parameters.Add(new FixedValueParameter<DoubleValue>("Beta", "The weighting factor beta", new DoubleValue(1.0))); 41 Parameters.Add(new ValueParameter<ITabular QualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));41 Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); 42 42 } 43 43 44 protected sealed override int Select( object state, IEnumerable<int> actions, IRandom random) {45 44 protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) { 45 var idxs = Enumerable.Range(0, followStates.Count); 46 46 // find best action 47 var bestActions = new List<int>();48 47 var bestQuality = double.NegativeInfinity; 49 if ( actions.Any(a => QualityFunction.Tries(state, a) == 0)) {50 return actions.Where(a => QualityFunction.Tries(state, a) == 0).SampleRandom(random, 1).First();48 if (followStates.Any(s => StateValueFunction.Tries(s) == 0)) { 49 return idxs.Where(idx => StateValueFunction.Tries(followStates[idx]) == 0).SampleRandom(random); 51 50 } 52 51 53 52 // windowing 54 var max = actions.Select(a => QualityFunction.Q(state, a)).Max();55 var min = actions.Select(a => QualityFunction.Q(state, a)).Min();53 var max = followStates.Select(s => StateValueFunction.Value(s)).Max(); 54 var min = followStates.Select(s => StateValueFunction.Value(s)).Min(); 56 55 double range = max - min; 57 if (range.IsAlmost(0.0)) return actions.SampleRandom(random, 1).First();56 if (range.IsAlmost(0.0)) return idxs.SampleRandom(random); 58 57 59 var w = from a in actions60 select Math.Exp(Beta * ( QualityFunction.Q(state, a) - min) / range);58 var w = from s in followStates 59 select Math.Exp(Beta * (StateValueFunction.Value(s) - min) / range); 61 60 62 return actions.SampleProportional(random, 1, w).First();61 return idxs.SampleProportional(random, 1, w).First(); 63 62 64 63 } 65 64 66 public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) { 67 foreach (var t in stateActionSequence) { 68 var state = t.Item1; 69 var action = t.Item2; 70 QualityFunction.Update(state, action, quality); 65 public sealed override void Update(IEnumerable<object> stateSequence, double quality) { 66 foreach (var state in stateSequence) { 67 StateValueFunction.Update(state, quality); 71 68 } 72 69 } 73 70 74 protected override object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {75 return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);71 protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) { 72 return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx); 76 73 } 77 74 -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/EpsGreedySymbolicExpressionConstructionPolicy.cs
r12909 r12923 20 20 } 21 21 22 public I QualityFunction QualityFunction {22 public IStateValueFunction StateValueFunction { 23 23 get { 24 return ((IValueParameter<I QualityFunction>)Parameters["Quality function"]).Value;24 return ((IValueParameter<IStateValueFunction>)Parameters["Quality function"]).Value; 25 25 } 26 set { ((IValueParameter<I QualityFunction>)Parameters["Quality function"]).Value = value; }26 set { ((IValueParameter<IStateValueFunction>)Parameters["Quality function"]).Value = value; } 27 27 } 28 28 … … 30 30 : base() { 31 31 Parameters.Add(new FixedValueParameter<DoubleValue>("Eps", "The fraction of random pulls", new PercentValue(0.1, true))); 32 Parameters.Add(new ValueParameter<I QualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));32 Parameters.Add(new ValueParameter<IStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); 33 33 } 34 34 35 protected override int Select(object state, IEnumerable<int> actions, IRandom random) { 35 protected override int Select(IReadOnlyList<object> followStates, IRandom random) { 36 var idxs = Enumerable.Range(0, followStates.Count); 36 37 if (random.NextDouble() < Eps) { 37 return actions.SampleRandom(random, 1).First();38 return idxs.SampleRandom(random); 38 39 } 39 40 40 41 // find best action 41 var best Actions = new List<int>();42 var bestFollowStates = new List<int>(); 42 43 var bestQuality = double.NegativeInfinity; 43 for each (var a in actions) {44 double quality = QualityFunction.Q(state, a);44 for (int idx = 0; idx < followStates.Count; idx++) { 45 double quality = StateValueFunction.Value(followStates[idx]); 45 46 46 47 if (quality >= bestQuality) { 47 48 if (quality > bestQuality) { 48 best Actions.Clear();49 bestFollowStates.Clear(); 49 50 bestQuality = quality; 50 51 } 51 best Actions.Add(a);52 bestFollowStates.Add(idx); 52 53 } 53 54 } 54 return best Actions.SampleRandom(random, 1).First();55 return bestFollowStates.SampleRandom(random); 55 56 } 56 57 57 public override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) { 58 foreach (var t in stateActionSequence) { 59 var state = t.Item1; 60 var action = t.Item2; 61 QualityFunction.Update(state, action, quality); 58 public sealed override void Update(IEnumerable<object> stateSequence, double quality) { 59 foreach (var state in stateSequence) { 60 StateValueFunction.Update(state, quality); 62 61 } 63 62 } 64 63 65 protected override object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {66 return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);64 protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) { 65 return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx); 67 66 } 68 67 -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/RandomSymbolicExpressionConstructionPolicy.cs
r12909 r12923 17 17 } 18 18 19 protected override int Select(object state, IEnumerable<int> actions, IRandom random) { 20 return actions.SampleRandom(random, 1).First(); 19 protected override int Select(IReadOnlyList<object> followStates, IRandom random) { 20 var idxs = Enumerable.Range(0, followStates.Count); 21 return idxs.SampleRandom(random); 21 22 } 22 23 23 public override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) { 24 public sealed override void Update(IEnumerable<object> stateSequence, double quality) { 25 24 26 // ignore 25 27 } 26 28 27 protected override object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {29 protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) { 28 30 return null; // doesn't use state information 29 31 } -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/SymbolicExpressionConstructionPolicyBase.cs
r12922 r12923 21 21 [Storable] 22 22 public IRandom Random { get; private set; } 23 private SearchTree searchTree;23 private SearchTree<ISymbol> searchTree; // tree of replacement symbols 24 24 25 25 private class Slot { … … 31 31 [StorableHook(HookType.AfterDeserialization)] 32 32 private void AfterDeserialization() { 33 searchTree = new SearchTree ();33 searchTree = new SearchTree<ISymbol>(); 34 34 } 35 35 protected SymbolicExpressionConstructionPolicyBase(SymbolicExpressionConstructionPolicyBase original, Cloner cloner) … … 39 39 40 40 // search tree is not cloned or stored 41 searchTree = new SearchTree ();41 searchTree = new SearchTree<ISymbol>(); 42 42 } 43 43 … … 52 52 this.Problem = problem; 53 53 this.Random = random; 54 this.searchTree = new SearchTree (); // represents all realized actionSequences as a prefix tree54 this.searchTree = new SearchTree<ISymbol>(); // represents all realized actionSequences as a prefix tree 55 55 } 56 56 57 public ISymbolicExpressionTree Sample(out IEnumerable< Tuple<object, int>> stateActionSequence) {58 var actions = new List< int>();57 public ISymbolicExpressionTree Sample(out IEnumerable<object> stateSequence) { 58 var actions = new List<ISymbol>(); 59 59 var states = new List<object>(); 60 60 … … 67 67 68 68 Contract.Assert(Problem.Encoding.FunctionDefinitions == 0); 69 openSlots.Push(new Slot() { parent = root, childIdx = 0, minSize = 2}); // at least two nodes are necessary below root69 openSlots.Push(new Slot() { parent = root, childIdx = 0, minSize = g.GetMinimumExpressionLength(root.Symbol) - 1 }); // at least two nodes are necessary below root 70 70 71 71 // tree size lower bound is the current tree size + the sum of the minimal size for all open slots … … 79 79 var childIdx = next.childIdx; 80 80 81 // states might be defined differently be different policies 82 // this allows policies to calculate the state as a function of the current tree and the position where it is changed, 83 // or as a function of the list of actions so far, 84 // or as a function of both 85 var currentState = CreateState(root, actions, parent, childIdx); 86 states.Add(currentState); 81 if (searchTree.IsLeafNode()) { 82 var allowedChildSymbols = g.GetAllowedChildSymbols(parent.Symbol, childIdx) 83 .Where(a => a.Enabled) 84 .Where(a => treeSize + g.GetMinimumExpressionLength(a) + openSlots.Select(e => e.minSize).Sum() <= maxLen); 87 85 88 89 // TODO: only filter the first time later use info from search tree 90 var alts = g.GetAllowedChildSymbols(parent.Symbol, childIdx) 91 .Where(a => treeSize + g.GetMinimumExpressionLength(a) + openSlots.Select(e => e.minSize).Sum() <= maxLen) 92 .ToArray(); 93 94 if (searchTree.IsLeafNode()) { 95 searchTree.ExpandCurrentNode(alts); 86 searchTree.ExpandCurrentNode(allowedChildSymbols); 96 87 } 97 88 98 if (!searchTree. PossibleActions.Any()) {89 if (!searchTree.ChildValues.Any()) { 99 90 throw new InvalidProgramException(string.Format("Couldn't construct a valid tree of maximum length {0} or all possible trees have been visited", maxLen)); 100 91 } 101 92 102 // select a symbol randomly for the child 103 // select random alternative 104 var selectedIdx = Select(currentState, searchTree.PossibleActions, Random); 105 actions.Add(selectedIdx); 93 var alternatives = searchTree.ChildValues.ToArray(); // TODO perf 106 94 107 // and add child node to parent 108 var childNode = alts[selectedIdx].CreateTreeNode(); 109 if (childNode.HasLocalParameters) { 110 throw new NotSupportedException("Symbols with parameters are not supported by construction policies for symbolic expressions. Try to reformulate the problem so that only discrete actions are necessary"); 111 // childNode.ResetLocalParameters(Random); 95 // generate follow states 96 var followStates = new object[alternatives.Length]; 97 for (int i = 0; i < followStates.Length; i++) { 98 // temporarily make the replacement and create the followState object 99 var childNode = alternatives[i].CreateTreeNode(); 100 if (childNode.HasLocalParameters) { 101 throw new NotSupportedException("Symbols with parameters are not supported by construction policies for symbolic expressions. " + 102 "Try to reformulate the problem so that only discrete actions are necessary"); 103 // childNode.ResetLocalParameters(Random); 104 } 105 parent.AddSubtree(childNode); 106 actions.Add(alternatives[i]); 107 108 // states might be defined differently be different policies 109 // this allows policies to calculate the state as a function of the current tree and the position where it is changed, 110 // or as a function of the list of actions so far, 111 // or as a function of both 112 followStates[i] = CreateState(root, actions, parent, childIdx); 113 114 // roll back the change 115 parent.RemoveSubtree(parent.SubtreeCount - 1); 116 actions.RemoveAt(actions.Count - 1); 112 117 } 113 118 114 Contract.Assert(parent.SubtreeCount == childIdx); 115 parent.AddSubtree(childNode); // enforce left-canonical derivation 116 treeSize++; 119 // select one of the follow states and prepare for the next step 120 var selectedIdx = Select(followStates, Random); 121 actions.Add(alternatives[selectedIdx]); 122 states.Add(followStates[selectedIdx]); 117 123 118 // push new slots 119 for (int chIdx = g.GetMinimumSubtreeCount(childNode.Symbol) - 1; chIdx >= 0; chIdx--) { 120 int minForChild = g.GetAllowedChildSymbols(childNode.Symbol, chIdx).Min(a => g.GetMinimumExpressionLength(a)); // min length of all possible alts for the slot 121 openSlots.Push(new Slot() { parent = childNode, childIdx = chIdx, minSize = minForChild }); 124 { 125 // and add child node to parent 126 var childNode = alternatives[selectedIdx].CreateTreeNode(); 127 128 Contract.Assert(parent.SubtreeCount == childIdx); // enforce left-canonical derivation 129 parent.AddSubtree(childNode); 130 treeSize++; 131 132 // push new slots 133 for (int chIdx = g.GetMinimumSubtreeCount(childNode.Symbol) - 1; chIdx >= 0; chIdx--) { 134 int minForChild = g.GetAllowedChildSymbols(childNode.Symbol, chIdx) 135 .Min(a => g.GetMinimumExpressionLength(a)); // min length of all possible alts for the slot 136 openSlots.Push(new Slot() { parent = childNode, childIdx = chIdx, minSize = minForChild }); 137 } 122 138 } 123 139 124 140 // if this is the last slot we never have to revisit selectedIdx 125 141 if (!openSlots.Any()) { 126 searchTree.RemoveBranch( selectedIdx);142 searchTree.RemoveBranch(alternatives[selectedIdx]); 127 143 } else { 128 searchTree.Follow( selectedIdx);144 searchTree.Follow(alternatives[selectedIdx]); 129 145 } 130 146 } 131 147 132 state ActionSequence = states.Zip(actions, Tuple.Create);148 stateSequence = states; 133 149 return new SymbolicExpressionTree(root); 134 150 } 135 151 152 /// <summary> 153 /// Choose one of the follow states 154 /// </summary> 155 /// <param name="followStates"></param> 156 /// <param name="random"></param> 157 /// <returns>The index of the selected follow state</returns> 158 protected abstract int Select(IReadOnlyList<object> followStates, IRandom random); 159 public abstract void Update(IEnumerable<object> stateSequence, double quality); 136 160 137 protected abstract int Select(object state, IEnumerable<int> possibleActions, IRandom random); 138 public abstract void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality); 139 140 protected abstract object CreateState(ISymbolicExpressionTreeNode root, List<int> actions, ISymbolicExpressionTreeNode parent, int childIdx); 161 protected abstract object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx); 141 162 } 142 163 } -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/UcbSymbolicExpressionConstructionPolicy.cs
r12909 r12923 20 20 } 21 21 22 public ITabular QualityFunction QualityFunction {22 public ITabularStateValueFunction StateValueFunction { 23 23 get { 24 return ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value;24 return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value; 25 25 } 26 set { ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value = value; }26 set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; } 27 27 } 28 28 … … 39 39 : base() { 40 40 Parameters.Add(new FixedValueParameter<DoubleValue>("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0))); 41 Parameters.Add(new ValueParameter<ITabular QualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));41 Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); 42 42 } 43 43 44 protected sealed override int Select(object state, IEnumerable<int> actions, IRandom random) { 45 46 // find best action 47 var bestActions = new List<int>(); 44 protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) { 45 var bestFollowStates = new List<int>(); 48 46 var bestQuality = double.NegativeInfinity; 49 int totalTries = actions.Sum(a => QualityFunction.Tries(state, a));50 for each (var a in actions) {47 int totalTries = followStates.Sum(s => StateValueFunction.Tries(s)); 48 for (int idx = 0; idx < followStates.Count; idx++) { 51 49 double quality; 52 if (QualityFunction.Tries(state, a) == 0) { 50 var s = followStates[idx]; 51 if (StateValueFunction.Tries(s) == 0) { 53 52 quality = double.PositiveInfinity; 54 53 } else { 55 quality = QualityFunction.Q(state, a) + R * Math.Sqrt((2 * Math.Log(totalTries)) / QualityFunction.Tries(state, a));54 quality = StateValueFunction.Value(s) + R * Math.Sqrt((2 * Math.Log(totalTries)) / StateValueFunction.Tries(s)); 56 55 } 57 56 if (quality >= bestQuality) { 58 57 if (quality > bestQuality) { 59 best Actions.Clear();58 bestFollowStates.Clear(); 60 59 bestQuality = quality; 61 60 } 62 best Actions.Add(a);61 bestFollowStates.Add(idx); 63 62 } 64 63 } 65 return best Actions.SampleRandom(random, 1).First();64 return bestFollowStates.SampleRandom(random); 66 65 } 67 66 68 public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) { 69 foreach (var t in stateActionSequence) { 70 var state = t.Item1; 71 var action = t.Item2; 72 QualityFunction.Update(state, action, quality); 67 public sealed override void Update(IEnumerable<object> stateSequence, double quality) { 68 foreach (var state in stateSequence) { 69 StateValueFunction.Update(state, quality); 73 70 } 74 71 } 75 72 76 protected override object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {77 return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);73 protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) { 74 return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx); 78 75 } 79 76 -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/Policies/UcbTunedSymbolicExpressionConstructionPolicy.cs
r12922 r12923 20 20 } 21 21 22 public ITabular QualityFunction QualityFunction {22 public ITabularStateValueFunction StateValueFunction { 23 23 get { 24 return ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value;24 return ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value; 25 25 } 26 set { ((IValueParameter<ITabular QualityFunction>)Parameters["Quality function"]).Value = value; }26 set { ((IValueParameter<ITabularStateValueFunction>)Parameters["Quality function"]).Value = value; } 27 27 } 28 28 … … 39 39 : base() { 40 40 Parameters.Add(new FixedValueParameter<DoubleValue>("R", "The weighting factor for the confidence bound (should be scaled based on the range or the fitness values)", new DoubleValue(1.0))); 41 Parameters.Add(new ValueParameter<ITabular QualityFunction>("Quality function", "The quality function to use", new TabularAvgQualityFunction()));41 Parameters.Add(new ValueParameter<ITabularStateValueFunction>("Quality function", "The quality function to use", new TabularAvgStateValueFunction())); 42 42 } 43 43 44 protected sealed override int Select(object state, IEnumerable<int> actions, IRandom random) { 45 46 // find best action 47 var bestActions = new List<int>(); 44 protected sealed override int Select(IReadOnlyList<object> followStates, IRandom random) { 45 var bestFollowStates = new List<int>(); 48 46 var bestQuality = double.NegativeInfinity; 49 int totalTries = actions.Sum(a => QualityFunction.Tries(state, a)); 50 foreach (var a in actions) { 47 int totalTries = followStates.Sum(s => StateValueFunction.Tries(s)); 48 for (int idx = 0; idx < followStates.Count; idx++) { 49 var s = followStates[idx]; 51 50 double quality; 52 if ( QualityFunction.Tries(state, a) == 0) {51 if (StateValueFunction.Tries(s) == 0) { 53 52 quality = double.PositiveInfinity; 54 53 } else { 55 double v = QualityFunction.QVariance(state, a) + Math.Sqrt(2 * Math.Log(totalTries) / QualityFunction.Tries(state, a));56 quality = QualityFunction.Q(state, a) + R * Math.Sqrt(Math.Log(totalTries) / QualityFunction.Tries(state, a) * v);54 double v = StateValueFunction.ValueVariance(s) + Math.Sqrt(2 * Math.Log(totalTries) / StateValueFunction.Tries(s)); 55 quality = StateValueFunction.Value(s) + R * Math.Sqrt(Math.Log(totalTries) / StateValueFunction.Tries(s) * v); 57 56 } 58 57 if (quality >= bestQuality) { 59 58 if (quality > bestQuality) { 60 best Actions.Clear();59 bestFollowStates.Clear(); 61 60 bestQuality = quality; 62 61 } 63 best Actions.Add(a);62 bestFollowStates.Add(idx); 64 63 } 65 64 } 66 return best Actions.SampleRandom(random, 1).First();65 return bestFollowStates.SampleRandom(random); 67 66 } 68 67 69 public sealed override void Update(IEnumerable<Tuple<object, int>> stateActionSequence, double quality) { 70 foreach (var t in stateActionSequence) { 71 var state = t.Item1; 72 var action = t.Item2; 73 QualityFunction.Update(state, action, quality); 68 public sealed override void Update(IEnumerable<object> stateSequence, double quality) { 69 foreach (var state in stateSequence) { 70 StateValueFunction.Update(state, quality); 74 71 } 75 72 } 76 73 77 protected override object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parent, int childIdx) {78 return QualityFunction.StateFunction.CreateState(root, actions, parent, childIdx);74 protected override object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parent, int childIdx) { 75 return StateValueFunction.StateFunction.CreateState(root, actions, parent, childIdx); 79 76 } 80 77 -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/SearchTree.cs
r12922 r12923 7 7 8 8 namespace HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction { 9 internal class SearchTree { 10 private class Node { 11 internal Node parent; 12 internal Node[] children; 9 internal class SearchTree<TValue> { 10 private class Node<TValue> { 11 internal TValue value; 12 internal Node<TValue> parent; 13 internal Node<TValue>[] children; 13 14 // children == null -> never visited 14 15 // children[i] != null -> visited at least once, still allowed … … 16 17 } 17 18 18 private Node root;19 private Node<TValue> root; 19 20 20 21 // for iteration 21 private Node currentNode;22 private Node<TValue> currentNode; 22 23 23 24 public SearchTree() { 24 root = new Node ();25 root = new Node<TValue>(); 25 26 currentNode = root; 26 27 } … … 34 35 } 35 36 36 public void ExpandCurrentNode <T>(IEnumerable<T> actions) {37 Contract.Assert( actions.Any());37 public void ExpandCurrentNode(IEnumerable<TValue> values) { 38 Contract.Assert(values.Any()); 38 39 Contract.Assert(currentNode.children == null); 39 currentNode.children = actions.Select(_ => new Node() {parent = currentNode }).ToArray();40 currentNode.children = values.Select(val => new Node<TValue>() { value = val, parent = currentNode }).ToArray(); 40 41 } 41 42 42 public void Follow(int action) { 43 Contract.Assert(currentNode.children != null); 44 Contract.Assert(currentNode.children[action] != null); 45 currentNode = currentNode.children[action]; 43 public void Follow(TValue value) { 44 // TODO: perf 45 int i = 0; 46 while (i < currentNode.children.Length && ( 47 currentNode.children[i] == null || !currentNode.children[i].value.Equals(value))) i++; 48 49 if (i >= currentNode.children.Length) throw new InvalidProgramException(); 50 currentNode = currentNode.children[i]; 46 51 } 47 52 48 public IEnumerable< int> PossibleActions {53 public IEnumerable<TValue> ChildValues { 49 54 get { 50 return Enumerable.Range(0, currentNode.children.Length) 51 .Where(i => currentNode.children[i] != null); 55 return from ch in currentNode.children 56 where ch != null 57 select ch.value; 52 58 } 53 59 } 54 60 55 public void RemoveBranch(int action) { 56 Contract.Assert(currentNode.children != null); 57 Contract.Assert(currentNode.children[action] != null); 58 currentNode.children[action] = null; 61 public void RemoveBranch(TValue value) { 62 // TODO: perf 63 int i = 0; 64 while (i < currentNode.children.Length && ( 65 currentNode.children[i] == null || !currentNode.children[i].value.Equals(value))) i++; 66 67 if (i >= currentNode.children.Length) throw new InvalidProgramException(); 68 currentNode.children[i] = null; 59 69 60 70 RemoveRecursively(currentNode); 61 71 } 62 72 63 private void RemoveRecursively(Node node) {73 private void RemoveRecursively(Node<TValue> node) { 64 74 // when the last child has been removed we must remove the current node from it's parent 65 75 while (node.parent != null && node.children.All(ch => ch == null)) { -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/StateFunctions/DefaultStateFunction.cs
r12909 r12923 18 18 } 19 19 20 public object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {21 return string.Join(",", actions );20 public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) { 21 return string.Join(",", actions.Select(a => a.Name)); 22 22 } 23 23 -
branches/HeuristicLab.Algorithms.IteratedSentenceConstruction/HeuristicLab.Algorithms.IteratedSymbolicExpressionConstruction/3.3/StateFunctions/ParentChildStateFunction.cs
r12909 r12923 19 19 } 20 20 21 public object CreateState(ISymbolicExpressionTreeNode root, List< int> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) {21 public object CreateState(ISymbolicExpressionTreeNode root, List<ISymbol> actions, ISymbolicExpressionTreeNode parentNode, int childIdx) { 22 22 return (parentNode == null ? "" : parentNode.Symbol.Name) + childIdx; 23 23 }
Note: See TracChangeset
for help on using the changeset viewer.