using System; using System.Diagnostics; using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Optimization; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration { public class RSquaredEvaluator : Item, IGrammarEnumerationAnalyzer { private readonly string BestTrainingQualityResultName = "Best R² (Training)"; private readonly string BestTrainingModelResultName = "Best model (Training)"; private readonly string BestTrainingSolutionResultName = "Best solution (Training)"; private readonly ISymbolicDataAnalysisExpressionTreeInterpreter expressionTreeLinearInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter(); public bool OptimizeConstants { get; set; } public RSquaredEvaluator() { } protected RSquaredEvaluator(RSquaredEvaluator original, Cloner cloner) { this.OptimizeConstants = original.OptimizeConstants; } public override IDeepCloneable Clone(Cloner cloner) { return new RSquaredEvaluator(this, cloner); } public void Register(GrammarEnumerationAlgorithm algorithm) { algorithm.Started += OnStarted; algorithm.Stopped += OnStopped; algorithm.DistinctSentenceGenerated += AlgorithmOnDistinctSentenceGenerated; } public void Deregister(GrammarEnumerationAlgorithm algorithm) { algorithm.Started -= OnStarted; algorithm.Stopped -= OnStopped; algorithm.DistinctSentenceGenerated -= AlgorithmOnDistinctSentenceGenerated; } private void AlgorithmOnDistinctSentenceGenerated(object sender, PhraseAddedEventArgs phraseAddedEventArgs) { GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender; EvaluateSentence(algorithm, phraseAddedEventArgs.NewPhrase); } private void OnStarted(object sender, EventArgs eventArgs) { GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender; algorithm.Results.Add(new Result(BestTrainingQualityResultName, new DoubleValue(-1.0))); algorithm.BestTrainingSentence = null; } private void OnStopped(object sender, EventArgs eventArgs) { GrammarEnumerationAlgorithm algorithm = (GrammarEnumerationAlgorithm)sender; if (algorithm.Results.ContainsKey(BestTrainingModelResultName)) { SymbolicRegressionModel model = (SymbolicRegressionModel)algorithm.Results[BestTrainingModelResultName].Value; IRegressionSolution bestTrainingSolution = new RegressionSolution(model, algorithm.Problem.ProblemData); algorithm.Results.AddOrUpdateResult(BestTrainingSolutionResultName, bestTrainingSolution); } } private void EvaluateSentence(GrammarEnumerationAlgorithm algorithm, SymbolString symbolString) { var problemData = algorithm.Problem.ProblemData; SymbolicExpressionTree tree = algorithm.Grammar.ParseSymbolicExpressionTree(symbolString); Debug.Assert(SymbolicRegressionConstantOptimizationEvaluator.CanOptimizeConstants(tree)); // TODO: Initialize constant values randomly // TODO: Restarts double r2; SymbolicRegressionModel model = new SymbolicRegressionModel( problemData.TargetVariable, tree, expressionTreeLinearInterpreter); if (OptimizeConstants) { r2 = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(expressionTreeLinearInterpreter, tree, problemData, problemData.TrainingIndices, applyLinearScaling: false, maxIterations: 50, updateVariableWeights: true, updateConstantsInTree: true); foreach (var symbolicExpressionTreeNode in tree.IterateNodesPostfix()) { ConstantTreeNode constTreeNode = symbolicExpressionTreeNode as ConstantTreeNode; if (constTreeNode != null && constTreeNode.Value.IsAlmost(0.0)) { constTreeNode.Value = 0.0; } } } else { var target = problemData.TargetVariableTrainingValues; var estVals = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices); OnlineCalculatorError error; r2 = OnlinePearsonsRCalculator.Calculate(target, estVals, out error); if (error != OnlineCalculatorError.None) r2 = 0.0; } var bestR2Result = (DoubleValue)algorithm.Results[BestTrainingQualityResultName].Value; bool better = r2 > bestR2Result.Value; bool equallyGood = r2.IsAlmost(bestR2Result.Value); bool shorter = false; if (!better && equallyGood) { shorter = algorithm.BestTrainingSentence != null && algorithm.Grammar.GetComplexity(algorithm.BestTrainingSentence) > algorithm.Grammar.GetComplexity(symbolString); } if (better || (equallyGood && shorter)) { bestR2Result.Value = r2; algorithm.Results.AddOrUpdateResult(BestTrainingModelResultName, model); algorithm.BestTrainingSentence = symbolString; } } } }