#region License Information /* HeuristicLab * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableClass] [Item("BayesianNonlinearRegressionModel", "")] public sealed class BayesianNonlinearRegressionModel : RegressionModel, IConfidenceRegressionModel { private const int SAMPLE_SIZE = 100; [Storable] public ISymbolicExpressionTree Tree { get; private set; } private double[][] parameterEmpiricalDistribution; public IEnumerable ParameterEmpiricalDistribution { get { return parameterEmpiricalDistribution; } } public ISymbolicDataAnalysisExpressionTreeInterpreter Interpreter { get; private set; } public override IEnumerable VariablesUsedForPrediction { get { return allowedInputVariables; } } [Storable] private string[] allowedInputVariables; [StorableConstructor] private BayesianNonlinearRegressionModel(bool deserializing) : base(deserializing) { } private BayesianNonlinearRegressionModel(BayesianNonlinearRegressionModel original, Cloner cloner) : base(original, cloner) { allowedInputVariables = (string[])original.allowedInputVariables.Clone(); this.Tree = cloner.Clone(original.Tree); this.parameterEmpiricalDistribution = original.parameterEmpiricalDistribution; this.Interpreter = cloner.Clone(original.Interpreter); } public BayesianNonlinearRegressionModel( ISymbolicExpressionTree tree, double[][] parameterEmpiricalDistribution, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, string targetVariable, IEnumerable allowedInputVariables ) : base(targetVariable) { this.name = ItemName; this.description = ItemDescription; this.Tree = tree; this.parameterEmpiricalDistribution = (double[][])parameterEmpiricalDistribution.Clone(); this.Interpreter = interpreter; this.allowedInputVariables = allowedInputVariables.ToArray(); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { } public override IDeepCloneable Clone(Cloner cloner) { return new BayesianNonlinearRegressionModel(this, cloner); } public override IEnumerable GetEstimatedValues(IDataset dataset, IEnumerable rows) { var y = Sample(dataset, rows); return y.Select(yi => yi.Average()); } public IEnumerable GetEstimatedVariances(IDataset dataset, IEnumerable rows) { var y = Sample(dataset, rows); return y.Select(yi => yi.VariancePop()); } private IList[] Sample(IDataset dataset, IEnumerable rows) { List[] y = rows.Select(_ => new List(SAMPLE_SIZE)).ToArray(); var rand = new System.Random(1234); // TODO; for (int s = 0; s < SAMPLE_SIZE; s++) { var paramIdx = rand.Next(parameterEmpiricalDistribution.Length); UpdateConstants(Tree, parameterEmpiricalDistribution[paramIdx]); int predRow = 0; foreach (var pred in Interpreter.GetSymbolicExpressionTreeValues(Tree, dataset, rows)) { y[predRow].Add(pred); predRow++; } } return y; } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new ConfidenceRegressionSolution(this, new RegressionProblemData(problemData)); } #region taken from ConstantOptEval TODO private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) { int i = 0; foreach (var node in tree.Root.IterateNodesPrefix().OfType()) { ConstantTreeNode constantTreeNode = node as ConstantTreeNode; VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; if (constantTreeNode != null) constantTreeNode.Value = constants[i++]; // else if (updateVariableWeights && variableTreeNodeBase != null) // variableTreeNodeBase.Weight = constants[i++]; else if (factorVarTreeNode != null) { for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) factorVarTreeNode.Weights[j] = constants[i++]; } } } #endregion } }