Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/16/08 00:39:50 (16 years ago)
Author:
gkronber
Message:

worked on #364 (Improve GP evaluation performance)

  • removed list of Instr
  • changes that didn't affect performance directly: reduced size of Instr class, added 'constant-folding' and pre-calculation of skip-lengths


Location:
trunk/sources/HeuristicLab.GP.StructureIdentification
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.GP.StructureIdentification/BakedTreeEvaluator.cs

    r702 r767  
    4141    private class Instr {
    4242      public double d_arg0;
    43       public int i_arg0;
    44       public int i_arg1;
    45       public int arity;
    46       public int symbol;
     43      public short i_arg0;
     44      public short i_arg1;
     45      public byte arity;
     46      public byte symbol;
     47      public ushort exprLength;
    4748      public IFunction function;
    4849    }
    4950
    50     private List<Instr> code;
    5151    private Instr[] codeArr;
    5252    private int PC;
     
    5454    private int sampleIndex;
    5555
    56 
    57     public BakedTreeEvaluator() {
    58       code = new List<Instr>();
    59     }
    60 
    6156    public void ResetEvaluator(BakedFunctionTree functionTree, Dataset dataset, int targetVariable, int start, int end, double punishmentFactor) {
    6257      this.dataset = dataset;
    6358      double maximumPunishment = punishmentFactor * dataset.GetRange(targetVariable);
    6459
    65       // get the mean of the values of the target variable to determin the max and min bounds of the estimated value
     60      // get the mean of the values of the target variable to determine the max and min bounds of the estimated value
    6661      double targetMean = dataset.GetMean(targetVariable, start, end - 1);
    6762      estimatedValueMin = targetMean - maximumPunishment;
     
    6964
    7065      List<LightWeightFunction> linearRepresentation = functionTree.LinearRepresentation;
    71       code.Clear();
    72       foreach(LightWeightFunction f in linearRepresentation) {
    73         Instr curInstr = new Instr();
    74         TranslateToInstr(f, curInstr);
    75         code.Add(curInstr);
    76       }
    77 
    78       codeArr = code.ToArray<Instr>();
    79     }
    80 
    81     private void TranslateToInstr(LightWeightFunction f, Instr instr) {
     66      codeArr = new Instr[linearRepresentation.Count];
     67      int i = 0;
     68      foreach (LightWeightFunction f in linearRepresentation) {
     69        codeArr[i++] = TranslateToInstr(f);
     70      }
     71      exprIndex = 0;
     72      ushort exprLength;
     73      bool constExpr;
     74      PatchExpressionLengthsAndConstants(0, out constExpr, out exprLength);
     75    }
     76
     77    ushort exprIndex;
     78    private void PatchExpressionLengthsAndConstants(ushort index, out bool constExpr, out ushort exprLength) {
     79      exprLength = 1;
     80      if (codeArr[index].arity == 0) {
     81        // when no children then it's a constant expression only if the terminal is a constant
     82        constExpr = codeArr[index].symbol == EvaluatorSymbolTable.CONSTANT;
     83      } else {
     84        constExpr = true; // when there are children it's a constant expression if all children are constant;
     85      }
     86      for (int i = 0; i < codeArr[index].arity; i++) {
     87        exprIndex++;
     88        ushort branchLength;
     89        bool branchConstExpr;
     90        PatchExpressionLengthsAndConstants(exprIndex, out branchConstExpr, out branchLength);
     91        exprLength += branchLength;
     92        constExpr &= branchConstExpr;
     93      }
     94      codeArr[index].exprLength = exprLength;
     95
     96      if (constExpr) {
     97        codeArr[index].symbol = EvaluatorSymbolTable.CONSTANT;
     98        PC = index;
     99        codeArr[index].d_arg0 = EvaluateBakedCode();
     100      }
     101    }
     102
     103    private Instr TranslateToInstr(LightWeightFunction f) {
     104      Instr instr = new Instr();
    82105      instr.arity = f.arity;
    83106      instr.symbol = EvaluatorSymbolTable.MapFunction(f.functionType);
    84       switch(instr.symbol) {
     107      switch (instr.symbol) {
    85108        case EvaluatorSymbolTable.DIFFERENTIAL:
    86109        case EvaluatorSymbolTable.VARIABLE: {
    87             instr.i_arg0 = (int)f.data[0]; // var
     110            instr.i_arg0 = (byte)f.data[0]; // var
    88111            instr.d_arg0 = f.data[1]; // weight
    89             instr.i_arg1 = (int)f.data[2]; // sample-offset
     112            instr.i_arg1 = (byte)f.data[2]; // sample-offset
     113            instr.exprLength = 1;
    90114            break;
    91115          }
    92116        case EvaluatorSymbolTable.CONSTANT: {
    93117            instr.d_arg0 = f.data[0]; // value
     118            instr.exprLength = 1;
    94119            break;
    95120          }
    96121        case EvaluatorSymbolTable.UNKNOWN: {
    97122            instr.function = f.functionType;
     123            instr.exprLength = 1;
    98124            break;
    99125          }
    100126      }
     127      return instr;
    101128    }
    102129
     
    106133
    107134      double estimated = EvaluateBakedCode();
    108       if(double.IsNaN(estimated) || double.IsInfinity(estimated)) {
     135      if (double.IsNaN(estimated) || double.IsInfinity(estimated)) {
    109136        estimated = estimatedValueMax;
    110       } else if(estimated > estimatedValueMax) {
     137      } else if (estimated > estimatedValueMax) {
    111138        estimated = estimatedValueMax;
    112       } else if(estimated < estimatedValueMin) {
     139      } else if (estimated < estimatedValueMin) {
    113140        estimated = estimatedValueMin;
    114141      }
     
    118145    // skips a whole branch
    119146    private void SkipBakedCode() {
    120       int i = 1;
    121       while(i > 0) {
    122         i += code[PC++].arity;
    123         i--;
    124       }
     147      PC += codeArr[PC].exprLength;
    125148    }
    126149
    127150    private double EvaluateBakedCode() {
    128151      Instr currInstr = codeArr[PC++];
    129       switch(currInstr.symbol) {
     152      switch (currInstr.symbol) {
    130153        case EvaluatorSymbolTable.VARIABLE: {
    131154            int row = sampleIndex + currInstr.i_arg1;
    132             if(row < 0 || row >= dataset.Rows) return double.NaN;
     155            if (row < 0 || row >= dataset.Rows) return double.NaN;
    133156            else return currInstr.d_arg0 * dataset.GetValue(row, currInstr.i_arg0);
    134157          }
    135158        case EvaluatorSymbolTable.CONSTANT: {
     159            PC += currInstr.exprLength - 1;
    136160            return currInstr.d_arg0;
    137161          }
    138162        case EvaluatorSymbolTable.DIFFERENTIAL: {
    139163            int row = sampleIndex + currInstr.i_arg1;
    140             if(row < 1 || row >= dataset.Rows) return double.NaN;
     164            if (row < 1 || row >= dataset.Rows) return double.NaN;
    141165            else return currInstr.d_arg0 * (dataset.GetValue(row, currInstr.i_arg0) - dataset.GetValue(row - 1, currInstr.i_arg0));
    142166          }
    143167        case EvaluatorSymbolTable.MULTIPLICATION: {
    144168            double result = EvaluateBakedCode();
    145             for(int i = 1; i < currInstr.arity; i++) {
     169            for (int i = 1; i < currInstr.arity; i++) {
    146170              result *= EvaluateBakedCode();
    147171            }
     
    150174        case EvaluatorSymbolTable.ADDITION: {
    151175            double sum = EvaluateBakedCode();
    152             for(int i = 1; i < currInstr.arity; i++) {
     176            for (int i = 1; i < currInstr.arity; i++) {
    153177              sum += EvaluateBakedCode();
    154178            }
     
    156180          }
    157181        case EvaluatorSymbolTable.SUBTRACTION: {
    158             if(currInstr.arity == 1) {
     182            if (currInstr.arity == 1) {
    159183              return -EvaluateBakedCode();
    160184            } else {
    161185              double result = EvaluateBakedCode();
    162               for(int i = 1; i < currInstr.arity; i++) {
     186              for (int i = 1; i < currInstr.arity; i++) {
    163187                result -= EvaluateBakedCode();
    164188              }
     
    168192        case EvaluatorSymbolTable.DIVISION: {
    169193            double result;
    170             if(currInstr.arity == 1) {
     194            if (currInstr.arity == 1) {
    171195              result = 1.0 / EvaluateBakedCode();
    172196            } else {
    173197              result = EvaluateBakedCode();
    174               for(int i = 1; i < currInstr.arity; i++) {
     198              for (int i = 1; i < currInstr.arity; i++) {
    175199                result /= EvaluateBakedCode();
    176200              }
    177201            }
    178             if(double.IsInfinity(result)) return 0.0;
     202            if (double.IsInfinity(result)) return 0.0;
    179203            else return result;
    180204          }
    181205        case EvaluatorSymbolTable.AVERAGE: {
    182206            double sum = EvaluateBakedCode();
    183             for(int i = 1; i < currInstr.arity; i++) {
     207            for (int i = 1; i < currInstr.arity; i++) {
    184208              sum += EvaluateBakedCode();
    185209            }
     
    205229        case EvaluatorSymbolTable.SIGNUM: {
    206230            double value = EvaluateBakedCode();
    207             if(double.IsNaN(value)) return double.NaN;
     231            if (double.IsNaN(value)) return double.NaN;
    208232            else return Math.Sign(value);
    209233          }
     
    216240        case EvaluatorSymbolTable.AND: { // only defined for inputs 1 and 0
    217241            double result = EvaluateBakedCode();
    218             for(int i = 1; i < currInstr.arity; i++) {
    219               if(result == 0.0) SkipBakedCode();
     242            for (int i = 1; i < currInstr.arity; i++) {
     243              if (result == 0.0) SkipBakedCode();
    220244              else {
    221245                result = EvaluateBakedCode();
     
    228252            double x = EvaluateBakedCode();
    229253            double y = EvaluateBakedCode();
    230             if(Math.Abs(x - y) < EPSILON) return 1.0; else return 0.0;
     254            if (Math.Abs(x - y) < EPSILON) return 1.0; else return 0.0;
    231255          }
    232256        case EvaluatorSymbolTable.GT: {
    233257            double x = EvaluateBakedCode();
    234258            double y = EvaluateBakedCode();
    235             if(x > y) return 1.0;
     259            if (x > y) return 1.0;
    236260            else return 0.0;
    237261          }
     
    240264            Debug.Assert(condition == 0.0 || condition == 1.0);
    241265            double result;
    242             if(condition == 0.0) {
     266            if (condition == 0.0) {
    243267              result = EvaluateBakedCode(); SkipBakedCode();
    244268            } else {
     
    250274            double x = EvaluateBakedCode();
    251275            double y = EvaluateBakedCode();
    252             if(x < y) return 1.0;
     276            if (x < y) return 1.0;
    253277            else return 0.0;
    254278          }
     
    260284        case EvaluatorSymbolTable.OR: { // only defined for inputs 0 or 1
    261285            double result = EvaluateBakedCode();
    262             for(int i = 1; i < currInstr.arity; i++) {
    263               if(result > 0.0) SkipBakedCode();
     286            for (int i = 1; i < currInstr.arity; i++) {
     287              if (result > 0.0) SkipBakedCode();
    264288              else {
    265289                result = EvaluateBakedCode();
  • trunk/sources/HeuristicLab.GP.StructureIdentification/SymbolTable.cs

    r656 r767  
    2929namespace HeuristicLab.GP.StructureIdentification {
    3030  class EvaluatorSymbolTable : StorableBase {
    31     public const int ADDITION = 1;
    32     public const int AND = 2;
    33     public const int AVERAGE = 3;
    34     public const int CONSTANT = 4;
    35     public const int COSINUS = 5;
    36     public const int DIFFERENTIAL = 25;
    37     public const int DIVISION = 6;
    38     public const int EQU = 7;
    39     public const int EXP = 8;
    40     public const int GT = 9;
    41     public const int IFTE = 10;
    42     public const int LT = 11;
    43     public const int LOG = 12;
    44     public const int MULTIPLICATION = 13;
    45     public const int NOT = 14;
    46     public const int OR = 15;
    47     public const int POWER = 16;
    48     public const int SIGNUM = 17;
    49     public const int SINUS = 18;
    50     public const int SQRT = 19;
    51     public const int SUBTRACTION = 20;
    52     public const int TANGENS = 21;
    53     public const int VARIABLE = 22;
    54     public const int XOR = 23;
    55     public const int UNKNOWN = 24;
     31    public const byte ADDITION = 1;
     32    public const byte AND = 2;
     33    public const byte AVERAGE = 3;
     34    public const byte CONSTANT = 4;
     35    public const byte COSINUS = 5;
     36    public const byte DIFFERENTIAL = 25;
     37    public const byte DIVISION = 6;
     38    public const byte EQU = 7;
     39    public const byte EXP = 8;
     40    public const byte GT = 9;
     41    public const byte IFTE = 10;
     42    public const byte LT = 11;
     43    public const byte LOG = 12;
     44    public const byte MULTIPLICATION = 13;
     45    public const byte NOT = 14;
     46    public const byte OR = 15;
     47    public const byte POWER = 16;
     48    public const byte SIGNUM = 17;
     49    public const byte SINUS = 18;
     50    public const byte SQRT = 19;
     51    public const byte SUBTRACTION = 20;
     52    public const byte TANGENS = 21;
     53    public const byte VARIABLE = 22;
     54    public const byte XOR = 23;
     55    public const byte UNKNOWN = 24;
    5656
    57     private static Dictionary<Type, int> staticTypes = new Dictionary<Type, int>();
     57    private static Dictionary<Type, byte> staticTypes = new Dictionary<Type, byte>();
    5858
    5959    // needs to be public for persistence mechanism (Activator.CreateInstance needs empty constructor)
    6060    static EvaluatorSymbolTable() {
    61       staticTypes = new Dictionary<Type, int>();
     61      staticTypes = new Dictionary<Type, byte>();
    6262      staticTypes[typeof(Addition)] = ADDITION;
    6363      staticTypes[typeof(And)] = AND;
     
    8686    }
    8787
    88     internal static int MapFunction(IFunction function) {
     88    internal static byte MapFunction(IFunction function) {
    8989      if(staticTypes.ContainsKey(function.GetType())) return staticTypes[function.GetType()];
    9090      else return UNKNOWN;
Note: See TracChangeset for help on using the changeset viewer.