#region License Information
/* HeuristicLab
* Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Parameters;
using HeuristicLab.Problems.DataAnalysis;
using HEAL.Attic;
namespace HeuristicLab.Algorithms.DataAnalysis {
[StorableType("F3A9CCD4-975F-4F55-BE24-3A3E932591F6")]
public abstract class LeafBase : ParameterizedNamedItem, ILeafModel {
public const string LeafBuildingStateVariableName = "LeafBuildingState";
public const string UseDampeningParameterName = "UseDampening";
public const string DampeningParameterName = "DampeningStrength";
public IFixedValueParameter DampeningParameter {
get { return (IFixedValueParameter)Parameters[DampeningParameterName]; }
}
public IFixedValueParameter UseDampeningParameter {
get { return (IFixedValueParameter)Parameters[UseDampeningParameterName]; }
}
public bool UseDampening {
get { return UseDampeningParameter.Value.Value; }
set { UseDampeningParameter.Value.Value = value; }
}
public double Dampening {
get { return DampeningParameter.Value.Value; }
set { DampeningParameter.Value.Value = value; }
}
#region Constructors & Cloning
[StorableConstructor]
protected LeafBase(StorableConstructorFlag _) : base(_) { }
protected LeafBase(LeafBase original, Cloner cloner) : base(original, cloner) { }
protected LeafBase() {
Parameters.Add(new FixedValueParameter(UseDampeningParameterName, "Whether logistic dampening should be used to prevent extreme extrapolation (default=false)", new BoolValue(false)));
Parameters.Add(new FixedValueParameter(DampeningParameterName, "Determines the strength of logistic dampening. Must be > 0.0. Larger numbers lead to more conservative predictions. (default=1.5)", new DoubleValue(1.5)));
}
#endregion
#region IModelType
public abstract bool ProvidesConfidence { get; }
public abstract int MinLeafSize(IRegressionProblemData pd);
public void Initialize(IScope states) {
states.Variables.Add(new Variable(LeafBuildingStateVariableName, new LeafBuildingState()));
}
public void Build(RegressionNodeTreeModel tree, IReadOnlyList trainingRows, IScope stateScope, CancellationToken cancellationToken) {
var parameters = (RegressionTreeParameters)stateScope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
var state = (LeafBuildingState)stateScope.Variables[LeafBuildingStateVariableName].Value;
if (state.Code == 0) {
state.FillLeafs(tree, trainingRows, parameters.Data);
state.Code = 1;
}
while (state.nodeQueue.Count != 0) {
var n = state.nodeQueue.Peek();
var t = state.trainingRowsQueue.Peek();
int numP;
n.SetLeafModel(BuildModel(t, parameters, cancellationToken, out numP));
state.nodeQueue.Dequeue();
state.trainingRowsQueue.Dequeue();
}
}
public IRegressionModel BuildModel(IReadOnlyList rows, RegressionTreeParameters parameters, CancellationToken cancellation, out int numberOfParameters) {
var reducedData = RegressionTreeUtilities.ReduceDataset(parameters.Data, rows, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
var pd = new RegressionProblemData(reducedData, parameters.AllowedInputVariables.ToArray(), parameters.TargetVariable);
pd.TrainingPartition.Start = 0;
pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
int numP;
var model = Build(pd, parameters.Random, cancellation, out numP);
if (UseDampening && Dampening > 0.0) {
model = DampenedModel.DampenModel(model, pd, Dampening);
}
numberOfParameters = numP;
cancellation.ThrowIfCancellationRequested();
return model;
}
public abstract IRegressionModel Build(IRegressionProblemData pd, IRandom random, CancellationToken cancellationToken, out int numberOfParameters);
#endregion
[StorableType("495243C0-6C15-4328-B30D-FFBFA0F54DCB")]
public class LeafBuildingState : Item {
[Storable]
public Queue nodeQueue = new Queue();
[Storable]
public Queue> trainingRowsQueue = new Queue>();
//State.Code values denote the current action (for pausing)
//0...nothing has been done;
//1...building models;
[Storable]
public int Code = 0;
#region HLConstructors & Cloning
[StorableConstructor]
protected LeafBuildingState(StorableConstructorFlag _) : base(_) { }
protected LeafBuildingState(LeafBuildingState original, Cloner cloner) : base(original, cloner) {
nodeQueue = new Queue(original.nodeQueue.Select(cloner.Clone));
trainingRowsQueue = new Queue>(original.trainingRowsQueue.Select(x => (IReadOnlyList)x.ToArray()));
Code = original.Code;
}
public LeafBuildingState() { }
public override IDeepCloneable Clone(Cloner cloner) {
return new LeafBuildingState(this, cloner);
}
#endregion
public void FillLeafs(RegressionNodeTreeModel tree, IReadOnlyList trainingRows, IDataset data) {
var helperQueue = new Queue();
var trainingHelperQueue = new Queue>();
nodeQueue.Clear();
trainingRowsQueue.Clear();
helperQueue.Enqueue(tree.Root);
trainingHelperQueue.Enqueue(trainingRows);
while (helperQueue.Count != 0) {
var n = helperQueue.Dequeue();
var t = trainingHelperQueue.Dequeue();
if (n.IsLeaf) {
nodeQueue.Enqueue(n);
trainingRowsQueue.Enqueue(t);
continue;
}
IReadOnlyList leftTraining, rightTraining;
RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
helperQueue.Enqueue(n.Left);
helperQueue.Enqueue(n.Right);
trainingHelperQueue.Enqueue(leftTraining);
trainingHelperQueue.Enqueue(rightTraining);
}
}
}
}
}