Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ExperimentalFunctionsBaking/BakedFunctionTree.cs @ 203

Last change on this file since 203 was 203, checked in by gkronber, 16 years ago

prototype implementation for fast evaluation of basic functions

File size: 12.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using HeuristicLab.DataAnalysis;
28using HeuristicLab.Data;
29
30namespace HeuristicLab.Functions {
31  class BakedFunctionTree : ItemBase, IFunctionTree {
32    private List<double> code;
33    private static double nextFunctionSymbol = 10000;
34    private static Dictionary<double, IFunction> symbolTable = new Dictionary<double, IFunction>();
35    private static Dictionary<IFunction, double> reverseSymbolTable = new Dictionary<IFunction, double>();
36    private static double additionSymbol = -1;
37    private static double substractionSymbol = -1;
38    private static double multiplicationSymbol = -1;
39    private static double divisionSymbol = -1;
40    private static double variableSymbol = -1;
41    private static double constantSymbol = -1;
42
43    internal BakedFunctionTree() {
44      code = new List<double>();
45    }
46
47    internal BakedFunctionTree(IFunction function) : this() {
48      code.Add(0);
49      code.Add(MapFunction(function));
50      code.Add(0);
51      treesExpanded = true;
52      subTrees = new List<IFunctionTree>();
53      variables = new List<IVariable>();
54      variablesExpanded = true;
55      foreach(IVariableInfo variableInfo in function.VariableInfos) {
56        if(variableInfo.Local) {
57          variables.Add((IVariable)function.GetVariable(variableInfo.FormalName).Clone());
58        }
59      }
60    }
61
62    internal BakedFunctionTree(IFunctionTree tree) : this() {
63      code.Add(0);
64      code.Add(MapFunction(tree.Function));
65      code.Add((byte)tree.LocalVariables.Count);
66      foreach(IVariable variable in tree.LocalVariables) {
67        IItem value = variable.Value;
68        code.Add(GetDoubleValue(value));
69      }
70      foreach(IFunctionTree subTree in tree.SubTrees) {
71        AddSubTree(new BakedFunctionTree(subTree));
72      }
73    }
74
75    private double GetDoubleValue(IItem value) {
76      if(value is DoubleData) {
77        return ((DoubleData)value).Data;
78      } else if(value is ConstrainedDoubleData) {
79        return ((ConstrainedDoubleData)value).Data;
80      } else if(value is IntData) {
81        return ((IntData)value).Data;
82      } else if(value is ConstrainedIntData) {
83        return ((ConstrainedIntData)value).Data;
84      } else throw new NotSupportedException("Invalid datatype of local variable for GP");
85    }
86
87    private double MapFunction(IFunction function) {
88      if(!reverseSymbolTable.ContainsKey(function)) {
89        reverseSymbolTable[function] = nextFunctionSymbol;
90        symbolTable[nextFunctionSymbol] = function;
91        if(function is Variable) {
92          variableSymbol = nextFunctionSymbol;
93        } else if(function is Constant) {
94          constantSymbol = nextFunctionSymbol;
95        } else if(function is Addition) {
96          additionSymbol = nextFunctionSymbol;
97        } else if(function is Substraction) {
98          substractionSymbol = nextFunctionSymbol;
99        } else if(function is Multiplication) {
100          multiplicationSymbol = nextFunctionSymbol;
101        } else if(function is Division) {
102          divisionSymbol = nextFunctionSymbol;
103        } else throw new NotSupportedException("Unsupported function " + function);
104
105        nextFunctionSymbol++;
106      }
107      return reverseSymbolTable[function];
108    }
109
110    private int BranchLength(int branchRoot) {
111      double arity = code[branchRoot];
112      int nLocalVariables = (int)code[branchRoot + 2];
113      int len = 3 + nLocalVariables;
114      int subBranchStart = branchRoot + len;
115      for(int i = 0; i < arity; i++) {
116        int branchLen = BranchLength(subBranchStart);
117        len += branchLen;
118        subBranchStart += branchLen;
119      }
120      return len;
121    }
122
123    private void FlattenTrees() {
124      if(treesExpanded) {
125        code[0] = subTrees.Count;
126        foreach(BakedFunctionTree subTree in subTrees) {
127          subTree.FlattenVariables();
128          subTree.FlattenTrees();
129          code.AddRange(subTree.code);
130        }
131        treesExpanded = false;
132        subTrees.Clear();
133      }
134    }
135
136    private void FlattenVariables() {
137      if(variablesExpanded) {
138        code[2] = variables.Count;
139        int localVariableIndex = 3;
140        foreach(IVariable variable in variables) {
141          code.Insert(localVariableIndex, GetDoubleValue(variable.Value));
142          localVariableIndex++;
143        }
144        variablesExpanded = false;
145        variables.Clear();
146      }
147    }
148
149    private bool treesExpanded = false;
150    private List<IFunctionTree> subTrees;
151    public IList<IFunctionTree> SubTrees {
152      get {
153        if(!treesExpanded) {
154          subTrees = new List<IFunctionTree>();
155          double arity = code[0];
156          int nLocalVariables = (int)code[2];
157          int branchIndex = 3 + nLocalVariables;
158          for(int i = 0; i < arity; i++) {
159            BakedFunctionTree subTree = new BakedFunctionTree();
160            int branchLen = BranchLength(branchIndex);
161            subTree.code = code.GetRange(branchIndex, branchLen);
162            branchIndex += branchLen;
163            subTrees.Add(subTree);
164          }
165          treesExpanded = true;
166          code.RemoveRange(3 + nLocalVariables, code.Count - (3 + nLocalVariables));
167          code[0] = 0;
168        }
169        return subTrees;
170      }
171    }
172
173    private bool variablesExpanded = false;
174    private List<IVariable> variables;
175    public ICollection<IVariable> LocalVariables {
176      get {
177        if(!variablesExpanded) {
178          variables = new List<IVariable>();
179          IFunction function = symbolTable[code[1]];
180          int localVariableIndex = 3;
181          foreach(IVariableInfo variableInfo in function.VariableInfos) {
182            if(variableInfo.Local) {
183              IVariable clone = (IVariable)function.GetVariable(variableInfo.FormalName).Clone();
184              IItem value = clone.Value;
185              if(value is ConstrainedDoubleData) {
186                ((ConstrainedDoubleData)value).Data = code[localVariableIndex];
187              } else if(value is ConstrainedIntData) {
188                ((ConstrainedIntData)value).Data = (int)code[localVariableIndex];
189              } else if(value is DoubleData) {
190                ((DoubleData)value).Data = code[localVariableIndex];
191              } else if(value is IntData) {
192                ((IntData)value).Data = (int)code[localVariableIndex];
193              } else throw new NotSupportedException("Invalid local variable type for GP.");
194              variables.Add(clone);
195              localVariableIndex++;
196            }
197          }
198          variablesExpanded = true;
199          code[2] = 0;
200          code.RemoveRange(3, variables.Count);
201        }
202        return variables;
203      }
204    }
205
206    public IFunction Function {
207      get { return symbolTable[code[1]]; }
208    }
209
210    public IVariable GetLocalVariable(string name) {
211      throw new NotImplementedException();
212    }
213
214    public void AddVariable(IVariable variable) {
215      throw new NotImplementedException();
216    }
217
218    public void RemoveVariable(string name) {
219      throw new NotImplementedException();
220    }
221
222    public void AddSubTree(IFunctionTree tree) {
223      if(treesExpanded) {
224        subTrees.Add(tree);
225      } else {
226        //code.AddRange(((BakedFunctionTree)tree).code);
227        //code[0] = code[0] + 1;
228        throw new NotImplementedException();
229      }
230    }
231
232    public void InsertSubTree(int index, IFunctionTree tree) {
233      if(treesExpanded) {
234        subTrees.Insert(index, tree);
235      } else {
236        //byte nLocalVariables = code[2];
237        //// skip branches
238        //int currentBranchIndex = 3 + nLocalVariables;
239        //for(int i = 0; i < index; i++) {
240        //  int branchLength = BranchLength(currentBranchIndex);
241        //  currentBranchIndex += branchLength;
242        //}
243       
244        //code.InsertRange(currentBranchIndex, ((BakedFunctionTree)tree).code);
245        //code[0] = code[0] + 1;
246
247        throw new NotImplementedException();
248      }
249    }
250
251    public void RemoveSubTree(int index) {
252      if(treesExpanded) {
253        subTrees.RemoveAt(index);
254      } else {
255        //int nLocalVariables = (int)code[2];
256        //// skip branches
257        //int currentBranchIndex = 3 + nLocalVariables;
258        //for(int i = 0; i < index; i++) {
259        //  int branchLength = BranchLength(currentBranchIndex);
260        //  currentBranchIndex += branchLength;
261        //}
262        //int deletedBranchLength = BranchLength(currentBranchIndex);
263        //code.RemoveRange(currentBranchIndex, deletedBranchLength);
264        //code[0] = code[0] - 1;
265        throw new NotImplementedException();
266      }
267    }
268
269    private int PC;
270    public double Evaluate(Dataset dataset, int sampleIndex) {
271      PC = 0;
272      FlattenVariables();
273      FlattenTrees();
274      return EvaluateBakedCode(dataset, sampleIndex);
275    }
276
277    private double EvaluateBakedCode(Dataset dataset, int sampleIndex) {
278      double arity = code[PC++];
279      double functionSymbol = code[PC++];
280      double nLocalVariables = code[PC++];
281      if(functionSymbol == variableSymbol) {
282        int var = (int)code[PC++];
283        double weight = code[PC++];
284        int offset = (int)code[PC++];
285        return weight * dataset.GetValue(sampleIndex+offset, var);
286      } else if(functionSymbol == constantSymbol) {
287        double value = code[PC++];
288        return value;
289      } else if(functionSymbol == additionSymbol) {
290        double sum = 0.0;
291        for(int i = 0; i < arity; i++) {
292          sum += EvaluateBakedCode(dataset, sampleIndex);
293        }
294        return sum;
295      } else if(functionSymbol == substractionSymbol) {
296        if(arity == 1) {
297          return -EvaluateBakedCode(dataset, sampleIndex);
298        } else {
299          double result = EvaluateBakedCode(dataset, sampleIndex);
300          for(int i = 1; i < arity; i++) {
301            result -= EvaluateBakedCode(dataset, sampleIndex);
302          }
303          return result;
304        }
305      } else if(functionSymbol == multiplicationSymbol) {
306        double result = 1.0;
307        for(int i = 0; i < arity; i++) {
308          result *= EvaluateBakedCode(dataset, sampleIndex);
309        }
310        return result;
311      } else if(functionSymbol == divisionSymbol) {
312        if(arity == 1) {
313          double divisor = EvaluateBakedCode(dataset, sampleIndex);
314          if(divisor == 0) return 0;
315          else return 1.0 / divisor;
316        } else {
317          double result = EvaluateBakedCode(dataset, sampleIndex);
318          for(int i = 1; i < arity; i++) {
319            double divisor = EvaluateBakedCode(dataset, sampleIndex);
320            if(divisor == 0) result = 0;
321            else result /= divisor;
322          }
323          return result;
324        }
325      } else { throw new NotSupportedException(); }
326    }
327
328    public override System.Xml.XmlNode GetXmlNode(string name, System.Xml.XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
329      throw new NotImplementedException();
330    }
331
332    public override void Populate(System.Xml.XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
333      throw new NotImplementedException();
334    }
335
336    public override object Clone(IDictionary<Guid, object> clonedObjects) {
337      BakedFunctionTree clone = new BakedFunctionTree();
338      if(treesExpanded || variablesExpanded) throw new InvalidOperationException();
339      //if(treesExpanded) FlattenTrees();
340      //if(variablesExpanded) FlattenVariables();
341      clone.code = new List<double>(code);
342      return clone;
343    }
344  }
345}
Note: See TracBrowser for help on using the repository browser.