using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Diagnostics; using System.Linq; using System.Threading; using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration { [Item("Grammar Enumeration Symbolic Regression", "Iterates all possible model structures for a fixed grammar.")] [StorableClass] [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 250)] public class GrammarEnumerationAlgorithm : FixedDataAnalysisAlgorithm { private readonly string BestTrainingSolution = "Best solution (training)"; private readonly string BestTrainingSolutionQuality = "Best solution quality (training)"; private readonly string BestTestSolution = "Best solution (test)"; private readonly string BestTestSolutionQuality = "Best solution quality (test)"; private readonly string MaxTreeSizeParameterName = "Max. Tree Nodes"; private readonly string GuiUpdateIntervalParameterName = "GUI Update Interval"; #region properties public IValueParameter MaxTreeSizeParameter { get { return (IValueParameter)Parameters[MaxTreeSizeParameterName]; } } public int MaxTreeSize { get { return MaxTreeSizeParameter.Value.Value; } } public IValueParameter GuiUpdateIntervalParameter { get { return (IValueParameter)Parameters[MaxTreeSizeParameterName]; } } public int GuiUpdateInterval { get { return GuiUpdateIntervalParameter.Value.Value; } } #endregion private Grammar grammar; #region ctors public override IDeepCloneable Clone(Cloner cloner) { return new GrammarEnumerationAlgorithm(this, cloner); } public GrammarEnumerationAlgorithm() { Problem = new RegressionProblem(); Parameters.Add(new ValueParameter(MaxTreeSizeParameterName, "The number of clusters.", new IntValue(4))); Parameters.Add(new ValueParameter(GuiUpdateIntervalParameterName, "Number of generated sentences, until GUI is refreshed.", new IntValue(4000))); } private GrammarEnumerationAlgorithm(GrammarEnumerationAlgorithm original, Cloner cloner) : base(original, cloner) { } #endregion protected override void Run(CancellationToken cancellationToken) { List allGenerated = new List(); List distinctGenerated = new List(); HashSet evaluatedHashes = new HashSet(); grammar = new Grammar(Problem.ProblemData.AllowedInputVariables.ToArray()); Stack remainingTrees = new Stack(); remainingTrees.Push(new SymbolString(new[] { grammar.StartSymbol })); while (remainingTrees.Any()) { if (cancellationToken.IsCancellationRequested) break; SymbolString currSymbolString = remainingTrees.Pop(); if (currSymbolString.IsSentence()) { allGenerated.Add(currSymbolString); if (evaluatedHashes.Add(grammar.CalcHashCode(currSymbolString))) { EvaluateSentence(currSymbolString); distinctGenerated.Add(currSymbolString); } UpdateView(allGenerated, distinctGenerated); } else { // expand next nonterminal symbols int nonterminalSymbolIndex = currSymbolString.FindIndex(s => s is NonterminalSymbol); NonterminalSymbol expandedSymbol = currSymbolString[nonterminalSymbolIndex] as NonterminalSymbol; foreach (Production productionAlternative in expandedSymbol.Alternatives) { SymbolString newSentence = new SymbolString(currSymbolString); newSentence.RemoveAt(nonterminalSymbolIndex); newSentence.InsertRange(nonterminalSymbolIndex, productionAlternative); if (newSentence.Count <= MaxTreeSize) { remainingTrees.Push(newSentence); } } } } StringArray sentences = new StringArray(allGenerated.Select(r => r.ToString()).ToArray()); Results.Add(new Result("All generated sentences", sentences)); StringArray distinctSentences = new StringArray(distinctGenerated.Select(r => r.ToString()).ToArray()); Results.Add(new Result("Distinct generated sentences", distinctSentences)); } private void UpdateView(List allGenerated, List distinctGenerated) { int generatedSolutions = allGenerated.Count; int distinctSolutions = distinctGenerated.Count; if (generatedSolutions % GuiUpdateInterval == 0) { Results.AddOrUpdateResult("Generated Solutions", new IntValue(generatedSolutions)); Results.Add(new Result("Distinct Solutions", new IntValue(distinctSolutions))); DoubleValue averageTreeLength = new DoubleValue(allGenerated.Select(r => r.Count).Average()); Results.Add(new Result("Average Tree Length of Solutions", averageTreeLength)); } } private void EvaluateSentence(SymbolString symbolString) { SymbolicExpressionTree tree = grammar.ParseSymbolicExpressionTree(symbolString); SymbolicRegressionModel model = new SymbolicRegressionModel( Problem.ProblemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter()); IRegressionSolution newSolution = model.CreateRegressionSolution(Problem.ProblemData); IResult currBestTrainingSolutionResult; IResult currBestTestSolutionResult; if (!Results.TryGetValue(BestTrainingSolution, out currBestTrainingSolutionResult) || !Results.TryGetValue(BestTestSolution, out currBestTestSolutionResult)) { Results.Add(new Result(BestTrainingSolution, newSolution)); Results.Add(new Result(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly())); Results.Add(new Result(BestTestSolution, newSolution)); Results.Add(new Result(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly())); } else { IRegressionSolution currBestTrainingSolution = (IRegressionSolution)currBestTrainingSolutionResult.Value; if (currBestTrainingSolution.TrainingRSquared < newSolution.TrainingRSquared) { currBestTrainingSolutionResult.Value = newSolution; Results.AddOrUpdateResult(BestTrainingSolutionQuality, new DoubleValue(newSolution.TrainingRSquared).AsReadOnly()); } IRegressionSolution currBestTestSolution = (IRegressionSolution)currBestTestSolutionResult.Value; if (currBestTestSolution.TestRSquared < newSolution.TestRSquared) { currBestTestSolutionResult.Value = newSolution; Results.AddOrUpdateResult(BestTestSolutionQuality, new DoubleValue(newSolution.TestRSquared).AsReadOnly()); } } } } }