[5617]  1  #region License Information


 2  /* HeuristicLab


[14185]  3  * Copyright (C) 20022016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


[5617]  4  *


 5  * This file is part of HeuristicLab.


 6  *


 7  * HeuristicLab is free software: you can redistribute it and/or modify


 8  * it under the terms of the GNU General Public License as published by


 9  * the Free Software Foundation, either version 3 of the License, or


 10  * (at your option) any later version.


 11  *


 12  * HeuristicLab is distributed in the hope that it will be useful,


 13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


 14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


 15  * GNU General Public License for more details.


 16  *


 17  * You should have received a copy of the GNU General Public License


 18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


 19  */


 20  #endregion


 21 


 22  using System;


[5777]  23  using System.Collections.Generic;


[5617]  24  using System.Linq;


[14542]  25  using System.Threading;


[5617]  26  using HeuristicLab.Common;


 27  using HeuristicLab.Core;


 28  using HeuristicLab.Data;


[5777]  29  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


[5617]  30  using HeuristicLab.Optimization;


 31  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


 32  using HeuristicLab.Problems.DataAnalysis;


 33  using HeuristicLab.Problems.DataAnalysis.Symbolic;


[5624]  34  using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;


[5617]  35 


 36  namespace HeuristicLab.Algorithms.DataAnalysis {


 37  /// <summary>


 38  /// Linear regression data analysis algorithm.


 39  /// </summary>


[13238]  40  [Item("Linear Regression (LR)", "Linear regression data analysis algorithm (wrapper for ALGLIB).")]


[12504]  41  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 100)]


[5617]  42  [StorableClass]


 43  public sealed class LinearRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {


[5649]  44  private const string LinearRegressionModelResultName = "Linear regression solution";


[5617]  45 


 46  [StorableConstructor]


 47  private LinearRegression(bool deserializing) : base(deserializing) { }


 48  private LinearRegression(LinearRegression original, Cloner cloner)


 49  : base(original, cloner) {


 50  }


 51  public LinearRegression()


 52  : base() {


[5649]  53  Problem = new RegressionProblem();


[5617]  54  }


 55  [StorableHook(HookType.AfterDeserialization)]


 56  private void AfterDeserialization() { }


 57 


 58  public override IDeepCloneable Clone(Cloner cloner) {


 59  return new LinearRegression(this, cloner);


 60  }


 61 


 62  #region linear regression


[14542]  63  protected override void Run(CancellationToken cancellationToken) {


[5617]  64  double rmsError, cvRmsError;


[5624]  65  var solution = CreateLinearRegressionSolution(Problem.ProblemData, out rmsError, out cvRmsError);


[5649]  66  Results.Add(new Result(LinearRegressionModelResultName, "The linear regression solution.", solution));


 67  Results.Add(new Result("Root mean square error", "The root of the mean of squared errors of the linear regression solution on the training set.", new DoubleValue(rmsError)));


 68  Results.Add(new Result("Estimated root mean square error (crossvalidation)", "The estimated root of the mean of squared errors of the linear regression solution via cross validation.", new DoubleValue(cvRmsError)));


[5617]  69  }


 70 


[5624]  71  public static ISymbolicRegressionSolution CreateLinearRegressionSolution(IRegressionProblemData problemData, out double rmsError, out double cvRmsError) {


[12509]  72  var dataset = problemData.Dataset;


[5624]  73  string targetVariable = problemData.TargetVariable;


[5649]  74  IEnumerable<string> allowedInputVariables = problemData.AllowedInputVariables;


[8139]  75  IEnumerable<int> rows = problemData.TrainingIndices;


[14237]  76  var doubleVariables = allowedInputVariables.Where(dataset.VariableHasType<double>);


 77  var factorVariableNames = allowedInputVariables.Where(dataset.VariableHasType<string>);


[14240]  78  var factorVariables = AlglibUtil.GetFactorVariableValues(dataset, factorVariableNames, rows);


[14237]  79  double[,] binaryMatrix = AlglibUtil.PrepareInputMatrix(dataset, factorVariables, rows);


 80  double[,] doubleVarMatrix = AlglibUtil.PrepareInputMatrix(dataset, doubleVariables.Concat(new string[] { targetVariable }), rows);


 81  var inputMatrix = binaryMatrix.VertCat(doubleVarMatrix);


 82 


[6002]  83  if (inputMatrix.Cast<double>().Any(x => double.IsNaN(x)  double.IsInfinity(x)))


 84  throw new NotSupportedException("Linear regression does not support NaN or infinity values in the input dataset.");


[5617]  85 


[12817]  86  alglib.linearmodel lm = new alglib.linearmodel();


 87  alglib.lrreport ar = new alglib.lrreport();


[5617]  88  int nRows = inputMatrix.GetLength(0);


 89  int nFeatures = inputMatrix.GetLength(1)  1;


[12817]  90  double[] coefficients = new double[nFeatures + 1]; // last coefficient is for the constant


[5617]  91 


 92  int retVal = 1;


 93  alglib.lrbuild(inputMatrix, nRows, nFeatures, out retVal, out lm, out ar);


[5649]  94  if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");


[5617]  95  rmsError = ar.rmserror;


 96  cvRmsError = ar.cvrmserror;


 97 


 98  alglib.lrunpack(lm, out coefficients, out nFeatures);


 99 


 100  ISymbolicExpressionTree tree = new SymbolicExpressionTree(new ProgramRootSymbol().CreateTreeNode());


 101  ISymbolicExpressionTreeNode startNode = new StartSymbol().CreateTreeNode();


[5733]  102  tree.Root.AddSubtree(startNode);


[5617]  103  ISymbolicExpressionTreeNode addition = new Addition().CreateTreeNode();


[5733]  104  startNode.AddSubtree(addition);


[5617]  105 


 106  int col = 0;


[14237]  107  foreach (var kvp in factorVariables) {


 108  var varName = kvp.Key;


 109  foreach (var cat in kvp.Value) {


[14243]  110  BinaryFactorVariableTreeNode vNode =


 111  (BinaryFactorVariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.BinaryFactorVariable().CreateTreeNode();


[14237]  112  vNode.VariableName = varName;


 113  vNode.VariableValue = cat;


 114  vNode.Weight = coefficients[col];


 115  addition.AddSubtree(vNode);


 116  col++;


 117  }


 118  }


 119  foreach (string column in doubleVariables) {


[5617]  120  VariableTreeNode vNode = (VariableTreeNode)new HeuristicLab.Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();


 121  vNode.VariableName = column;


 122  vNode.Weight = coefficients[col];


[5733]  123  addition.AddSubtree(vNode);


[5617]  124  col++;


 125  }


 126 


 127  ConstantTreeNode cNode = (ConstantTreeNode)new Constant().CreateTreeNode();


 128  cNode.Value = coefficients[coefficients.Length  1];


[5733]  129  addition.AddSubtree(cNode);


[5617]  130 


[13941]  131  SymbolicRegressionSolution solution = new SymbolicRegressionSolution(new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeInterpreter()), (IRegressionProblemData)problemData.Clone());


[6555]  132  solution.Model.Name = "Linear Regression Model";


[7588]  133  solution.Name = "Linear Regression Solution";


[5624]  134  return solution;


[5617]  135  }


 136  #endregion


 137  }


 138  }

