#region License Information /* HeuristicLab * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableClass] internal class M5NodeModel : RegressionModel { #region Properties [Storable] internal bool IsLeaf { get; private set; } [Storable] internal IRegressionModel NodeModel { get; private set; } [Storable] internal string SplitAttr { get; private set; } [Storable] internal double SplitValue { get; private set; } [Storable] internal M5NodeModel Left { get; private set; } [Storable] internal M5NodeModel Right { get; private set; } [Storable] internal M5NodeModel Parent { get; set; } [Storable] internal int NumSamples { get; private set; } [Storable] internal int NumParam { get; set; } [Storable] internal int NodeModelParams { get; set; } [Storable] private IReadOnlyList Variables { get; set; } #endregion #region HLConstructors [StorableConstructor] protected M5NodeModel(bool deserializing) : base(deserializing) { } protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) { IsLeaf = original.IsLeaf; NodeModel = cloner.Clone(original.NodeModel); SplitValue = original.SplitValue; SplitAttr = original.SplitAttr; Left = cloner.Clone(original.Left); Right = cloner.Clone(original.Right); Parent = cloner.Clone(original.Parent); NumParam = original.NumParam; NumSamples = original.NumSamples; Variables = original.Variables != null ? original.Variables.ToList() : null; } protected M5NodeModel(string targetAttr) : base(targetAttr) { } protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) { Parent = parent; } public override IDeepCloneable Clone(Cloner cloner) { return new M5NodeModel(this, cloner); } public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) { return m5CreationParams.LeafType is ILeafType ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr); } private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) { return m5CreationParams.LeafType is ILeafType ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent); } #endregion #region RegressionModel public override IEnumerable VariablesUsedForPrediction { get { return Variables; } } public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) { if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row)); if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly"); return NodeModel.GetEstimatedValues(dataset, rows); } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, problemData); } #endregion internal void Split(IReadOnlyList rows, M5CreationParameters m5CreationParams, double globalStdDev) { Variables = m5CreationParams.AllowedInputVariables.ToArray(); NumSamples = rows.Count; Right = null; Left = null; NodeModel = null; SplitAttr = null; SplitValue = double.NaN; string attr; double splitValue; //IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction; //if (IsLeaf) return; IsLeaf = !m5CreationParams.Split.Split(new RegressionProblemData(ReduceDataset(m5CreationParams.Data, rows), Variables, TargetVariable), m5CreationParams.MinLeafSize, out attr, out splitValue); if (IsLeaf) return; //split Dataset IReadOnlyList leftRows, rightRows; SplitRows(rows, m5CreationParams.Data, attr, splitValue, out leftRows, out rightRows); if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) { IsLeaf = true; return; } SplitAttr = attr; SplitValue = splitValue; //create subtrees Left = CreateNode(this, m5CreationParams); Left.Split(leftRows, m5CreationParams, globalStdDev); Right = CreateNode(this, m5CreationParams); Right.Split(rightRows, m5CreationParams, globalStdDev); } internal bool Prune(IReadOnlyList trainingRows, IReadOnlyList testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) { if (IsLeaf) { BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation); NumParam = NodeModelParams; return true; } //split training & holdout data IReadOnlyList leftTest, rightTest; SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest); IReadOnlyList leftTraining, rightTraining; SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining); //prune children frist var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev); var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev); NumParam = Left.NumParam + Right.NumParam + 1; //TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes) if (!lpruned && !rpruned) return false; BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation); //check if children will be pruned if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false; //convert to leafNode ((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1; IsLeaf = true; Right = null; Left = null; NumParam = NodeModelParams; return true; } internal void InstallModels(IReadOnlyList rows, IRandom random, IDataset data, ILeafType leafType, CancellationToken cancellation) { if (!IsLeaf) { IReadOnlyList leftRows, rightRows; SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows); Left.InstallModels(leftRows, random, data, leafType, cancellation); Right.InstallModels(rightRows, random, data, leafType, cancellation); return; } BuildModel(rows, data, random, leafType, cancellation); } internal IEnumerable EnumerateNodes() { var queue = new Queue(); queue.Enqueue(this); while (queue.Count != 0) { var cur = queue.Dequeue(); yield return cur; if (cur.Left == null && cur.Right == null) continue; if (cur.Left != null) queue.Enqueue(cur.Left); if (cur.Right != null) queue.Enqueue(cur.Right); } } internal void ToRuleNode() { Parent = null; } #region Helpers private double GetEstimatedValue(IDataset dataset, int row) { if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row); if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly"); return NodeModel.GetEstimatedValues(dataset, new[] {row}).First(); } private void BuildModel(IReadOnlyList rows, IDataset data, IRandom random, ILeafType leafType, CancellationToken cancellation) { var reducedData = ReduceDataset(data, rows); var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable); pd.TrainingPartition.Start = 0; pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows; int noparams; NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams); NodeModelParams = noparams; cancellation.ThrowIfCancellationRequested(); } private IDataset ReduceDataset(IDataset data, IReadOnlyList rows) { return new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList())); } private static void SplitRows(IReadOnlyList rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList leftRows, out IReadOnlyList rightRows) { var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray(); leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList(); rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList(); } #endregion [StorableClass] private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel { #region HLConstructors [StorableConstructor] private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { } private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { } public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { } public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { } public override IDeepCloneable Clone(Cloner cloner) { return new ConfidenceM5NodeModel(this, cloner); } #endregion public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) { return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row)); } private double GetEstimatedVariance(IDataset dataset, int row) { if (!IsLeaf) return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single(); return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First(); } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new ConfidenceRegressionSolution(this, problemData); } } } }