#region License Information
/* HeuristicLab
* Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.GP.Interfaces;
using System;
using HeuristicLab.DataAnalysis;
using HeuristicLab.Modeling;
namespace HeuristicLab.GP.StructureIdentification {
public class TournamentPruning : OperatorBase {
public TournamentPruning()
: base() {
AddVariableInfo(new VariableInfo("Random", "", typeof(IRandom), VariableKind.In));
AddVariableInfo(new VariableInfo("FunctionTree", "The tree to analyse", typeof(IGeneticProgrammingModel), VariableKind.In));
AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In));
AddVariableInfo(new VariableInfo("TargetVariable", "", typeof(StringData), VariableKind.In));
AddVariableInfo(new VariableInfo("TrainingSamplesStart", "Samples start", typeof(IntData), VariableKind.In));
AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "Samples end", typeof(IntData), VariableKind.In));
AddVariableInfo(new VariableInfo("TreeEvaluator", "", typeof(ITreeEvaluator), VariableKind.In));
AddVariableInfo(new VariableInfo("MaxPruningRatio", "Maximale relative size of the pruned branch", typeof(DoubleData), VariableKind.In));
AddVariableInfo(new VariableInfo("TournamentSize", "Number of branches to compare for pruning", typeof(IntData), VariableKind.
In));
AddVariableInfo(new VariableInfo("PopulationPercentileStart", "", typeof(DoubleData), VariableKind.In));
AddVariableInfo(new VariableInfo("PopulationPercentileEnd", "", typeof(DoubleData), VariableKind.In));
AddVariableInfo(new VariableInfo("QualityGainWeight", "", typeof(DoubleData), VariableKind.In));
}
public override IOperation Apply(IScope scope) {
IRandom random = scope.GetVariableValue("Random", true);
double percentileStart = scope.GetVariableValue("PopulationPercentileStart", true).Data;
double percentileEnd = scope.GetVariableValue("PopulationPercentileEnd", true).Data;
int tournamentSize = scope.GetVariableValue("TournamentSize", true).Data;
Dataset dataset = scope.GetVariableValue("Dataset", true);
string targetVariable = scope.GetVariableValue("TargetVariable", true).Data;
int samplesStart = scope.GetVariableValue("TrainingSamplesStart", true).Data;
int samplesEnd = scope.GetVariableValue("TrainingSamplesEnd", true).Data;
ITreeEvaluator evaluator = scope.GetVariableValue("TreeEvaluator", true);
double maxPruningRatio = scope.GetVariableValue("MaxPruningRatio", true).Data;
double qualityGainWeight = scope.GetVariableValue("QualityGainWeight", true).Data;
int n = scope.SubScopes.Count;
// for each tree in the given percentile
var trees = (from subScope in scope.SubScopes
select subScope.GetVariableValue("FunctionTree", false))
.Skip((int)(n * percentileStart))
.Take((int)(n * (percentileEnd - percentileStart)));
foreach (var tree in trees) {
tree.FunctionTree = Prune(random, tree.FunctionTree, tournamentSize, dataset, targetVariable, samplesStart, samplesEnd, evaluator, maxPruningRatio, qualityGainWeight);
}
return null;
}
public static IFunctionTree Prune(IRandom random, IFunctionTree tree, int tournamentSize,
Dataset dataset, string targetVariable, int samplesStart, int samplesEnd, ITreeEvaluator evaluator,
double maxPruningRatio, double qualityGainWeight) {
var evaluatedRows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
var estimatedValues = evaluator.Evaluate(dataset, tree, evaluatedRows).ToArray();
var targetValues = dataset.GetVariableValues(targetVariable, samplesStart, samplesEnd);
int originalSize = tree.GetSize();
double originalMse = SimpleMSEEvaluator.Calculate(Matrix.Create(targetValues, estimatedValues));
int maxPrunedBranchSize = (int)(tree.GetSize() * maxPruningRatio);
IFunctionTree bestTree = tree;
double bestGain = double.PositiveInfinity;
for (int i = 0; i < tournamentSize; i++) {
var clonedTree = (IFunctionTree)tree.Clone();
var prunePoints = (from node in FunctionTreeIterator.IteratePrefix(clonedTree)
from subTree in node.SubTrees
where subTree.GetSize() <= maxPrunedBranchSize
select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
.ToList();
var selectedPrunePoint = prunePoints[random.Next(prunePoints.Count)];
var branchValues = evaluator.Evaluate(dataset, selectedPrunePoint.Branch, evaluatedRows);
var branchMean = branchValues.Average();
selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
var constNode = CreateConstant(branchMean);
selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
estimatedValues = evaluator.Evaluate(dataset, clonedTree, evaluatedRows).ToArray();
double prunedMse = SimpleMSEEvaluator.Calculate(Matrix.Create(targetValues, estimatedValues));
double prunedSize = clonedTree.GetSize();
// MSE of the pruned tree is larger than the original tree in most cases
// size of the pruned tree is always smaller than the size of the original tree
// same change in quality => prefer pruning operation that removes a larger tree
double gain = ((prunedMse / originalMse) * qualityGainWeight) /
(originalSize / prunedSize);
if (gain < bestGain) {
bestGain = gain;
bestTree = clonedTree;
}
}
return bestTree;
}
private static FunctionTree CreateConstant(double constantValue) {
var node = (ConstantFunctionTree)(new Constant()).GetTreeNode();
node.Value = constantValue;
return node;
}
}
}