1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using ExcelDna.Integration;
|
---|
5 | using HeuristicLab.Problems.DataAnalysis;
|
---|
6 |
|
---|
7 | namespace HeuristicLabExcel {
|
---|
8 |
|
---|
9 | public class ExcelFunctions {
|
---|
10 | /* Standard example from ExcelDNA */
|
---|
11 | [ExcelFunction(Description = "Multiplies two numbers", Category = "Useful functions")]
|
---|
12 | public static double MultiplyThem(double x, double y) {
|
---|
13 | return x * y;
|
---|
14 | }
|
---|
15 |
|
---|
16 |
|
---|
17 | [ExcelFunction(Description = "Random forest", Category = "Data Analysis")]
|
---|
18 | public static double[,] PredictRandomForest(double[,] x, double[] y) {
|
---|
19 | int nRows = x.GetLength(0);
|
---|
20 | int nCols = x.GetLength(1);
|
---|
21 | if (nRows > 5000) throw new ArgumentException("y");
|
---|
22 | if (nCols >= nRows) throw new ArgumentException("x");
|
---|
23 | var inputs = Enumerable.Range(0, nCols).Select(i => "x" + i);
|
---|
24 | var target = "y";
|
---|
25 | var variables = inputs.Concat(new string[] { target });
|
---|
26 | // copy data
|
---|
27 | var xy = new double[nRows, nCols + 1];
|
---|
28 | for (int r = 0; r < nRows; r++) {
|
---|
29 | for (int c = 0; c < nCols; c++) {
|
---|
30 | xy[r, c] = x[r, c];
|
---|
31 | }
|
---|
32 | if (r < y.Length)
|
---|
33 | xy[r, nCols] = y[r];
|
---|
34 | }
|
---|
35 | var ds = new Dataset(variables, xy);
|
---|
36 |
|
---|
37 | var problemData = new RegressionProblemData(ds, inputs, target);
|
---|
38 | problemData.TrainingPartition.Start = 0;
|
---|
39 | problemData.TrainingPartition.End = y.Length;
|
---|
40 | problemData.TestPartition.Start = y.Length;
|
---|
41 | problemData.TestPartition.End = nRows;
|
---|
42 |
|
---|
43 | double rmsError;
|
---|
44 | double oobAvgRelError;
|
---|
45 | double oobRmsError;
|
---|
46 | double avgRelError;
|
---|
47 | var rf = HeuristicLab.Algorithms.DataAnalysis.RandomForestRegression.CreateRandomForestRegressionSolution(problemData, 100, 0.5, 0.5, 31415, out rmsError, out avgRelError, out oobRmsError, out oobAvgRelError);
|
---|
48 |
|
---|
49 | // copy for output
|
---|
50 | var res = new double[nRows, 1];
|
---|
51 | var estValuesEnum = rf.EstimatedValues.GetEnumerator();
|
---|
52 | estValuesEnum.MoveNext();
|
---|
53 | for (int r = 0; r < nRows; r++) {
|
---|
54 | res[r, 0] = estValuesEnum.Current;
|
---|
55 | estValuesEnum.MoveNext();
|
---|
56 | }
|
---|
57 | return res;
|
---|
58 | }
|
---|
59 | }
|
---|
60 | }
|
---|