Changeset 15993


Ignore:
Timestamp:
07/11/18 10:34:21 (3 years ago)
Author:
bburlacu
Message:

#2886: refactor code

Location:
branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/Analysis/RSquaredEvaluator.cs

    r15985 r15993  
    22using System.Diagnostics;
    33using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
     4using HeuristicLab.Analysis;
    45using HeuristicLab.Common;
    56using HeuristicLab.Core;
     
    2021    public static readonly string BestTrainingSolutionResultName = "Best solution (Training)";
    2122    public static readonly string BestComplexityResultName = "Best solution complexity";
     23    public static readonly string BestSolutions = "Best solutions";
    2224
    2325    private static readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    24 
    25     public bool OptimizeConstants { get; set; }
    2626
    2727    public RSquaredEvaluator() { }
     
    3131
    3232    protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) : base(original, cloner) {
    33       this.OptimizeConstants = original.OptimizeConstants;
    3433    }
    3534
     
    5251    private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) {
    5352      GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender;
    54       EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase);
     53      EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase, algorithm.OptimizeConstants);
    5554    }
    5655
     
    7069    }
    7170
    72     private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) {
     71    private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString, bool optimizeConstants) {
    7372      var results = algorithm.Results;
    7473      var grammar = algorithm.Grammar;
     
    7877      Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree));
    7978
    80       double r2 = Evaluate(problemData, tree, OptimizeConstants);
     79      double r2 = Evaluate(problemData, tree, optimizeConstants);
    8180      double bestR2 = results.ContainsKey(BestTrainingQualityResultName) ? GetValue<double>(results[BestTrainingQualityResultName].Value) : 0.0;
    8281      if (r2 < bestR2)
     
    9089        results.AddOrUpdateResult(BestComplexityResultName, new IntValue(complexity));
    9190        algorithm.BestTrainingSentence = symbolString;
     91
     92        // record best sentence quality & length
     93        DataTable dt;
     94        if (!results.ContainsKey(BestSolutions)) {
     95          var names = new[] { "Quality", "Relative Length", "Complexity", "Timestamp" };
     96          dt = new DataTable();
     97          foreach (var name in names) {
     98            dt.Rows.Add(new DataRow(name) { VisualProperties = { StartIndexZero = true } });
     99          }
     100          results.AddOrUpdateResult(BestSolutions, dt);
     101        }
     102        dt = (DataTable)results[BestSolutions].Value;
     103        dt.Rows["Quality"].Values.Add(r2);
     104        dt.Rows["Relative Length"].Values.Add((double)symbolString.Count() / algorithm.MaxSentenceLength);
     105        dt.Rows["Complexity"].Values.Add(complexity);
     106        dt.Rows["Timestamp"].Values.Add(algorithm.ExecutionTime.TotalMilliseconds / 1000d);
    92107      }
    93108    }
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs

    r15975 r15993  
    243243    }
    244244
     245    // returns the maximum achievable sentence length below the maximum complexity
     246    public int GetMaxSentenceLength(int maxComplexity) {
     247      SymbolString s = new SymbolString(StartSymbol);
     248
     249      while (!s.IsSentence() && GetComplexity(s) <= maxComplexity) {
     250        int expandedSymbolIndex = s.NextNonterminalIndex();
     251        NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex];
     252
     253        var productions = Productions[expandedSymbol];
     254        var longestProduction = productions // Find production with most terminal symbols to expand as much as possible...
     255          .OrderBy(CountTerminals)          // but with lowest complexity/nonterminal count to keep complexity low.                                                                                     
     256          .ThenByDescending(CountNonTerminals)
     257          .First();
     258
     259        s = s.DerivePhrase(expandedSymbolIndex, longestProduction);
     260      }
     261
     262      return s.Count();
     263    }
     264
     265    private int CountTerminals(Production p) {
     266      return p.Count(s => s is TerminalSymbol);
     267    }
     268
     269    private int CountNonTerminals(Production p) {
     270      return p.Count(s => s is NonterminalSymbol);
     271    }
     272
    245273    public double EvaluatePhrase(SymbolString s, IRegressionProblemData problemData, bool optimizeConstants) {
    246274      SymbolicExpressionTree tree = ParseSymbolicExpressionTree(s);
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs

    r15987 r15993  
    2121  public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> {
    2222    #region properties and result names
    23     private readonly string VariableImportanceWeightName = "Variable Importance Weight";
    2423    private readonly string SearchStructureSizeName = "Search Structure Size";
    2524    private readonly string GeneratedPhrasesName = "Generated/Archived Phrases";
     
    3837    private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval";
    3938    private readonly string GrammarSymbolsParameterName = "Grammar Symbols";
    40     private readonly string SearchCacheSizeParameterName = "Search Cache Size";
    4139    private readonly string SearchDataStructureSizeParameterName = "Search Data Structure Size";
    4240
     
    4846    public override bool SupportsPause { get { return true; } }
    4947
    50     protected IFixedValueParameter<DoubleValue> VariableImportanceWeightParameter {
    51       get { return (IFixedValueParameter<DoubleValue>)Parameters[VariableImportanceWeightName]; }
    52     }
    53 
    54     public double VariableImportanceWeight {
    55       get { return VariableImportanceWeightParameter.Value.Value; }
    56       set { VariableImportanceWeightParameter.Value.Value = value; }
    57     }
    58 
    5948    protected IFixedValueParameter<BoolValue> OptimizeConstantsParameter {
    6049      get { return (IFixedValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; }
     
    10594    }
    10695
    107     public IFixedValueParameter<IntValue> SearchCacheSizeParameter {
    108       get { return (IFixedValueParameter<IntValue>)Parameters[SearchCacheSizeParameterName]; }
    109     }
    110 
    111     public int SearchCacheSize {
    112       get { return SearchCacheSizeParameter.Value.Value; }
    113     }
    114 
    11596    public StorageType SearchDataStructure {
    11697      get { return SearchDataStructureParameter.Value.Value; }
     
    146127    [Storable]
    147128    internal SearchDataStore OpenPhrases { get; private set; }           // Stack/Queue/etc. for fetching the next node in the search tree. 
     129
     130    [Storable]
     131    public int MaxSentenceLength { get; private set; }
    148132
    149133    #region execution stats
     
    178162      SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare();
    179163      SearchDataStructureSizeParameter.Value.ValueChanged += (o, e) => Prepare();
    180       SearchCacheSizeParameter.Value.ValueChanged += (o, e) => Prepare();
    181164    }
    182165
     
    188171      SearchDataStructureParameter.Value.ValueChanged -= (o, e) => Prepare();
    189172      SearchDataStructureSizeParameter.Value.ValueChanged -= (o, e) => Prepare();
    190       SearchCacheSizeParameter.Value.ValueChanged -= (o, e) => Prepare();
    191173    }
    192174
     
    197179
    198180    public GrammarEnumerationAlgorithm() {
    199       Parameters.Add(new FixedValueParameter<DoubleValue>(VariableImportanceWeightName, "Variable Weight.", new DoubleValue(1.0)));
    200181      Parameters.Add(new FixedValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false)));
    201182      Parameters.Add(new FixedValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8)));
    202183      Parameters.Add(new FixedValueParameter<IntValue>(MaxComplexityParameterName, "The maximum number of variable symbols in a sentence.", new IntValue(12)));
    203184      Parameters.Add(new FixedValueParameter<IntValue>(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(5000)));
    204       Parameters.Add(new FixedValueParameter<IntValue>(SearchCacheSizeParameterName, "The size of the search node cache.", new IntValue((int)1e5)));
    205185      Parameters.Add(new FixedValueParameter<IntValue>(SearchDataStructureSizeParameterName, "The size of the search data structure.", new IntValue((int)1e5)));
    206       Parameters.Add(new FixedValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.PriorityQueue)));
     186      Parameters.Add(new FixedValueParameter<EnumValue<StorageType>>(SearchDataStructureParameterName, new EnumValue<StorageType>(StorageType.SortedSet)));
    207187
    208188      SearchDataStructureParameter.Value.ValueChanged += (o, e) => Prepare();
    209189      SearchDataStructureSizeParameter.Value.ValueChanged += (o, e) => Prepare();
    210       SearchCacheSizeParameter.Value.ValueChanged += (o, e) => Prepare();
    211190
    212191      var availableAnalyzers = new IGrammarEnumerationAnalyzer[] {
     
    258237      OverwrittenSentencesCount = original.OverwrittenSentencesCount;
    259238      PhraseExpansionCount = original.PhraseExpansionCount;
    260 
    261       if (original.variableImportance != null)
    262         variableImportance = new Dictionary<VariableTerminalSymbol, double>(original.variableImportance);
    263239    }
    264240    #endregion
     
    274250      PhraseExpansionCount = 0;
    275251
    276       Analyzers.OfType<RSquaredEvaluator>().First().OptimizeConstants = OptimizeConstants;
    277252      Grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray(), GrammarSymbols.CheckedItems.Select(v => v.Value));
    278       OpenPhrases = new SearchDataStore(SearchDataStructure, SearchDataStructureSize, SearchCacheSize); // Select search strategy
    279 
    280       CalculateVariableImportances();
    281 
     253      OpenPhrases = new SearchDataStore(SearchDataStructure, SearchDataStructureSize); // Select search strategy
    282254      base.Prepare(); // this actually clears the results which will get reinitialized on Run()
    283     }
    284 
    285     private void CalculateVariableImportances() {
    286       variableImportance = new Dictionary<VariableTerminalSymbol, double>();
    287 
    288       RandomForestRegression rf = new RandomForestRegression();
    289       rf.Problem = Problem;
    290       rf.Start();
    291       IRegressionSolution rfSolution = (RandomForestRegressionSolution)rf.Results["Random forest regression solution"].Value;
    292       var rfImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts(
    293         rfSolution,
    294         RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training,
    295         RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle);
    296 
    297       // save the normalized importances
    298       var sum = rfImpacts.Sum(x => x.Item2);
    299       foreach (Tuple<string, double> rfImpact in rfImpacts) {
    300         VariableTerminalSymbol varSym = Grammar.VarTerminals.First(v => v.StringRepresentation == rfImpact.Item1);
    301         variableImportance[varSym] = rfImpact.Item2 / sum;
    302       }
    303255    }
    304256
     
    313265      }
    314266
    315       int maxSentenceLength = GetMaxSentenceLength();
     267      MaxSentenceLength = Grammar.GetMaxSentenceLength(MaxComplexity);
    316268      var errorWeight = ErrorWeight;
    317       var variableImportanceWeight = VariableImportanceWeight;
     269      var optimizeConstants = OptimizeConstants; // cache value to avoid parameter lookup
    318270      // main search loop
    319271      while (OpenPhrases.Count > 0) {
     
    375327
    376328            bool isCompleteSentence = IsCompleteSentence(newPhrase);
    377             double r2 = isCompleteSentence ? Grammar.EvaluatePhrase(newPhrase, Problem.ProblemData, OptimizeConstants) : fetchedSearchNode.R2;
    378             double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength, errorWeight, variableImportanceWeight);
     329            double r2 = isCompleteSentence ? Grammar.EvaluatePhrase(newPhrase, Problem.ProblemData, optimizeConstants) : fetchedSearchNode.R2;
     330            double phrasePriority = GetPriority(newPhrase, r2, MaxSentenceLength);
    379331
    380332            SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase);
     
    386338    }
    387339
    388     protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength, double errorWeight, double variableImportanceWeight) {
    389       var distinctVars = phrase.OfType<VariableTerminalSymbol>().Distinct();
    390 
    391       var sum = 0d;
    392       foreach (var v in distinctVars) {
    393         sum += variableImportance[v];
    394       }
    395       var phraseVariableImportance = 1 - sum;
    396 
    397       double relLength = (double)phrase.Count() / maxSentenceLength;
    398       double error = 1.0 - r2;
    399       return error * relLength;
     340    protected static double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) {
     341      return (1 - r2) * phrase.Count() / maxSentenceLength;
    400342    }
    401343
    402344    private bool IsCompleteSentence(SymbolString phrase) {
    403345      return !phrase.Any(x => x is NonterminalSymbol && x != Grammar.Expr);
    404     }
    405 
    406     private int GetMaxSentenceLength() {
    407       SymbolString s = new SymbolString(Grammar.StartSymbol);
    408 
    409       while (!s.IsSentence() && Grammar.GetComplexity(s) <= MaxComplexity) {
    410         int expandedSymbolIndex = s.NextNonterminalIndex();
    411         NonterminalSymbol expandedSymbol = (NonterminalSymbol)s[expandedSymbolIndex];
    412 
    413         var productions = Grammar.Productions[expandedSymbol];
    414         var longestProduction = productions // Find production with most terminal symbols to expand as much as possible...
    415           .OrderBy(CountTerminals)          // but with lowest complexity/nonterminal count to keep complexity low.                                                                                     
    416           .ThenByDescending(CountNonTerminals)
    417           .First();
    418 
    419         s = s.DerivePhrase(expandedSymbolIndex, longestProduction);
    420       }
    421 
    422       return s.Count();
    423     }
    424 
    425     private int CountTerminals(Production p) {
    426       return p.Count(s => s is TerminalSymbol);
    427     }
    428 
    429     private int CountNonTerminals(Production p) {
    430       return p.Count(s => s is NonterminalSymbol);
    431346    }
    432347
     
    461376      }
    462377
     378      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
    463379      var tree = Grammar.ParseSymbolicExpressionTree(BestTrainingSentence);
    464       var model = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
    465       model.Scale(Problem.ProblemData);
     380      var model = new SymbolicRegressionModel(Problem.ProblemData.TargetVariable, tree, interpreter);
     381
     382      SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(
     383        interpreter,
     384        model.SymbolicExpressionTree,
     385        Problem.ProblemData,
     386        Problem.ProblemData.TrainingIndices,
     387        applyLinearScaling: true,
     388        maxIterations: 10,
     389        updateVariableWeights: false,
     390        updateConstantsInTree: true);
     391
    466392      var bestTrainingSolution = new SymbolicRegressionSolution(model, Problem.ProblemData);
    467393      Results.AddOrUpdateResult(BestTrainingModelResultName, model);
  • branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/SearchDataStructure.cs

    r15977 r15993  
    5252  class SearchDataStore : DeepCloneable, IEnumerable<SearchNode> {
    5353    [Storable]
    54     private LruCache<int, SearchNode> storedValues;
    55     //private Dictionary<int, SearchNode> storedValues; // Store hash-references and associated, actual values
     54    private Dictionary<int, SearchNode> storedValues; // Store hash-references and associated, actual values
    5655
    5756    [Storable]
     
    7271    [Storable]
    7372    private int searchDataStructureSize; // storage size for search nodes
    74 
    75     [Storable]
    76     private int cacheSize; // cache for already explored search nodes
    7773
    7874    [ExcludeFromObjectGraphTraversal]
     
    8783    protected SearchDataStore(bool deserializing) : this() { }
    8884
    89     public SearchDataStore(StorageType storageType, int searchDataStructureSize = (int)1e5, int cacheSize = (int)1e5) {
     85    public SearchDataStore(StorageType storageType, int searchDataStructureSize = (int)1e5) {
    9086      this.storageType = storageType;
    9187
    9288      this.searchDataStructureSize = searchDataStructureSize;
    93       this.cacheSize = cacheSize;
    94 
    95       storedValues = new LruCache<int, SearchNode>(this.cacheSize);
     89
     90      storedValues = new Dictionary<int, SearchNode>();
    9691      InitSearchDataStructure();
    9792    }
     
    142137
    143138    protected SearchDataStore(SearchDataStore original, Cloner cloner) : base(original, cloner) {
    144       storedValues = cloner.Clone(original.storedValues);
     139      storedValues = original.storedValues.ToDictionary(x => x.Key, x => cloner.Clone(x.Value));
    145140      storageType = original.storageType;
    146       cacheSize = original.cacheSize;
    147141      searchDataStructureSize = original.searchDataStructureSize;
    148142
     
    204198              var max = sortedSet.Max;
    205199              sortedSet.Remove(max);
    206               storedValues.Remove(max.Item2);
     200              storedValues.Remove(max.Item2); // should always be in sync with the sorted set
    207201            }
    208202            sortedSet.Add(Tuple.Create(prio, hash));
     
    211205            var elem = sortedSet.FirstOrDefault();
    212206            if (elem == null)
    213               return 0;
     207              return default(int);
    214208            sortedSet.Remove(elem);
    215209            return elem.Item2;
     
    226220        // size is the 0-based index of the last used element
    227221        if (priorityQueue.Size == capacity - 1) {
    228           // if the queue is at capacity we have to replace
    229           return;
     222          return; // if the queue is at capacity we have to return
    230223        }
    231224        priorityQueue.Insert(prio, hash);
Note: See TracChangeset for help on using the changeset viewer.