Free cookie consent management tool by TermsFeed Policy Generator

source: branches/sluengo/HeuristicLab.Problems.TradeRules/Solution/TradeRulesSolutionBase.cs @ 9386

Last change on this file since 9386 was 9386, checked in by sluengo, 11 years ago
File size: 11.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.Problems.DataAnalysis;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Data;
8using HeuristicLab.Common;
9using HeuristicLab.Optimization;
10using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
11
12namespace HeuristicLab.Problems.TradeRules
13{
14      [StorableClass]
15    public abstract class TradeRulesSolutionBase : DataAnalysisSolution, IRegressionSolution
16    {
17        private const string TrainingCashResultName = "Cash after the operation (training)";
18        private const string TestCashResultName = "Cash after the operation (test)";
19        private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
20        private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
21        private const string TrainingMeanAbsoluteErrorResultName = "Mean absolute error (training)";
22        private const string TestMeanAbsoluteErrorResultName = "Mean absolute error (test)";
23        private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
24        private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
25        private const string TrainingRelativeErrorResultName = "Average relative error (training)";
26        private const string TestRelativeErrorResultName = "Average relative error (test)";
27        private const string TrainingNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (training)";
28        private const string TestNormalizedMeanSquaredErrorResultName = "Normalized mean squared error (test)";
29        private const string TrainingMeanErrorResultName = "Mean error (training)";
30        private const string TestMeanErrorResultName = "Mean error (test)";
31        private const string TestTradeDaysResultName = "Trade days";
32        private const string TestNumberTradesResultName = "Number of trades";
33        private const string TestTotalTradesResultName = "Total trades";
34
35
36        public new IRegressionModel Model
37        {
38            get { return (IRegressionModel)base.Model; }
39            protected set { base.Model = value; }
40        }
41
42        public new IRegressionProblemData ProblemData
43        {
44            get { return (IRegressionProblemData)base.ProblemData; }
45            set { base.ProblemData = value; }
46        }
47
48        public abstract IEnumerable<double> EstimatedValues { get; }
49        public abstract IEnumerable<double> EstimatedTrainingValues { get; }
50        public abstract IEnumerable<double> EstimatedTestValues { get; }
51        public abstract IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows);
52
53
54        #region Results
55        public double TrainingCash
56        {
57            get { return ((DoubleValue)this[TrainingCashResultName].Value).Value; }
58            private set { ((DoubleValue)this[TrainingCashResultName].Value).Value = value; }
59        }
60        public double TestCash
61        {
62            get { return ((DoubleValue)this[TestCashResultName].Value).Value; }
63            private set { ((DoubleValue)this[TestCashResultName].Value).Value = value; }
64        }
65        public double TradeDays
66        {
67            get { return ((DoubleValue)this[TestTradeDaysResultName].Value).Value; }
68            private set { ((DoubleValue)this[TestTradeDaysResultName].Value).Value = value; }
69        }
70        public double NumberTrades
71        {
72            get { return ((DoubleValue)this[TestNumberTradesResultName].Value).Value; }
73            private set { ((DoubleValue)this[TestNumberTradesResultName].Value).Value = value; }
74        }
75        public double TotalTrades
76        {
77            get { return ((DoubleValue)this[TestTotalTradesResultName].Value).Value; }
78            private set { ((DoubleValue)this[TestTotalTradesResultName].Value).Value = value; }
79        }
80        public double TrainingMeanSquaredError
81        {
82            get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
83            private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
84        }
85        public double TestMeanSquaredError
86        {
87            get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
88            private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
89        }
90        public double TrainingMeanAbsoluteError
91        {
92            get { return ((DoubleValue)this[TrainingMeanAbsoluteErrorResultName].Value).Value; }
93            private set { ((DoubleValue)this[TrainingMeanAbsoluteErrorResultName].Value).Value = value; }
94        }
95        public double TestMeanAbsoluteError
96        {
97            get { return ((DoubleValue)this[TestMeanAbsoluteErrorResultName].Value).Value; }
98            private set { ((DoubleValue)this[TestMeanAbsoluteErrorResultName].Value).Value = value; }
99        }
100        public double TrainingRSquared
101        {
102            get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
103            private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
104        }
105        public double TestRSquared
106        {
107            get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
108            private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
109        }
110        public double TrainingRelativeError
111        {
112            get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
113            private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
114        }
115        public double TestRelativeError
116        {
117            get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
118            private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
119        }
120        public double TrainingNormalizedMeanSquaredError
121        {
122            get { return ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value; }
123            private set { ((DoubleValue)this[TrainingNormalizedMeanSquaredErrorResultName].Value).Value = value; }
124        }
125        public double TestNormalizedMeanSquaredError
126        {
127            get { return ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value; }
128            private set { ((DoubleValue)this[TestNormalizedMeanSquaredErrorResultName].Value).Value = value; }
129        }
130        public double TrainingMeanError
131        {
132            get { return ((DoubleValue)this[TrainingMeanErrorResultName].Value).Value; }
133            private set { ((DoubleValue)this[TrainingMeanErrorResultName].Value).Value = value; }
134        }
135        public double TestMeanError
136        {
137            get { return ((DoubleValue)this[TestMeanErrorResultName].Value).Value; }
138            private set { ((DoubleValue)this[TestMeanErrorResultName].Value).Value = value; }
139        }
140        #endregion
141
142
143          [StorableConstructor]
144    protected TradeRulesSolutionBase(bool deserializing) : base(deserializing) { }
145    protected TradeRulesSolutionBase(TradeRulesSolutionBase original, Cloner cloner)
146      : base(original, cloner) {
147    }
148    protected TradeRulesSolutionBase(IRegressionModel model, IRegressionProblemData problemData)
149      : base(model, problemData) {
150      Add(new Result(TrainingCashResultName, "Cash obtained after training period in the stock market", new DoubleValue()));
151      Add(new Result(TestCashResultName, "Cash obtained after test period in the stock market", new DoubleValue()));
152      Add(new Result(TestTradeDaysResultName, "Number of trading days", new DoubleValue()));
153      Add(new Result(TestNumberTradesResultName, "Number of trades", new DoubleValue()));
154      Add(new Result(TestTotalTradesResultName, "Total trades", new DoubleValue()));
155     }
156
157    [StorableHook(HookType.AfterDeserialization)]
158    private void AfterDeserialization()
159    {
160        // BackwardsCompatibility3.4
161
162        #region Backwards compatible code, remove with 3.5
163
164        if (!ContainsKey(TrainingMeanAbsoluteErrorResultName))
165        {
166            OnlineCalculatorError errorState;
167            Add(new Result(TrainingMeanAbsoluteErrorResultName, "Mean of absolute errors of the model on the training partition", new DoubleValue()));
168            double trainingMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices), out errorState);
169            TrainingMeanAbsoluteError = errorState == OnlineCalculatorError.None ? trainingMAE : double.NaN;
170        }
171
172        if (!ContainsKey(TestMeanAbsoluteErrorResultName))
173        {
174            OnlineCalculatorError errorState;
175            Add(new Result(TestMeanAbsoluteErrorResultName, "Mean of absolute errors of the model on the test partition", new DoubleValue()));
176            double testMAE = OnlineMeanAbsoluteErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices), out errorState);
177            TestMeanAbsoluteError = errorState == OnlineCalculatorError.None ? testMAE : double.NaN;
178        }
179
180        if (!ContainsKey(TrainingMeanErrorResultName))
181        {
182            OnlineCalculatorError errorState;
183            Add(new Result(TrainingMeanErrorResultName, "Mean of errors of the model on the training partition", new DoubleValue()));
184            double trainingME = OnlineMeanErrorCalculator.Calculate(EstimatedTrainingValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices), out errorState);
185            TrainingMeanError = errorState == OnlineCalculatorError.None ? trainingME : double.NaN;
186        }
187        if (!ContainsKey(TestMeanErrorResultName))
188        {
189            OnlineCalculatorError errorState;
190            Add(new Result(TestMeanErrorResultName, "Mean of errors of the model on the test partition", new DoubleValue()));
191            double testME = OnlineMeanErrorCalculator.Calculate(EstimatedTestValues, ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices), out errorState);
192            TestMeanError = errorState == OnlineCalculatorError.None ? testME : double.NaN;
193        }
194        #endregion
195    }
196
197
198    protected void CalculateResults()
199    {
200        IEnumerable<double> estimatedValues = EstimatedValues;
201        IEnumerable<double> estimatedTrainingValues = EstimatedTrainingValues; // cache values
202        IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
203        IEnumerable<double> estimatedTestValues = EstimatedTestValues; // cache values
204        IEnumerable<double> originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
205
206        double trainingCTR = OnlineTradeRulesCalculator.Calculate(estimatedTrainingValues, ProblemData, ProblemData.TrainingIndices);
207        TrainingCash = trainingCTR;
208        double testCTR = OnlineTradeRulesCalculator.Calculate(estimatedTestValues, ProblemData, ProblemData.TestIndices);
209        TestCash = testCTR;
210        double testTradeDays = OnlineTradeRulesCalculator.getTradeDays();
211        TradeDays = testTradeDays;
212        double testNumberTrades = OnlineTradeRulesCalculator.getNumberTrades();
213        NumberTrades = testNumberTrades;
214        double testTotalTradeDays = OnlineTradeRulesCalculator.getTotalTradesDays();
215        TotalTrades = testTotalTradeDays;
216
217    }
218
219
220   }
221}
Note: See TracBrowser for help on using the repository browser.