#region License Information
/* HeuristicLab
* Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis;
namespace HeuristicLab.Algorithms.DataAnalysis {
[StorableClass]
internal class M5RuleModel : RegressionModel, IM5MetaModel {
internal const string NoCurrentLeafesResultName = "Number of current Leafs";
#region Properties
[Storable]
internal string[] SplitAtts { get; private set; }
[Storable]
private double[] SplitVals { get; set; }
[Storable]
private RelOp[] RelOps { get; set; }
[Storable]
protected IRegressionModel RuleModel { get; set; }
[Storable]
private IReadOnlyList Variables { get; set; }
#endregion
#region HLConstructors
[StorableConstructor]
protected M5RuleModel(bool deserializing) : base(deserializing) { }
protected M5RuleModel(M5RuleModel original, Cloner cloner) : base(original, cloner) {
if (original.SplitAtts != null) SplitAtts = original.SplitAtts.ToArray();
if (original.SplitVals != null) SplitVals = original.SplitVals.ToArray();
if (original.RelOps != null) RelOps = original.RelOps.ToArray();
RuleModel = cloner.Clone(original.RuleModel);
if (original.Variables != null) Variables = original.Variables.ToList();
}
private M5RuleModel(string target) : base(target) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new M5RuleModel(this, cloner);
}
#endregion
internal static M5RuleModel CreateRuleModel(string target, M5CreationParameters m5CreationParams) {
return m5CreationParams.LeafType is ILeafType ? new ConfidenceM5RuleModel(target) : new M5RuleModel(target);
}
#region IRegressionModel
public override IEnumerable VariablesUsedForPrediction {
get { return Variables; }
}
public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
if (RuleModel == null) throw new NotSupportedException("M5P has not been built correctly");
return RuleModel.GetEstimatedValues(dataset, rows);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RegressionSolution(this, problemData);
}
#endregion
#region IM5Component
public void BuildClassifier(IReadOnlyList trainingRows, IReadOnlyList holdoutRows, M5CreationParameters m5CreationParams, CancellationToken cancellation) {
Variables = m5CreationParams.AllowedInputVariables.ToList();
var tree = M5TreeModel.CreateTreeModel(m5CreationParams.TargetVariable, m5CreationParams);
((IM5MetaModel) tree).BuildClassifier(trainingRows, holdoutRows, m5CreationParams, cancellation);
var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First();
var satts = new List();
var svals = new List();
var reops = new List();
//extract Splits
for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent) {
satts.Add(temp.Parent.SplitAttr);
svals.Add(temp.Parent.SplitValue);
reops.Add(temp.Parent.Left == temp ? RelOp.Lessequal : RelOp.Greater);
}
nodeModel.ToRuleNode();
RuleModel = nodeModel.NodeModel;
RelOps = reops.ToArray();
SplitAtts = satts.ToArray();
SplitVals = svals.ToArray();
}
public void UpdateModel(IReadOnlyList rows, M5UpdateParameters m5UpdateParameters, CancellationToken cancellation) {
BuildModel(rows, m5UpdateParameters.Random, m5UpdateParameters.Data, m5UpdateParameters.LeafType, cancellation);
}
#endregion
public bool Covers(IDataset dataset, int row) {
return !SplitAtts.Where((t, i) => !RelOps[i].Compare(dataset.GetDoubleValue(t, row), SplitVals[i])).Any();
}
public string ToCompactString() {
var mins = new Dictionary();
var maxs = new Dictionary();
for (var i = 0; i < SplitAtts.Length; i++) {
var n = SplitAtts[i];
var v = SplitVals[i];
if (!mins.ContainsKey(n)) mins.Add(n, double.NegativeInfinity);
if (!maxs.ContainsKey(n)) maxs.Add(n, double.PositiveInfinity);
if (RelOps[i] == RelOp.Lessequal) maxs[n] = Math.Min(maxs[n], v);
else mins[n] = Math.Max(mins[n], v);
}
if (maxs.Count == 0) return "";
var s = new StringBuilder();
foreach (var key in maxs.Keys)
s.Append(string.Format("{0} ∈ [{1:e2}; {2:e2}] && ", key, mins[key], maxs[key]));
s.Remove(s.Length - 4, 4);
return s.ToString();
}
#region Helpers
private void BuildModel(IReadOnlyList rows, IRandom random, IDataset data, ILeafType leafType, CancellationToken cancellation) {
var reducedData = new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
pd.TrainingPartition.Start = 0;
pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
int noparams;
RuleModel = leafType.BuildModel(pd, random, cancellation, out noparams);
cancellation.ThrowIfCancellationRequested();
}
#endregion
[StorableClass]
private sealed class ConfidenceM5RuleModel : M5RuleModel, IConfidenceRegressionModel {
#region HLConstructors
[StorableConstructor]
private ConfidenceM5RuleModel(bool deserializing) : base(deserializing) { }
private ConfidenceM5RuleModel(ConfidenceM5RuleModel original, Cloner cloner) : base(original, cloner) { }
public ConfidenceM5RuleModel(string targetAttr) : base(targetAttr) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new ConfidenceM5RuleModel(this, cloner);
}
#endregion
public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) {
return ((IConfidenceRegressionModel) RuleModel).GetEstimatedVariances(dataset, rows);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new ConfidenceRegressionSolution(this, problemData);
}
}
}
internal enum RelOp {
Lessequal,
Greater
}
internal static class RelOpExtentions {
public static bool Compare(this RelOp op, double x, double y) {
switch (op) {
case RelOp.Greater:
return x > y;
case RelOp.Lessequal:
return x <= y;
default:
throw new ArgumentOutOfRangeException(op.ToString(), op, null);
}
}
}
}