#region License Information
/* HeuristicLab
* Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Analysis;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Operators;
using HeuristicLab.Optimization;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Analyzers {
[StorableClass]
[Item("SymbolicDataAnalysisUsefulGenesAnalyzer", "An analyzer which performs pruning by promoting genes in the population that outperform their containing individuals (the individuals are replaced by their subparts).")]
public class SymbolicDataAnalysisUsefulGenesAnalyzer : SingleSuccessorOperator, ISymbolicDataAnalysisAnalyzer {
private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
private const string QualityParameterName = "Quality";
private const string ResultCollectionParameterName = "Results";
private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
private const string ProblemDataParameterName = "ProblemData";
private const string GenerationsParameterName = "Generations";
private const string UpdateCounterParameterName = "UpdateCounter";
private const string UpdateIntervalParameterName = "UpdateInterval";
private const string MinimumGenerationsParameterName = "MinimumGenerations";
private const string PruningProbabilityParameterName = "PruningProbability";
private const string PercentageOfBestSolutionsParameterName = "PercentageOfBestSolutions";
private const string PromotedSubtreesResultName = "Promoted subtrees";
private const string AverageQualityImprovementResultName = "Average quality improvement";
private const string AverageLengthReductionResultName = "Average length reduction";
private const string RandomParameterName = "Random";
#region Parameters
public IScopeTreeLookupParameter SymbolicExpressionTreeParameter {
get { return (IScopeTreeLookupParameter)Parameters[SymbolicExpressionTreeParameterName]; }
}
public IScopeTreeLookupParameter QualityParameter {
get { return (IScopeTreeLookupParameter)Parameters[QualityParameterName]; }
}
public ILookupParameter RandomParameter {
get { return (ILookupParameter)Parameters[RandomParameterName]; }
}
public ILookupParameter ResultCollectionParameter {
get { return (ILookupParameter)Parameters[ResultCollectionParameterName]; }
}
public ILookupParameter SymbolicDataAnalysisTreeInterpreterParameter {
get { return (ILookupParameter)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
}
public ILookupParameter ProblemDataParameter {
get { return (ILookupParameter)Parameters[ProblemDataParameterName]; }
}
public ILookupParameter GenerationsParameter {
get { return (ILookupParameter)Parameters[GenerationsParameterName]; }
}
public ValueParameter UpdateCounterParameter {
get { return (ValueParameter)Parameters[UpdateCounterParameterName]; }
}
public ValueParameter UpdateIntervalParameter {
get { return (ValueParameter)Parameters[UpdateIntervalParameterName]; }
}
public ValueParameter MinimumGenerationsParameter {
get { return (ValueParameter)Parameters[MinimumGenerationsParameterName]; }
}
public ValueParameter PercentageOfBestSolutionsParameter {
get { return (ValueParameter)Parameters[PercentageOfBestSolutionsParameterName]; }
}
public ValueParameter PruningProbabilityParameter {
get { return (ValueParameter)Parameters[PruningProbabilityParameterName]; }
}
#endregion
#region Parameter properties
public int UpdateCounter {
get { return UpdateCounterParameter.Value.Value; }
set { UpdateCounterParameter.Value.Value = value; }
}
public int UpdateInterval {
get { return UpdateIntervalParameter.Value.Value; }
set { UpdateIntervalParameter.Value.Value = value; }
}
public int MinimumGenerations {
get { return MinimumGenerationsParameter.Value.Value; }
set { MinimumGenerationsParameter.Value.Value = value; }
}
public double PercentageOfBestSolutions {
get { return PercentageOfBestSolutionsParameter.Value.Value; }
set { PercentageOfBestSolutionsParameter.Value.Value = value; }
}
public double PruningProbability {
get { return PruningProbabilityParameter.Value.Value; }
set { PruningProbabilityParameter.Value.Value = value; }
}
#endregion
public SymbolicDataAnalysisUsefulGenesAnalyzer() {
#region Add parameters
Parameters.Add(new ScopeTreeLookupParameter(SymbolicExpressionTreeParameterName));
Parameters.Add(new ScopeTreeLookupParameter(QualityParameterName));
Parameters.Add(new LookupParameter(RandomParameterName));
Parameters.Add(new LookupParameter(ResultCollectionParameterName));
Parameters.Add(new LookupParameter(ProblemDataParameterName));
Parameters.Add(new LookupParameter(SymbolicDataAnalysisTreeInterpreterParameterName));
Parameters.Add(new LookupParameter(GenerationsParameterName));
Parameters.Add(new ValueParameter(UpdateCounterParameterName, new IntValue(0)));
Parameters.Add(new ValueParameter(UpdateIntervalParameterName, new IntValue(1)));
Parameters.Add(new ValueParameter(MinimumGenerationsParameterName, "The minimum number of generations the algorithm must be let to evolve before applying this analyzer.", new IntValue(50)));
Parameters.Add(new ValueParameter(PercentageOfBestSolutionsParameterName, "How many of the best individuals should be pruned using this method.", new PercentValue(1.0)));
Parameters.Add(new ValueParameter(PruningProbabilityParameterName, "The probability to apply pruning", new PercentValue(0.1)));
#endregion
}
protected SymbolicDataAnalysisUsefulGenesAnalyzer(SymbolicDataAnalysisUsefulGenesAnalyzer original, Cloner cloner)
: base(original, cloner) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new SymbolicDataAnalysisUsefulGenesAnalyzer(this, cloner);
}
[StorableConstructor]
protected SymbolicDataAnalysisUsefulGenesAnalyzer(bool deserializing)
: base(deserializing) {
}
public bool EnabledByDefault {
get { return false; }
}
public override void InitializeState() {
UpdateCounter = 0;
base.InitializeState();
}
public override IOperation Apply() {
int generations = GenerationsParameter.ActualValue.Value;
#region Update counter & update interval
if (generations < MinimumGenerations)
return base.Apply();
UpdateCounter++;
if (UpdateCounter != UpdateInterval) {
return base.Apply();
}
UpdateCounter = 0;
#endregion
var trees = SymbolicExpressionTreeParameter.ActualValue.ToArray();
var qualities = QualityParameter.ActualValue.ToArray();
Array.Sort(qualities, trees); // sort trees array based on qualities
var interpreter = SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
var problemData = (IRegressionProblemData)ProblemDataParameter.ActualValue;
var rows = problemData.TrainingIndices.ToList();
var random = RandomParameter.ActualValue;
int replacedTrees = 0;
int avgLengthReduction = 0;
double avgQualityImprovement = 0;
var count = (int)Math.Floor(trees.Length * PercentageOfBestSolutions);
for (int i = trees.Length - 1; i >= trees.Length - count; --i) {
if (random.NextDouble() > PruningProbability) continue;
var tree = trees[i];
var quality = qualities[i].Value;
var root = tree.Root.GetSubtree(0).GetSubtree(0);
foreach (var s in root.IterateNodesPrefix().Skip(1)) {
var r2 = EvaluateSubtree(s, interpreter, problemData, rows);
if (double.IsNaN(r2) || r2 <= quality) continue;
avgQualityImprovement += (r2 - quality);
avgLengthReduction += (tree.Length - s.GetLength());
replacedTrees++;
// replace tree with its own subtree
var startNode = tree.Root.GetSubtree(0);
startNode.RemoveSubtree(0);
startNode.AddSubtree(s);
// update tree quality
qualities[i].Value = r2;
break;
}
}
avgQualityImprovement = replacedTrees == 0 ? 0 : avgQualityImprovement / replacedTrees;
avgLengthReduction = replacedTrees == 0 ? 0 : (int)Math.Round((double)avgLengthReduction / replacedTrees);
var results = ResultCollectionParameter.ActualValue;
DataTable table;
if (results.ContainsKey(PromotedSubtreesResultName)) {
table = (DataTable)results[PromotedSubtreesResultName].Value;
} else {
table = new DataTable(PromotedSubtreesResultName);
table.Rows.Add(new DataRow(PromotedSubtreesResultName));
results.Add(new Result(PromotedSubtreesResultName, table));
}
table.Rows[PromotedSubtreesResultName].Values.Add(replacedTrees);
if (results.ContainsKey(AverageQualityImprovementResultName)) {
table = (DataTable)results[AverageQualityImprovementResultName].Value;
} else {
table = new DataTable(AverageQualityImprovementResultName);
table.Rows.Add(new DataRow(AverageQualityImprovementResultName));
results.Add(new Result(AverageQualityImprovementResultName, table));
}
table.Rows[AverageQualityImprovementResultName].Values.Add(avgQualityImprovement);
if (results.ContainsKey(AverageLengthReductionResultName)) {
table = (DataTable)results[AverageLengthReductionResultName].Value;
} else {
table = new DataTable(AverageLengthReductionResultName);
table.Rows.Add(new DataRow(AverageLengthReductionResultName));
results.Add(new Result(AverageLengthReductionResultName, table));
}
table.Rows[AverageLengthReductionResultName].Values.Add(avgLengthReduction);
return base.Apply();
}
private static double EvaluateSubtree(ISymbolicExpressionTreeNode subtree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IRegressionProblemData problemData, List rows) {
var linearInterpreter = (SymbolicDataAnalysisExpressionTreeLinearInterpreter)interpreter;
var dataset = problemData.Dataset;
var targetValues = dataset.GetDoubleValues(problemData.TargetVariable, rows);
var estimatedValues = linearInterpreter.GetValues(subtree, dataset, rows);
OnlineCalculatorError error;
double r2 = OnlinePearsonsRSquaredCalculator.Calculate(targetValues, estimatedValues, out error);
return (error == OnlineCalculatorError.None) ? r2 : double.NaN;
}
}
}