Changeset 15950
- Timestamp:
- 06/05/18 14:35:03 (7 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/GrammarEnumerationAlgorithm.cs
r15930 r15950 4 4 using System.Threading; 5 5 using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; 6 using HeuristicLab.Collections;7 6 using HeuristicLab.Common; 8 7 using HeuristicLab.Core; … … 19 18 public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm<IRegressionProblem> { 20 19 #region properties and result names 20 private readonly string VariableImportanceWeightName = "Variable Importance Weight"; 21 21 private readonly string SearchStructureSizeName = "Search Structure Size"; 22 22 private readonly string GeneratedPhrasesName = "Generated/Archived Phrases"; … … 29 29 private readonly string ExpansionsPerSecondName = "Expansions per second"; 30 30 31 32 31 private readonly string OptimizeConstantsParameterName = "Optimize Constants"; 33 32 private readonly string ErrorWeightParameterName = "Error Weight"; … … 39 38 public override bool SupportsPause { get { return false; } } 40 39 40 protected IValueParameter<DoubleValue> VariableImportanceWeightParameter { 41 get { return (IValueParameter<DoubleValue>)Parameters[VariableImportanceWeightName]; } 42 } 43 44 protected double VariableImportanceWeight { 45 get { return VariableImportanceWeightParameter.Value.Value; } 46 } 47 41 48 protected IValueParameter<BoolValue> OptimizeConstantsParameter { 42 49 get { return (IValueParameter<BoolValue>)Parameters[OptimizeConstantsParameterName]; } … … 125 132 }; 126 133 134 Parameters.Add(new ValueParameter<DoubleValue>(VariableImportanceWeightName, "Variable Weight.", new DoubleValue(1.0))); 127 135 Parameters.Add(new ValueParameter<BoolValue>(OptimizeConstantsParameterName, "Run constant optimization in sentence evaluation.", new BoolValue(false))); 128 136 Parameters.Add(new ValueParameter<DoubleValue>(ErrorWeightParameterName, "Defines, how much weight is put on a phrase's r² value when priorizing phrases during search.", new DoubleValue(0.8))); … … 163 171 #endregion 164 172 173 private Dictionary<VariableTerminalSymbol, double> variableImportance; 174 165 175 protected override void Run(CancellationToken cancellationToken) { 166 176 #region init … … 191 201 #endregion 192 202 203 #region Variable Importance 204 variableImportance = new Dictionary<VariableTerminalSymbol, double>(); 205 206 RandomForestRegression rf = new RandomForestRegression(); 207 rf.Problem = Problem; 208 rf.Start(); 209 IRegressionSolution rfSolution = (RandomForestRegressionSolution)rf.Results["Random forest regression solution"].Value; 210 var rfImpacts = RegressionSolutionVariableImpactsCalculator.CalculateImpacts( 211 rfSolution, 212 RegressionSolutionVariableImpactsCalculator.DataPartitionEnum.Training, 213 RegressionSolutionVariableImpactsCalculator.ReplacementMethodEnum.Shuffle); 214 215 // save the normalized importances 216 var sum = rfImpacts.Sum(x => x.Item2); 217 foreach (Tuple<string, double> rfImpact in rfImpacts) { 218 VariableTerminalSymbol varSym = Grammar.VarTerminals.First(v => v.StringRepresentation == rfImpact.Item1); 219 variableImportance[varSym] = rfImpact.Item2 / sum; 220 } 221 #endregion 222 193 223 int maxSentenceLength = GetMaxSentenceLength(); 194 224 195 225 OpenPhrases.Store(new SearchNode(phrase0Hash, 0.0, 0.0, phrase0)); 226 227 var errorWeight = ErrorWeight; 228 var variableImportanceWeight = VariableImportanceWeight; 229 196 230 while (OpenPhrases.Count > 0) { 197 231 if (cancellationToken.IsCancellationRequested) break; … … 245 279 246 280 double r2 = GetR2(newPhrase, fetchedSearchNode.R2); 247 double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength );281 double phrasePriority = GetPriority(newPhrase, r2, maxSentenceLength, errorWeight, variableImportanceWeight); 248 282 249 283 SearchNode newSearchNode = new SearchNode(phraseHash, phrasePriority, r2, newPhrase); … … 256 290 } 257 291 258 protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength) { 292 protected double GetPriority(SymbolString phrase, double r2, int maxSentenceLength, double errorWeight, double variableImportanceWeight) { 293 var distinctVars = phrase.OfType<VariableTerminalSymbol>().Distinct(); 294 295 var sum = 0d; 296 foreach (var v in distinctVars) { 297 sum += variableImportance[v]; 298 } 299 var phraseVariableImportance = 1 - sum; 300 259 301 double relLength = (double)phrase.Count() / maxSentenceLength; 260 302 double error = 1.0 - r2; 261 303 262 return relLength + ErrorWeight * error;304 return relLength + errorWeight * error + variableImportanceWeight * phraseVariableImportance; 263 305 } 264 306
Note: See TracChangeset
for help on using the changeset viewer.