Free cookie consent management tool by TermsFeed Policy Generator

source: branches/BottomUpTreeEvaluation/BakedTreeEvaluator.cs @ 328

Last change on this file since 328 was 328, checked in by gkronber, 17 years ago

more work on GP evaluation - work in progress - bug nest

File size: 15.5 KB
RevLine 
[266]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[223]23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.DataAnalysis;
[259]27using HeuristicLab.Core;
28using System.Xml;
[223]29
30namespace HeuristicLab.Functions {
[322]31  internal static class BakedTreeEvaluator {
32    private const int MAX_TREE_SIZE = 4096;
[327]33    private const int MAX_TREE_DEPTH = 20;
[321]34    private class Instr {
[328]35      public double result;
[318]36      public double d_arg0;
37      public int i_arg0;
38      public int i_arg1;
39      public int arity;
40      public int symbol;
41    }
42
[327]43    private static int[] nInstr;
44    private static Instr[,] evaluationTable;
[322]45    private static Dataset dataset;
46    private static int sampleIndex;
[223]47
[322]48
49    static BakedTreeEvaluator() {
[327]50      evaluationTable = new Instr[MAX_TREE_SIZE, MAX_TREE_DEPTH];
51      nInstr = new int[MAX_TREE_DEPTH];
52      for(int j = 0; j < MAX_TREE_DEPTH; j++) {
53        for(int i = 0; i < MAX_TREE_SIZE; i++) {
54          evaluationTable[i, j] = new Instr();
55        }
[322]56      }
57    }
58
59    public static void ResetEvaluator(List<LightWeightFunction> linearRepresentation) {
[327]60      int length;
61      for(int i = 0; i < MAX_TREE_DEPTH; i++) nInstr[i] = 0;
[328]62      //TranslateToInstr(0, linearRepresentation, out length);
63      int[] heights = new int[linearRepresentation.Count];
64      CalcHeights(linearRepresentation, heights, 0, out length);
65      TranslateToTable(0, linearRepresentation, heights);
[223]66    }
67
[328]68    private static int CalcHeights(List<LightWeightFunction> linearRepresentation, int[] heights, int p, out int branchLength) {
69      if(linearRepresentation[p].arity == 0) {
70        heights[p] = 1;
71        branchLength = 1;
72        return 1;
73      }
[327]74      int height = 0;
75      int length = 1;
[328]76      for(int i = 0; i < linearRepresentation[p].arity; i++) {
[327]77        int curBranchLength;
[328]78        int curHeight = CalcHeights(linearRepresentation, heights, p + length, out curBranchLength);
79        if(curHeight > height) {
80          height = curHeight;
81        }
[327]82        length += curBranchLength;
83      }
[328]84      heights[p] = height+1;
85      branchLength = length;
86      return height+1;
87    }
88
89    private static int TranslateToTable(int pos, List<LightWeightFunction> list, int[] heights) {
90      LightWeightFunction f = list[pos];
91      if(f.arity == 0) {
92        Instr instr = evaluationTable[nInstr[0], 0];
93        instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
94        switch(instr.symbol) {
95          case EvaluatorSymbolTable.VARIABLE: {
96              instr.i_arg0 = (int)f.data[0]; // var
97              instr.d_arg0 = f.data[1]; // weight
98              instr.i_arg1 = (int)f.data[2]; // sample-offset
99              break;
100            }
101          case EvaluatorSymbolTable.CONSTANT: {
102              instr.result = f.data[0]; // value
103              break;
104            }
105        }
106        nInstr[0]++;
107        return 1;
108      } else {
109        int length = 1;
110        int height = heights[pos];
111        for(int i = 0; i < f.arity; i++) {
112          int curBranchHeight = heights[pos + length];
113          if(curBranchHeight < height - 1) {
114            for(int j = curBranchHeight; j < height - 1; j++) {
115              evaluationTable[nInstr[j], j].symbol = EvaluatorSymbolTable.IDENTITY;
116              nInstr[j]++;
117            }
[318]118          }
[328]119          int curBranchLength = TranslateToTable(pos + length, list, heights);
120          length += curBranchLength;
121        }
122
123        Instr cell = evaluationTable[nInstr[height-1], height-1];
124        nInstr[height-1]++;
125        cell.arity = f.arity;
126        cell.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
127        return length;
[318]128      }
129    }
130
[328]131
132    //private static int TranslateToInstr(int pos, List<LightWeightFunction> linearRepresentation, out int branchLength) {
133    //  int height = 0;
134    //  int length = 1;
135    //  LightWeightFunction f = linearRepresentation[pos];
136    //  for(int i = 0; i < f.arity; i++) {
137    //    int curBranchLength;
138    //    int curBranchHeight = TranslateToInstr(pos + length, linearRepresentation, out curBranchLength);
139    //    if(curBranchHeight > height) height = curBranchHeight;
140    //    length += curBranchLength;
141    //  }
142    //  Instr instr = evaluationTable[nInstr[height], height];
[327]143    //  instr.arity = f.arity;
144    //  instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
145    //  switch(instr.symbol) {
146    //    case EvaluatorSymbolTable.VARIABLE: {
147    //        instr.i_arg0 = (int)f.data[0]; // var
148    //        instr.d_arg0 = f.data[1]; // weight
149    //        instr.i_arg1 = (int)f.data[2]; // sample-offset
150    //        break;
151    //      }
152    //    case EvaluatorSymbolTable.CONSTANT: {
[328]153    //        instr.result = f.data[0]; // value
[327]154    //        break;
155    //      }
156    //  }
[328]157    //  nInstr[height]++;
158    //  branchLength = length;
159    //  return height + 1;
[327]160    //}
161
[322]162    internal static double Evaluate(Dataset dataset, int sampleIndex) {
163      BakedTreeEvaluator.sampleIndex = sampleIndex;
164      BakedTreeEvaluator.dataset = dataset;
[327]165      return EvaluateTable();
[223]166    }
167
[327]168    private static double EvaluateTable() {
169      int terminalP = 0;
[328]170      // process remaining instr first
171      for(int i = 0; i < nInstr[0] % 4; i++) {
172        Instr curInstr = evaluationTable[terminalP++, 0];
173        if(curInstr.symbol == EvaluatorSymbolTable.VARIABLE) {
174          int row = sampleIndex + curInstr.i_arg1;
175          if(row < 0 || row >= dataset.Rows) curInstr.result = double.NaN;
176          else curInstr.result = curInstr.d_arg0 * dataset.GetValue(row, curInstr.i_arg0);
177        }
178      }
179      // unrolled loop
180      for(; terminalP < nInstr[0] - 4; terminalP += 4) {
[327]181        Instr curInstr0 = evaluationTable[terminalP, 0];
182        Instr curInstr1 = evaluationTable[terminalP + 1, 0];
[328]183        Instr curInstr2 = evaluationTable[terminalP + 2, 0];
184        Instr curInstr3 = evaluationTable[terminalP + 3, 0];
[327]185        if(curInstr0.symbol == EvaluatorSymbolTable.VARIABLE) {
186          int row = sampleIndex + curInstr0.i_arg1;
[328]187          if(row < 0 || row >= dataset.Rows) curInstr0.result = double.NaN;
188          else curInstr0.result = curInstr0.d_arg0 * dataset.GetValue(row, curInstr0.i_arg0);
[327]189        }
190        if(curInstr1.symbol == EvaluatorSymbolTable.VARIABLE) {
191          int row = sampleIndex + curInstr1.i_arg1;
[328]192          if(row < 0 || row >= dataset.Rows) curInstr1.result = double.NaN;
193          else curInstr1.result = curInstr1.d_arg0 * dataset.GetValue(row, curInstr1.i_arg0);
[327]194        }
[328]195        if(curInstr2.symbol == EvaluatorSymbolTable.VARIABLE) {
196          int row = sampleIndex + curInstr2.i_arg1;
197          if(row < 0 || row >= dataset.Rows) curInstr2.result = double.NaN;
198          else curInstr2.result = curInstr2.d_arg0 * dataset.GetValue(row, curInstr2.i_arg0);
199        }
200        if(curInstr3.symbol == EvaluatorSymbolTable.VARIABLE) {
201          int row = sampleIndex + curInstr3.i_arg1;
202          if(row < 0 || row >= dataset.Rows) curInstr3.result = double.NaN;
203          else curInstr3.result = curInstr3.d_arg0 * dataset.GetValue(row, curInstr3.i_arg0);
204        }
[327]205      }
206
207      int curLevel = 1;
208      while(nInstr[curLevel] > 0) {
209        int lastLayerInstrP = 0;
210        for(int curLayerInstrP = 0; curLayerInstrP < nInstr[curLevel]; curLayerInstrP++) {
211          Instr curInstr = evaluationTable[curLayerInstrP, curLevel];
212          switch(curInstr.symbol) {
213            case EvaluatorSymbolTable.MULTIPLICATION: {
[328]214                curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]215                for(int i = 1; i < curInstr.arity; i++) {
[328]216                  curInstr.result *= evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]217                }
218                break;
[223]219              }
[327]220            case EvaluatorSymbolTable.ADDITION: {
[328]221                curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]222                for(int i = 1; i < curInstr.arity; i++) {
[328]223                  curInstr.result += evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]224                }
225                break;
[223]226              }
[327]227            case EvaluatorSymbolTable.SUBTRACTION: {
228                if(curInstr.arity == 1) {
[328]229                  curInstr.result = -evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]230                } else {
[328]231                  curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]232                  for(int i = 1; i < curInstr.arity; i++) {
[328]233                    curInstr.result -= evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]234                  }
235                }
236                break;
237              }
238            case EvaluatorSymbolTable.DIVISION: {
239                if(curInstr.arity == 1) {
[328]240                  curInstr.result = 1.0 / evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]241                } else {
[328]242                  curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]243                  for(int i = 1; i < curInstr.arity; i++) {
[328]244                    curInstr.result /= evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]245                  }
246                }
[328]247                if(double.IsInfinity(curInstr.result)) curInstr.result = 0.0;
[327]248                break;
249              }
250            case EvaluatorSymbolTable.AVERAGE: {
[328]251                curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]252                for(int i = 1; i < curInstr.arity; i++) {
[328]253                  curInstr.result += evaluationTable[lastLayerInstrP++, curLevel - 1].result;
[327]254                }
[328]255                curInstr.result /= curInstr.arity;
[327]256                break;
257              }
258            case EvaluatorSymbolTable.COSINUS: {
[328]259                curInstr.result = Math.Cos(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]260                break;
261              }
262            case EvaluatorSymbolTable.SINUS: {
[328]263                curInstr.result = Math.Sin(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]264                break;
265              }
266            case EvaluatorSymbolTable.EXP: {
[328]267                curInstr.result = Math.Exp(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]268                break;
269              }
270            case EvaluatorSymbolTable.LOG: {
[328]271                curInstr.result = Math.Log(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]272                break;
273              }
274            case EvaluatorSymbolTable.POWER: {
[328]275                double x = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
276                double p = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
277                curInstr.result = Math.Pow(x, p);
[327]278                break;
279              }
280            case EvaluatorSymbolTable.SIGNUM: {
[328]281                double value = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
282                if(double.IsNaN(value)) curInstr.result = double.NaN;
283                else curInstr.result = Math.Sign(value);
[327]284                break;
285              }
286            case EvaluatorSymbolTable.SQRT: {
[328]287                curInstr.result = Math.Sqrt(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]288                break;
289              }
290            case EvaluatorSymbolTable.TANGENS: {
[328]291                curInstr.result = Math.Tan(evaluationTable[lastLayerInstrP++, curLevel - 1].result);
[327]292                break;
293              }
294            //case EvaluatorSymbolTable.AND: {
295            //    double result = 1.0;
296            //    // have to evaluate all sub-trees, skipping would probably not lead to a big gain because
297            //    // we have to iterate over the linear structure anyway
298            //    for(int i = 0; i < currInstr.arity; i++) {
299            //      double x = Math.Round(EvaluateBakedCode());
300            //      if(x == 0 || x == 1.0) result *= x;
301            //      else result = double.NaN;
302            //    }
303            //    return result;
304            //  }
305            //case EvaluatorSymbolTable.EQU: {
306            //    double x = EvaluateBakedCode();
307            //    double y = EvaluateBakedCode();
308            //    if(x == y) return 1.0; else return 0.0;
309            //  }
310            //case EvaluatorSymbolTable.GT: {
311            //    double x = EvaluateBakedCode();
312            //    double y = EvaluateBakedCode();
313            //    if(x > y) return 1.0;
314            //    else return 0.0;
315            //  }
316            //case EvaluatorSymbolTable.IFTE: {
317            //    double condition = Math.Round(EvaluateBakedCode());
318            //    double x = EvaluateBakedCode();
319            //    double y = EvaluateBakedCode();
320            //    if(condition < .5) return x;
321            //    else if(condition >= .5) return y;
322            //    else return double.NaN;
323            //  }
324            //case EvaluatorSymbolTable.LT: {
325            //    double x = EvaluateBakedCode();
326            //    double y = EvaluateBakedCode();
327            //    if(x < y) return 1.0;
328            //    else return 0.0;
329            //  }
330            //case EvaluatorSymbolTable.NOT: {
331            //    double result = Math.Round(EvaluateBakedCode());
332            //    if(result == 0.0) return 1.0;
333            //    else if(result == 1.0) return 0.0;
334            //    else return double.NaN;
335            //  }
336            //case EvaluatorSymbolTable.OR: {
337            //    double result = 0.0; // default is false
338            //    for(int i = 0; i < currInstr.arity; i++) {
339            //      double x = Math.Round(EvaluateBakedCode());
340            //      if(x == 1.0 && result == 0.0) result = 1.0; // found first true (1.0) => set to true
341            //      else if(x != 0.0) result = double.NaN; // if it was not true it can only be false (0.0) all other cases are undefined => (NaN)
342            //    }
343            //    return result;
344            //  }
345            //case EvaluatorSymbolTable.XOR: {
346            //    double x = Math.Round(EvaluateBakedCode());
347            //    double y = Math.Round(EvaluateBakedCode());
348            //    if(x == 0.0 && y == 0.0) return 0.0;
349            //    if(x == 1.0 && y == 0.0) return 1.0;
350            //    if(x == 0.0 && y == 1.0) return 1.0;
351            //    if(x == 1.0 && y == 1.0) return 0.0;
352            //    return double.NaN;
353            //  }
[328]354            case EvaluatorSymbolTable.IDENTITY: {
355                curInstr.result = evaluationTable[lastLayerInstrP++, curLevel - 1].result;
356                break;
357              }
[327]358            default: {
359                throw new NotImplementedException();
360              }
[223]361          }
[327]362        }
363        curLevel++;
[223]364      }
[328]365      return evaluationTable[0, curLevel - 1].result;
[223]366    }
367  }
368}
Note: See TracBrowser for help on using the repository browser.