#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableClass] [Item("OneR Classification Model", "A model that uses intervals for one variable to determine the class.")] public sealed class OneRClassificationModel : ClassificationModel { public override IEnumerable VariablesUsedForPrediction { get { return new[] { Variable }; } } [Storable] private string variable; public string Variable { get { return variable; } } [Storable] private double[] splits; public double[] Splits { get { return splits; } } [Storable] private double[] classes; public double[] Classes { get { return classes; } } [Storable] private double missingValuesClass; public double MissingValuesClass { get { return missingValuesClass; } } [StorableConstructor] private OneRClassificationModel(bool deserializing) : base(deserializing) { } private OneRClassificationModel(OneRClassificationModel original, Cloner cloner) : base(original, cloner) { this.variable = (string)original.variable; this.splits = (double[])original.splits.Clone(); this.classes = (double[])original.classes.Clone(); this.missingValuesClass = original.missingValuesClass; } public override IDeepCloneable Clone(Cloner cloner) { return new OneRClassificationModel(this, cloner); } public OneRClassificationModel(string targetVariable, string variable, double[] splits, double[] classes, double missingValuesClass = double.NaN) : base(targetVariable) { if (splits.Length != classes.Length) { throw new ArgumentException("Number of splits and classes has to be equal."); } if (!Double.IsPositiveInfinity(splits[splits.Length - 1])) { throw new ArgumentException("Last split has to be double.PositiveInfinity, so that all values are covered."); } this.name = ItemName; this.description = ItemDescription; this.variable = variable; this.splits = splits; this.classes = classes; this.missingValuesClass = missingValuesClass; } // uses sorting to return the values in the order of rows, instead of using nested for loops // to avoid O(n²) runtime public override IEnumerable GetEstimatedClassValues(IDataset dataset, IEnumerable rows) { var values = dataset.GetDoubleValues(Variable, rows).ToArray(); var rowsArray = rows.ToArray(); var order = Enumerable.Range(0, rowsArray.Length).ToArray(); double[] estimated = new double[rowsArray.Length]; Array.Sort(rowsArray, order); Array.Sort(values, rowsArray); int curSplit = 0, curIndex = 0; while (curIndex < values.Length && Double.IsNaN(values[curIndex])) { estimated[curIndex] = MissingValuesClass; curIndex++; } while (curSplit < Splits.Length) { while (curIndex < values.Length && Splits[curSplit] > values[curIndex]) { estimated[curIndex] = classes[curSplit]; curIndex++; } curSplit++; } Array.Sort(rowsArray, estimated); Array.Sort(order, estimated); return estimated; } public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) { return new OneRClassificationSolution(this, new ClassificationProblemData(problemData)); } } }