Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.Instances.DataAnalysis/3.3/Regression/VariableNetworks/LinearVariableNetwork.cs @ 16713

Last change on this file since 16713 was 16565, checked in by gkronber, 6 years ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 4.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Problems.DataAnalysis;
28using HeuristicLab.Random;
29
30namespace HeuristicLab.Problems.Instances.DataAnalysis {
31  public sealed class LinearVariableNetwork : VariableNetwork {
32    private int numberOfFeatures;
33    private double noiseRatio;
34
35    public override string Name { get { return string.Format("LinearVariableNetwork-{0:0%} ({1} dim)", noiseRatio, numberOfFeatures); } }
36
37    public LinearVariableNetwork(int numberOfFeatures, double noiseRatio,
38      IRandom rand)
39      : base(250, 250, numberOfFeatures, noiseRatio, rand) {
40      this.noiseRatio = noiseRatio;
41      this.numberOfFeatures = numberOfFeatures;
42    }
43
44    protected override IEnumerable<double> GenerateRandomFunction(IRandom rand, List<List<double>> xs, out string[] selectedVarNames, out double[] relevance) {
45      int nl = SampleNumberOfVariables(rand, numberOfFeatures);
46
47      var selectedIdx = Enumerable.Range(0, xs.Count).Shuffle(rand)
48        .Take(nl).ToArray();
49
50      var selectedVars = selectedIdx.Select(i => xs[i]).ToArray();
51      selectedVarNames = selectedIdx.Select(i => VariableNames[i]).ToArray();
52      return SampleLinearFunction(rand, selectedVars, out relevance);
53    }
54
55    private IEnumerable<double> SampleLinearFunction(IRandom rand, List<double>[] xs, out double[] relevance) {
56      int nl = xs.Length;
57      int nRows = xs.First().Count;
58
59      // sample standardized coefficients iid ~ N(0, 1)
60      var c = Enumerable.Range(0, nRows).Select(_ => NormalDistributedRandom.NextDouble(rand, 0, 1)).ToArray();
61
62      // calculate scaled coefficients (variables with large variance should have smaller coefficients)
63      var scaledC = Enumerable.Range(0, nl)
64        .Select(i => c[i] / xs[i].StandardDeviationPop())
65        .ToArray();
66
67      var y = EvaluteLinearModel(xs, scaledC);
68
69      relevance = CalculateRelevance(y, xs, scaledC);
70
71      return y;
72    }
73
74    private double[] EvaluteLinearModel(List<double>[] xs, double[] c) {
75      int nRows = xs.First().Count;
76      var y = new double[nRows];
77      for(int row = 0; row < nRows; row++) {
78        y[row] = xs.Select(xi => xi[row]).Zip(c, (xij, cj) => xij * cj).Sum();
79        y[row] /= c.Length;
80      }
81      return y;
82    }
83
84    // calculate variable relevance based on removal of variables
85    //  1) to remove a variable we set it's coefficient to zero
86    //  2) calculate MSE of the original target values (y) to the updated targes y' (after variable removal)
87    //  3) relevance is larger if MSE(y,y') is large
88    //  4) scale impacts so that the most important variable has impact = 1
89    private double[] CalculateRelevance(double[] y, List<double>[] xs, double[] l) {
90      var changedL = new double[l.Length];
91      var relevance = new double[l.Length];
92      for(int i = 0; i < l.Length; i++) {
93        Array.Copy(l, changedL, changedL.Length);
94        changedL[i] = 0.0;
95
96        var yChanged = EvaluteLinearModel(xs, changedL);
97
98        OnlineCalculatorError error;
99        var mse = OnlineMeanSquaredErrorCalculator.Calculate(y, yChanged, out error);
100        if(error != OnlineCalculatorError.None) mse = double.MaxValue;
101        relevance[i] = mse;
102      }
103      // scale so that max relevance is 1.0
104      var maxRel = relevance.Max();
105      for(int i = 0; i < relevance.Length; i++) relevance[i] /= maxRel;
106      return relevance;
107    }
108  }
109}
Note: See TracBrowser for help on using the repository browser.