#region License Information /* HeuristicLab * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.GP.Interfaces; using System; using HeuristicLab.DataAnalysis; using HeuristicLab.Modeling; namespace HeuristicLab.GP.StructureIdentification { public class TournamentPruning : OperatorBase { public TournamentPruning() : base() { AddVariableInfo(new VariableInfo("Random", "", typeof(IRandom), VariableKind.In)); AddVariableInfo(new VariableInfo("FunctionTree", "The tree to analyse", typeof(IGeneticProgrammingModel), VariableKind.In)); AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In)); AddVariableInfo(new VariableInfo("TargetVariable", "", typeof(StringData), VariableKind.In)); AddVariableInfo(new VariableInfo("TrainingSamplesStart", "Samples start", typeof(IntData), VariableKind.In)); AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "Samples end", typeof(IntData), VariableKind.In)); AddVariableInfo(new VariableInfo("TreeEvaluator", "", typeof(ITreeEvaluator), VariableKind.In)); AddVariableInfo(new VariableInfo("MaxPruningRatio", "Maximale relative size of the pruned branch", typeof(DoubleData), VariableKind.In)); AddVariableInfo(new VariableInfo("TournamentSize", "Number of branches to compare for pruning", typeof(IntData), VariableKind. In)); AddVariableInfo(new VariableInfo("PopulationPercentileStart", "", typeof(DoubleData), VariableKind.In)); AddVariableInfo(new VariableInfo("PopulationPercentileEnd", "", typeof(DoubleData), VariableKind.In)); AddVariableInfo(new VariableInfo("QualityGainWeight", "", typeof(DoubleData), VariableKind.In)); } public override IOperation Apply(IScope scope) { IRandom random = scope.GetVariableValue("Random", true); double percentileStart = scope.GetVariableValue("PopulationPercentileStart", true).Data; double percentileEnd = scope.GetVariableValue("PopulationPercentileEnd", true).Data; int tournamentSize = scope.GetVariableValue("TournamentSize", true).Data; Dataset dataset = scope.GetVariableValue("Dataset", true); string targetVariable = scope.GetVariableValue("TargetVariable", true).Data; int samplesStart = scope.GetVariableValue("TrainingSamplesStart", true).Data; int samplesEnd = scope.GetVariableValue("TrainingSamplesEnd", true).Data; ITreeEvaluator evaluator = scope.GetVariableValue("TreeEvaluator", true); double maxPruningRatio = scope.GetVariableValue("MaxPruningRatio", true).Data; double qualityGainWeight = scope.GetVariableValue("QualityGainWeight", true).Data; int n = scope.SubScopes.Count; // for each tree in the given percentile var trees = (from subScope in scope.SubScopes select subScope.GetVariableValue("FunctionTree", false)) .Skip((int)(n * percentileStart)) .Take((int)(n * (percentileEnd - percentileStart))); foreach (var tree in trees) { tree.FunctionTree = Prune(random, tree.FunctionTree, tournamentSize, dataset, targetVariable, samplesStart, samplesEnd, evaluator, maxPruningRatio, qualityGainWeight); } return null; } public static IFunctionTree Prune(IRandom random, IFunctionTree tree, int tournamentSize, Dataset dataset, string targetVariable, int samplesStart, int samplesEnd, ITreeEvaluator evaluator, double maxPruningRatio, double qualityGainWeight) { var evaluatedRows = Enumerable.Range(samplesStart, samplesEnd - samplesStart); var estimatedValues = evaluator.Evaluate(dataset, tree, evaluatedRows).ToArray(); var targetValues = dataset.GetVariableValues(targetVariable, samplesStart, samplesEnd); int originalSize = tree.GetSize(); double originalMse = SimpleMSEEvaluator.Calculate(Matrix.Create(targetValues, estimatedValues)); int maxPrunedBranchSize = (int)(tree.GetSize() * maxPruningRatio); IFunctionTree bestTree = tree; double bestGain = double.PositiveInfinity; for (int i = 0; i < tournamentSize; i++) { var clonedTree = (IFunctionTree)tree.Clone(); var prunePoints = (from node in FunctionTreeIterator.IteratePrefix(clonedTree) from subTree in node.SubTrees where subTree.GetSize() <= maxPrunedBranchSize select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) }) .ToList(); var selectedPrunePoint = prunePoints[random.Next(prunePoints.Count)]; var branchValues = evaluator.Evaluate(dataset, selectedPrunePoint.Branch, evaluatedRows); var branchMean = branchValues.Average(); selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex); var constNode = CreateConstant(branchMean); selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode); estimatedValues = evaluator.Evaluate(dataset, clonedTree, evaluatedRows).ToArray(); double prunedMse = SimpleMSEEvaluator.Calculate(Matrix.Create(targetValues, estimatedValues)); double prunedSize = clonedTree.GetSize(); // MSE of the pruned tree is larger than the original tree in most cases // size of the pruned tree is always smaller than the size of the original tree // same change in quality => prefer pruning operation that removes a larger tree double gain = ((prunedMse / originalMse) * qualityGainWeight) / (originalSize / prunedSize); if (gain < bestGain) { bestGain = gain; bestTree = clonedTree; } } return bestTree; } private static FunctionTree CreateConstant(double constantValue) { var node = (ConstantFunctionTree)(new Constant()).GetTreeNode(); node.Value = constantValue; return node; } } }