Free cookie consent management tool by TermsFeed Policy Generator

Changeset 13724


Ignore:
Timestamp:
03/24/16 11:05:12 (9 years ago)
Author:
mkommend
Message:

#2591: Reversed merged accidentally commited changes to GradientBoostingRegressionAlgorithm (r13721).

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs

    r13721 r13724  
    317317          // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights)
    318318
    319           var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone());
     319          var ensembleModel = new RegressionEnsembleModel(models) { AverageModelEstimates = false };
     320          var ensembleSolution = ensembleModel.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
    320321          Results.Add(new Result("EnsembleSolution", ensembleSolution));
    321322        }
     
    325326        alg.Prepare(true);
    326327      }
    327     }
    328 
    329     private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models,
    330       IRegressionProblemData problemData) {
    331       var rows = problemData.TrainingPartition.Size;
    332       var features = models.Count;
    333       double[,] inputMatrix = new double[rows, features + 1];
    334 
    335       //add model estimates
    336       for (int m = 0; m < models.Count; m++) {
    337         var model = models[m];
    338         var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices);
    339         int estimatesCounter = 0;
    340         foreach (var estimate in estimates) {
    341           inputMatrix[estimatesCounter, m] = estimate;
    342           estimatesCounter++;
    343         }
    344       }
    345 
    346       //add target
    347       var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices);
    348       int targetCounter = 0;
    349       foreach (var target in targets) {
    350         inputMatrix[targetCounter, models.Count] = target;
    351         targetCounter++;
    352       }
    353 
    354       alglib.linearmodel lm = new alglib.linearmodel();
    355       alglib.lrreport ar = new alglib.lrreport();
    356       double[] coefficients;
    357       int retVal = 1;
    358       alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar);
    359       if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution");
    360 
    361       alglib.lrunpack(lm, out coefficients, out features);
    362 
    363       var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false };
    364       var ensembleSolution = ensembleModel.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
    365       return ensembleSolution;
    366328    }
    367329
Note: See TracChangeset for help on using the changeset viewer.