Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/05/18 14:35:03 (6 years ago)
Author:
bburlacu
Message:

#2886: Try to use variable importance information (from a random forest) to guide the search.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs

    r15930 r15950  
    44using System.Threading;
    55using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
    6 using HeuristicLab.Collections;
    76using HeuristicLab.Common;
    87using HeuristicLab.Core;
     
    1918  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    2019    #region properties and result names
     20    private readonly string VariableImportanceWeightName = "Variable Importance Weight";
    2121    private readonly string SearchStructureSizeName = "Search Structure Size";
    2222    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
     
    2929    private readonly string ExpansionsPerSecondName = "Expansions per second";
    3030
    31 
    3231    private readonly string OptimizeConstantsParameterName = "Optimize Constants";
    3332    private readonly string ErrorWeightParameterName = "Error Weight";
     
    3938    public override bool SupportsPause { get { return false; } }
    4039
     40    protected IValueParameter<DoubleValue> VariableImportanceWeightParameter {
     41      get { return (IValueParameter<DoubleValue>)Parameters[VariableImportanceWeightName]; }
     42    }
     43
     44    protected double VariableImportanceWeight {
     45      get { return VariableImportanceWeightParameter.Value.Value; }
     46    }
     47
    4148    protected IValueParameter<BoolValue> OptimizeConstantsParameter {
    4249      get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
     
    125132      };
    126133
     134      Parameters.Add(new ValueParameter<DoubleValue>(VariableImportanceWeightName, "Variable Weight.", new DoubleValue(1.0)));
    127135      Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
    128136      Parameters.Add(new ValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
     
    163171    #endregion
    164172
     173    private Dictionary<VariableTerminalSymbol, double> variableImportance;
     174
    165175    protected override void Run(CancellationToken cancellationToken) {
    166176      #region init
     
    191201      #endregion
    192202
     203      #region Variable Importance
     204      variableImportance = new Dictionary<VariableTerminalSymbol, double>();
     205
     206      RandomForestRegression rf = new RandomForestRegression();
     207      rf.Problem = Problem;
     208      rf.Start();
     209      IRegressionSolution rfSolution = (RandomForestRegressionSolution)rf.Results["Random forest regression solution"].Value;
     210      var rfImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
     211        rfSolution,
     212        RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training,
     213        RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle);
     214
     215      // save the normalized importances
     216      var sum = rfImpacts.Sum(x => x.Item2);
     217      foreach (Tuple<string, double> rfImpact in rfImpacts) {
     218        VariableTerminalSymbol varSym = Grammar.VarTerminals.First(v => v.StringRepresentation == rfImpact.Item1);
     219        variableImportance[varSym] = rfImpact.Item2 / sum;
     220      }
     221      #endregion
     222
    193223      int maxSentenceLength = GetMaxSentenceLength();
    194224
    195225      OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0));
     226
     227      var errorWeight = ErrorWeight;
     228      var variableImportanceWeight = VariableImportanceWeight;
     229
    196230      while (OpenPhrases.Count > 0) {
    197231        if (cancellationToken.IsCancellationRequested) break;
     
    245279
    246280              double r2 = GetR2(newPhrase, fetchedSearchNode.R2);
    247               double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength);
     281              double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength, errorWeight, variableImportanceWeight);
    248282
    249283              SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase);
     
    256290    }
    257291
    258     protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) {
     292    protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength, double errorWeight, double variableImportanceWeight) {
     293      var distinctVars = phrase.OfType<VariableTerminalSymbol>().Distinct();
     294
     295      var sum = 0d;
     296      foreach (var v in distinctVars) {
     297        sum += variableImportance[v];
     298      }
     299      var phraseVariableImportance = 1 - sum;
     300
    259301      double relLength = (double)phrase.Count() / maxSentenceLength;
    260302      double error = 1.0 - r2;
    261303
    262       return relLength + ErrorWeight * error;
     304      return relLength + errorWeight * error + variableImportanceWeight * phraseVariableImportance;
    263305    }
    264306
Note: See TracChangeset for help on using the changeset viewer.