#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Optimization;
using HeuristicLab.Problems.DataAnalysis;
using HEAL.Attic;
namespace HeuristicLab.Algorithms.DataAnalysis {
[StorableType("425AF262-A756-4E9A-B76F-4D2480BEA4FD")]
public class RegressionRuleModel : RegressionModel, IDecisionTreeModel {
#region Properties
[Storable]
public string[] SplitAttributes { get; set; }
[Storable]
private double[] SplitValues { get; set; }
[Storable]
private Comparison[] Comparisons { get; set; }
[Storable]
private IRegressionModel RuleModel { get; set; }
[Storable]
private IReadOnlyList variables;
#endregion
#region HLConstructors
[StorableConstructor]
protected RegressionRuleModel(StorableConstructorFlag _) : base(_) { }
protected RegressionRuleModel(RegressionRuleModel original, Cloner cloner) : base(original, cloner) {
if (original.SplitAttributes != null) SplitAttributes = original.SplitAttributes.ToArray();
if (original.SplitValues != null) SplitValues = original.SplitValues.ToArray();
if (original.Comparisons != null) Comparisons = original.Comparisons.ToArray();
RuleModel = cloner.Clone(original.RuleModel);
if (original.variables != null) variables = original.variables.ToList();
}
private RegressionRuleModel(string target) : base(target) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new RegressionRuleModel(this, cloner);
}
#endregion
internal static RegressionRuleModel CreateRuleModel(string target, RegressionTreeParameters regressionTreeParams) {
return regressionTreeParams.LeafModel.ProvidesConfidence ? new ConfidenceRegressionRuleModel(target) : new RegressionRuleModel(target);
}
#region IRegressionModel
public override IEnumerable VariablesUsedForPrediction {
get { return variables; }
}
public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) {
if (RuleModel == null) throw new NotSupportedException("The model has not been built correctly");
return RuleModel.GetEstimatedValues(dataset, rows);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new RegressionSolution(this, problemData);
}
#endregion
public void Build(IReadOnlyList trainingRows, IReadOnlyList pruningRows, IScope statescope, ResultCollection results, CancellationToken cancellationToken) {
var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
variables = regressionTreeParams.AllowedInputVariables.ToList();
//build tree and select node with maximum coverage
var tree = RegressionNodeTreeModel.CreateTreeModel(regressionTreeParams.TargetVariable, regressionTreeParams);
tree.BuildModel(trainingRows, pruningRows, statescope, results, cancellationToken);
var nodeModel = tree.Root.EnumerateNodes().Where(x => x.IsLeaf).MaxItems(x => x.NumSamples).First();
var satts = new List();
var svals = new List();
var reops = new List();
//extract splits
for (var temp = nodeModel; temp.Parent != null; temp = temp.Parent) {
satts.Add(temp.Parent.SplitAttribute);
svals.Add(temp.Parent.SplitValue);
reops.Add(temp.Parent.Left == temp ? Comparison.LessEqual : Comparison.Greater);
}
Comparisons = reops.ToArray();
SplitAttributes = satts.ToArray();
SplitValues = svals.ToArray();
int np;
RuleModel = regressionTreeParams.LeafModel.BuildModel(trainingRows.Union(pruningRows).Where(r => Covers(regressionTreeParams.Data, r)).ToArray(), regressionTreeParams, cancellationToken, out np);
}
public void Update(IReadOnlyList rows, IScope statescope, CancellationToken cancellationToken) {
var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
int np;
RuleModel = regressionTreeParams.LeafModel.BuildModel(rows, regressionTreeParams, cancellationToken, out np);
}
public bool Covers(IDataset dataset, int row) {
return !SplitAttributes.Where((t, i) => !Comparisons[i].Compare(dataset.GetDoubleValue(t, row), SplitValues[i])).Any();
}
public string ToCompactString() {
var mins = new Dictionary();
var maxs = new Dictionary();
for (var i = 0; i < SplitAttributes.Length; i++) {
var n = SplitAttributes[i];
var v = SplitValues[i];
if (!mins.ContainsKey(n)) mins.Add(n, double.NegativeInfinity);
if (!maxs.ContainsKey(n)) maxs.Add(n, double.PositiveInfinity);
if (Comparisons[i] == Comparison.LessEqual) maxs[n] = Math.Min(maxs[n], v);
else mins[n] = Math.Max(mins[n], v);
}
if (maxs.Count == 0) return "";
var s = new StringBuilder();
foreach (var key in maxs.Keys)
s.Append(string.Format("{0} ∈ [{1:e2}; {2:e2}] && ", key, mins[key], maxs[key]));
s.Remove(s.Length - 4, 4);
return s.ToString();
}
[StorableType("7302AA30-9F58-42F3-BF6A-ECF1536508AB")]
private sealed class ConfidenceRegressionRuleModel : RegressionRuleModel, IConfidenceRegressionModel {
#region HLConstructors
[StorableConstructor]
private ConfidenceRegressionRuleModel(StorableConstructorFlag _) : base(_) { }
private ConfidenceRegressionRuleModel(ConfidenceRegressionRuleModel original, Cloner cloner) : base(original, cloner) { }
public ConfidenceRegressionRuleModel(string targetAttr) : base(targetAttr) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new ConfidenceRegressionRuleModel(this, cloner);
}
#endregion
public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) {
return ((IConfidenceRegressionModel)RuleModel).GetEstimatedVariances(dataset, rows);
}
public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
return new ConfidenceRegressionSolution(this, problemData);
}
}
}
[StorableType("152DECE4-2692-4D53-B290-974806ADCD72")]
internal enum Comparison {
LessEqual,
Greater
}
internal static class ComparisonExtentions {
public static bool Compare(this Comparison op, double x, double y) {
switch (op) {
case Comparison.Greater:
return x > y;
case Comparison.LessEqual:
return x <= y;
default:
throw new ArgumentOutOfRangeException(op.ToString(), op, null);
}
}
}
}