#region License Information
/* HeuristicLab
* Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
namespace HeuristicLab.Algorithms.DataAnalysis {
[StorableClass]
internal class M5NodeModel : RegressionModel {
#region Properties
[Storable]
internal bool IsLeaf { get; private set; }
[Storable]
internal IRegressionModel NodeModel { get; private set; }
[Storable]
internal string SplitAttr { get; private set; }
[Storable]
internal double SplitValue { get; private set; }
[Storable]
internal M5NodeModel Left { get; private set; }
[Storable]
internal M5NodeModel Right { get; private set; }
[Storable]
internal M5NodeModel Parent { get; set; }
[Storable]
internal int NumSamples { get; private set; }
[Storable]
internal int NumParam { get; set; }
[Storable]
internal int NodeModelParams { get; set; }
[Storable]
private IReadOnlyList Variables { get; set; }
#endregion
#region HLConstructors
[StorableConstructor]
protected M5NodeModel(bool deserializing) : base(deserializing) { }
protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
IsLeaf = original.IsLeaf;
NodeModel = cloner.Clone(original.NodeModel);
SplitValue = original.SplitValue;
SplitAttr = original.SplitAttr;
Left = cloner.Clone(original.Left);
Right = cloner.Clone(original.Right);
Parent = cloner.Clone(original.Parent);
NumParam = original.NumParam;
NumSamples = original.NumSamples;
Variables = original.Variables != null ? original.Variables.ToList() : null;
}
protected M5NodeModel(string targetAttr) : base(targetAttr) { }
protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) {
Parent = parent;
}
public override IDeepCloneable Clone(Cloner cloner) {
return new M5NodeModel(this, cloner);
}
public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) {
return m5CreationParams.LeafType is ILeafType ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
}
private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) {
return m5CreationParams.LeafType is ILeafType ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
}
#endregion
#region RegressionModel
public override IEnumerable VariablesUsedForPrediction {
get { return Variables; }
}
public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
return NodeModel.GetEstimatedValues(dataset, rows);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RegressionSolution(this, problemData);
}
#endregion
internal void Split(IReadOnlyList rows, M5CreationParameters m5CreationParams, double globalStdDev) {
Variables = m5CreationParams.AllowedInputVariables.ToArray();
NumSamples = rows.Count;
Right = null;
Left = null;
NodeModel = null;
SplitAttr = null;
SplitValue = double.NaN;
string attr;
double splitValue;
//IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction;
//if (IsLeaf) return;
IsLeaf = !m5CreationParams.Split.Split(new RegressionProblemData(ReduceDataset(m5CreationParams.Data, rows), Variables, TargetVariable), m5CreationParams.MinLeafSize, out attr, out splitValue);
if (IsLeaf) return;
//split Dataset
IReadOnlyList leftRows, rightRows;
SplitRows(rows, m5CreationParams.Data, attr, splitValue, out leftRows, out rightRows);
if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) {
IsLeaf = true;
return;
}
SplitAttr = attr;
SplitValue = splitValue;
//create subtrees
Left = CreateNode(this, m5CreationParams);
Left.Split(leftRows, m5CreationParams, globalStdDev);
Right = CreateNode(this, m5CreationParams);
Right.Split(rightRows, m5CreationParams, globalStdDev);
}
internal bool Prune(IReadOnlyList trainingRows, IReadOnlyList testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) {
if (IsLeaf) {
BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
NumParam = NodeModelParams;
return true;
}
//split training & holdout data
IReadOnlyList leftTest, rightTest;
SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest);
IReadOnlyList leftTraining, rightTraining;
SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining);
//prune children frist
var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev);
var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev);
NumParam = Left.NumParam + Right.NumParam + 1;
//TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes)
if (!lpruned && !rpruned) return false;
BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
//check if children will be pruned
if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false;
//convert to leafNode
((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1;
IsLeaf = true;
Right = null;
Left = null;
NumParam = NodeModelParams;
return true;
}
internal void InstallModels(IReadOnlyList rows, IRandom random, IDataset data, ILeafType leafType, CancellationToken cancellation) {
if (!IsLeaf) {
IReadOnlyList leftRows, rightRows;
SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows);
Left.InstallModels(leftRows, random, data, leafType, cancellation);
Right.InstallModels(rightRows, random, data, leafType, cancellation);
return;
}
BuildModel(rows, data, random, leafType, cancellation);
}
internal IEnumerable EnumerateNodes() {
var queue = new Queue();
queue.Enqueue(this);
while (queue.Count != 0) {
var cur = queue.Dequeue();
yield return cur;
if (cur.Left == null && cur.Right == null) continue;
if (cur.Left != null) queue.Enqueue(cur.Left);
if (cur.Right != null) queue.Enqueue(cur.Right);
}
}
internal void ToRuleNode() {
Parent = null;
}
#region Helpers
private double GetEstimatedValue(IDataset dataset, int row) {
if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
return NodeModel.GetEstimatedValues(dataset, new[] {row}).First();
}
private void BuildModel(IReadOnlyList rows, IDataset data, IRandom random, ILeafType leafType, CancellationToken cancellation) {
var reducedData = ReduceDataset(data, rows);
var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
pd.TrainingPartition.Start = 0;
pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
int noparams;
NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams);
NodeModelParams = noparams;
cancellation.ThrowIfCancellationRequested();
}
private IDataset ReduceDataset(IDataset data, IReadOnlyList rows) {
return new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
}
private static void SplitRows(IReadOnlyList rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList leftRows, out IReadOnlyList rightRows) {
var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray();
leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList();
rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList();
}
#endregion
[StorableClass]
private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel {
#region HLConstructors
[StorableConstructor]
private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { }
private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { }
public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { }
public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new ConfidenceM5NodeModel(this, cloner);
}
#endregion
public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) {
return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
}
private double GetEstimatedVariance(IDataset dataset, int row) {
if (!IsLeaf)
return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First();
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new ConfidenceRegressionSolution(this, problemData);
}
}
}
}