Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3127-MRGP-VarPro-Exploration/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/VarProTest.cs @ 18097

Last change on this file since 18097 was 17988, checked in by gkronber, 3 years ago

#3127: initial import of VarProMRGP implementation (depends on new NativeInterpreter which supports VarPro)

File size: 5.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.NativeInterpreter;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
9using HeuristicLab.Problems.DataAnalysis.Symbolic.Tests;
10using HeuristicLab.Random;
11using Microsoft.VisualStudio.TestTools.UnitTesting;
12
13namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Tests {
14  [TestClass]
15  public class VarProTest {
16
17    [TestMethod]
18    public void Exponential() {
19      var rand = new MersenneTwister(31415);
20      int n = 30;
21      var x1 = Enumerable.Range(0, n).Select(_ => rand.NextDouble() * 2 - 1).ToArray();
22      var x2 = Enumerable.Range(0, n).Select(_ => rand.NextDouble() * 2 - 1).ToArray();
23
24      var x1Handle = GCHandle.Alloc(x1, GCHandleType.Pinned);
25      var x2Handle = GCHandle.Alloc(x2, GCHandleType.Pinned);
26
27      {
28        var y = x1.Select(xi => Math.Exp(-0.3 * xi)).ToArray();
29
30        var code = new NativeInterpreter.NativeInstruction[4];
31        code[0] = EmitConst(1.0);
32        code[1] = EmitVar(x1Handle, 1.0);
33        code[2] = Emit(OpCode.Mul, code[0], code[1]);
34        code[3] = Emit(OpCode.Exp, code[2]);
35
36        var termIndices = new int[1];
37        termIndices[0] = code.Length - 1;
38        var rows = Enumerable.Range(0, n).ToArray();
39        var coeff = new double[termIndices.Length + 1];
40        var options = new SolverOptions();
41        var y_pred = new double[y.Length];
42        NativeInterpreter.NativeWrapper.GetValuesVarPro(code, code.Length, termIndices, nTerms: 1, rows, nRows: n, coeff, options, y_pred, y, out var optSummary);
43        Assert.AreEqual(code[0].value, -0.3, 1e-6);
44        Assert.AreEqual(code[1].value, 1.0);
45
46        Array.Clear(coeff, 0, coeff.Length);
47        Array.Clear(y_pred, 0, y_pred.Length);
48        NativeInterpreter.NativeWrapper.GetValuesVarPro(code, code.Length, termIndices, nTerms: 1, rows, nRows: n, coeff, options, y_pred, y, out optSummary);
49        Assert.AreEqual(code[0].value, -0.3, 1e-6);
50        Assert.AreEqual(code[1].value, 1.0);
51      }
52
53      {
54        var y = new double[n];
55        for (int i = 0; i < n; i++) {
56          y[i] = 3.0 * Math.Exp(-0.3 * x1[i]) + 4.0 * Math.Exp(0.5 * x2[i]) + 5.0;
57        }
58
59        var code = new NativeInterpreter.NativeInstruction[8];
60
61        // first term
62        code[0] = EmitConst(1.0);
63        code[1] = EmitVar(x1Handle, 1.0);
64        code[2] = Emit(OpCode.Mul, code[0], code[1]);
65        code[3] = Emit(OpCode.Exp, code[2]);
66
67        // second term
68        code[4] = EmitConst(1.0);
69        code[5] = EmitVar(x2Handle, 1.0);
70        code[6] = Emit(OpCode.Mul, code[4], code[5]);
71        code[7] = Emit(OpCode.Exp, code[6]);
72
73        var termIndices = new int[] { 3, 7 };
74        var rows = Enumerable.Range(0, n).ToArray();
75        var coeff = new double[termIndices.Length + 1];
76        var options = new SolverOptions();
77        var y_pred = new double[y.Length];
78        NativeInterpreter.NativeWrapper.GetValuesVarPro(code, code.Length, termIndices, termIndices.Length, rows, nRows: n, coeff, options, y_pred, y, out var optSummary);
79        Assert.AreEqual(code[0].value, -0.3, 1e-6);
80        Assert.AreEqual(code[1].value, 1.0);
81        Assert.AreEqual(code[4].value, 0.5, 1e-6);
82        Assert.AreEqual(code[5].value, 1.0);
83        Assert.AreEqual(coeff[0], 3.0, 1e-6);
84        Assert.AreEqual(coeff[1], 4.0, 1e-6);
85        Assert.AreEqual(coeff[2], 5.0, 1e-6);
86
87        Array.Clear(coeff, 0, coeff.Length);
88        Array.Clear(y_pred, 0, y_pred.Length);
89        NativeInterpreter.NativeWrapper.GetValuesVarPro(code, code.Length, termIndices, termIndices.Length, rows, nRows: n, coeff, options, y_pred, y, out optSummary);
90        Assert.AreEqual(code[0].value, -0.3, 1e-6);
91        Assert.AreEqual(code[1].value, 1.0);
92        Assert.AreEqual(code[4].value, 0.5, 1e-6);
93        Assert.AreEqual(code[5].value, 1.0);
94        Assert.AreEqual(coeff[0], 3.0, 1e-6);
95        Assert.AreEqual(coeff[1], 4.0, 1e-6);
96        Assert.AreEqual(coeff[2], 5.0, 1e-6);
97      }
98
99
100      x1Handle.Free();
101      x2Handle.Free();
102    }
103
104    private NativeInstruction Emit(OpCode opcode, params NativeInstruction[] args) {
105      var instr = new NativeInstruction();
106      instr.narg = (ushort)args.Length;
107      instr.length = args.Sum(argi => argi.length) + 1;
108      instr.opcode = (byte)opcode;
109      instr.value = 0.0;
110      instr.optimize = false;
111      return instr;
112    }
113
114    private NativeInstruction EmitConst(double v) {
115      var instr = new NativeInstruction();
116      instr.narg = 0;
117      instr.length = 1;
118      instr.opcode = (byte)OpCode.Constant;
119      instr.value = v;
120      instr.optimize = true;
121      return instr;
122    }
123
124    private NativeInstruction EmitVar(GCHandle gch, double v) {
125      var instr = new NativeInstruction();
126      instr.data = gch.AddrOfPinnedObject();
127      instr.narg = 0;
128      instr.length = 1;
129      instr.opcode = (byte)OpCode.Variable;
130      instr.value = v;
131      instr.optimize = false;
132      return instr;
133    }
134  }
135}
Note: See TracBrowser for help on using the repository browser.