#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Linq; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.VariableInteractionNetworks { [Item("SymbolicRegressionVariableImpactsAnalyzer", "An analyzer which calculates variable impacts based on the average node impacts from the tree")] [StorableClass] public class SymbolicRegressionVariableImpactsAnalyzer : SymbolicDataAnalysisAnalyzer { private const string UpdateCounterParameterName = "UpdateCounter"; private const string UpdateIntervalParameterName = "UpdateInterval"; public const string QualityParameterName = "Quality"; private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; private const string ProblemDataParameterName = "ProblemData"; private const string ApplyLinearScalingParameterName = "ApplyLinearScaling"; private const string MaxCOIterationsParameterName = "MaxCOIterations"; private const string EstimationLimitsParameterName = "EstimationLimits"; private const string EvaluatorParameterName = "Evaluator"; private const string VariableImpactsParameterName = "AverageVariableImpacts"; private const string PercentageBestParameterName = "PercentageBest"; private const string LastGenerationsParameterName = "LastGenerations"; private const string MaximumGenerationsParameterName = "MaximumGenerations"; private const string OptimizeConstantsParameterName = "OptimizeConstants"; private const string PruneTreesParameterName = "PruneTrees"; private SymbolicDataAnalysisExpressionTreeSimplifier simplifier; private SymbolicRegressionSolutionImpactValuesCalculator impactsCalculator; #region parameters public ValueParameter UpdateCounterParameter { get { return (ValueParameter)Parameters[UpdateCounterParameterName]; } } public ValueParameter UpdateIntervalParameter { get { return (ValueParameter)Parameters[UpdateIntervalParameterName]; } } public IScopeTreeLookupParameter QualityParameter { get { return (IScopeTreeLookupParameter)Parameters[QualityParameterName]; } } public ILookupParameter SymbolicDataAnalysisTreeInterpreterParameter { get { return (ILookupParameter)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; } } public ILookupParameter ProblemDataParameter { get { return (ILookupParameter)Parameters[ProblemDataParameterName]; } } public ILookupParameter ApplyLinearScalingParameter { get { return (ILookupParameter)Parameters[ApplyLinearScalingParameterName]; } } public IFixedValueParameter MaxCOIterationsParameter { get { return (IFixedValueParameter)Parameters[MaxCOIterationsParameterName]; } } public ILookupParameter EstimationLimitsParameter { get { return (ILookupParameter)Parameters[EstimationLimitsParameterName]; } } public ILookupParameter EvaluatorParameter { get { return (ILookupParameter)Parameters[EvaluatorParameterName]; } } public ILookupParameter VariableImpactsParameter { get { return (ILookupParameter)Parameters[VariableImpactsParameterName]; } } public IFixedValueParameter PercentageBestParameter { get { return (IFixedValueParameter)Parameters[PercentageBestParameterName]; } } public IFixedValueParameter LastGenerationsParameter { get { return (IFixedValueParameter)Parameters[LastGenerationsParameterName]; } } public IFixedValueParameter OptimizeConstantsParameter { get { return (IFixedValueParameter)Parameters[OptimizeConstantsParameterName]; } } public IFixedValueParameter PruneTreesParameter { get { return (IFixedValueParameter)Parameters[PruneTreesParameterName]; } } private ILookupParameter MaximumGenerationsParameter { get { return (ILookupParameter)Parameters[MaximumGenerationsParameterName]; } } #endregion #region parameter properties public int UpdateCounter { get { return UpdateCounterParameter.Value.Value; } set { UpdateCounterParameter.Value.Value = value; } } public int UpdateInterval { get { return UpdateIntervalParameter.Value.Value; } set { UpdateIntervalParameter.Value.Value = value; } } #endregion public SymbolicRegressionVariableImpactsAnalyzer() { #region add parameters Parameters.Add(new ValueParameter(UpdateCounterParameterName, new IntValue(0))); Parameters.Add(new ValueParameter(UpdateIntervalParameterName, new IntValue(1))); Parameters.Add(new LookupParameter(ProblemDataParameterName)); Parameters.Add(new LookupParameter(SymbolicDataAnalysisTreeInterpreterParameterName)); Parameters.Add(new ScopeTreeLookupParameter(QualityParameterName, "The individual qualities.")); Parameters.Add(new LookupParameter(ApplyLinearScalingParameterName)); Parameters.Add(new LookupParameter(EstimationLimitsParameterName)); Parameters.Add(new FixedValueParameter(MaxCOIterationsParameterName, new IntValue(3))); Parameters.Add(new LookupParameter(VariableImpactsParameterName, "The relative variable relevance calculated as the average relative variable frequency over the whole run.")); Parameters.Add(new FixedValueParameter(PercentageBestParameterName, new PercentValue(100))); Parameters.Add(new FixedValueParameter(LastGenerationsParameterName, new IntValue(10))); Parameters.Add(new FixedValueParameter(OptimizeConstantsParameterName, new BoolValue(false))); Parameters.Add(new FixedValueParameter(PruneTreesParameterName, new BoolValue(false))); Parameters.Add(new LookupParameter(MaximumGenerationsParameterName, "The maximum number of generations which should be processed.")); Parameters.Add(new LookupParameter(EvaluatorParameterName)); #endregion impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier(); } [StorableConstructor] protected SymbolicRegressionVariableImpactsAnalyzer(bool deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier(); if (!Parameters.ContainsKey(EvaluatorParameterName)) Parameters.Add(new LookupParameter(EvaluatorParameterName)); } protected SymbolicRegressionVariableImpactsAnalyzer(SymbolicRegressionVariableImpactsAnalyzer original, Cloner cloner) : base(original, cloner) { impactsCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); simplifier = new SymbolicDataAnalysisExpressionTreeSimplifier(); } public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicRegressionVariableImpactsAnalyzer(this, cloner); } public override IOperation Apply() { #region Update counter & update interval UpdateCounter++; if (UpdateCounter != UpdateInterval) { return base.Apply(); } UpdateCounter = 0; #endregion var results = ResultCollectionParameter.ActualValue; int maxGen = MaximumGenerationsParameter.ActualValue.Value; int gen = ((IntValue)results["Generations"].Value).Value; int lastGen = LastGenerationsParameter.Value.Value; if (lastGen > 0 && gen < maxGen - lastGen) return base.Apply(); var trees = SymbolicExpressionTree.ToArray(); var qualities = QualityParameter.ActualValue.ToArray(); Array.Sort(qualities, trees); Array.Reverse(qualities); Array.Reverse(trees); var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue; var problemData = ProblemDataParameter.ActualValue; var applyLinearScaling = ApplyLinearScalingParameter.ActualValue.Value; var constantOptimizationIterations = MaxCOIterationsParameter.Value.Value; // fixed value parameter => Value var estimationLimits = EstimationLimitsParameter.ActualValue; // lookup parameter => ActualValue var percentageBest = PercentageBestParameter.Value.Value; var optimizeConstants = OptimizeConstantsParameter.Value.Value; var pruneTrees = PruneTreesParameter.Value.Value; var allowedInputVariables = problemData.AllowedInputVariables.ToList(); DataTable dataTable; if (VariableImpactsParameter.ActualValue == null) { dataTable = new DataTable("Variable impacts", "Average impact of variables over the population"); dataTable.VisualProperties.XAxisTitle = "Generation"; dataTable.VisualProperties.YAxisTitle = "Average variable impact"; VariableImpactsParameter.ActualValue = dataTable; results.Add(new Result("Average variable impacts", "The relative variable relevance calculated as the average relative variable frequency over the whole run.", new DataTable())); foreach (var v in allowedInputVariables) { dataTable.Rows.Add(new DataRow(v) { VisualProperties = { StartIndexZero = true } }); } VariableImpactsParameter.ActualValue = dataTable; } dataTable = VariableImpactsParameter.ActualValue; int nTrees = (int)Math.Round(trees.Length * percentageBest); var bestTrees = trees.Take(nTrees).Select(x => (ISymbolicExpressionTree)x.Clone()).ToList(); // simplify trees before doing anything else var simplifiedTrees = bestTrees.Select(x => simplifier.Simplify(x)).ToList(); if (optimizeConstants) { for (int i = 0; i < simplifiedTrees.Count; ++i) { qualities[i].Value = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], problemData, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, estimationLimits.Upper, estimationLimits.Lower); } } if (pruneTrees) { for (int i = 0; i < simplifiedTrees.Count; ++i) { simplifiedTrees[i] = SymbolicRegressionPruningOperator.Prune(simplifiedTrees[i], impactsCalculator, interpreter, problemData, estimationLimits, problemData.TrainingIndices); } } // map each variable to a list of indices of the trees that contain it var variablesToTreeIndices = allowedInputVariables.ToDictionary(x => x, x => Enumerable.Range(0, simplifiedTrees.Count).Where(i => ContainsVariable(simplifiedTrees[i], x)).ToList()); foreach (var mapping in variablesToTreeIndices) { var variableName = mapping.Key; var median = problemData.Dataset.GetDoubleValues(variableName, problemData.TrainingIndices).Median(); var ds = new ModifiableDataset(problemData.Dataset.DoubleVariables, problemData.Dataset.DoubleVariables.Select(x => problemData.Dataset.GetReadOnlyDoubleValues(x).ToList())); foreach (var i in problemData.TrainingIndices) { ds.SetVariableValue(median, variableName, i); } var pd = new RegressionProblemData(ds, allowedInputVariables, problemData.TargetVariable); pd.TrainingPartition.Start = problemData.TrainingPartition.Start; pd.TrainingPartition.End = problemData.TrainingPartition.End; pd.TestPartition.Start = problemData.TestPartition.Start; pd.TestPartition.End = problemData.TestPartition.End; var indices = mapping.Value; var averageImpact = 0d; for (int i = 0; i < indices.Count; ++i) { var originalQuality = qualities[i].Value; double newQuality; if (optimizeConstants) { newQuality = SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, simplifiedTrees[i], pd, problemData.TrainingIndices, applyLinearScaling, constantOptimizationIterations, estimationLimits.Upper, estimationLimits.Lower); } else { var evaluator = EvaluatorParameter.ActualValue; newQuality = evaluator.Evaluate(this.ExecutionContext, simplifiedTrees[i], pd, pd.TrainingIndices); } averageImpact += originalQuality - newQuality; // impact calculated this way may be negative } averageImpact /= indices.Count; dataTable.Rows[variableName].Values.Add(averageImpact); } results["Average variable impacts"].Value = dataTable; return base.Apply(); } private static bool ContainsVariable(ISymbolicExpressionTree tree, string variableName) { return tree.IterateNodesPrefix().OfType().Any(x => x.VariableName == variableName); } } }