Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/03/16 18:54:14 (8 years ago)
Author:
gkronber
Message:

created a feature branch for #2650 (support for categorical variables in symb reg) with a first set of changes

work in progress...

Location:
branches/symbreg-factors-2650
Files:
1 edited
1 copied

Legend:

Unmodified
Added
Removed
  • branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14185 r14232  
    181181      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    182182      List<string> variableNames = new List<string>();
     183      List<string> categoricalVariableValues = new List<string>();
    183184
    184185      AutoDiff.Term func;
    185       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out func))
     186      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out func))
    186187        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    187       if (variableNames.Count == 0) return 0.0;
     188      if (variableNames.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
    188189
    189190      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
    190191
    191       List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
     192      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
    192193      if (updateVariableWeights)
    193194        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
     
    222223      foreach (var r in rows) {
    223224        for (int col = 0; col < variableNames.Count; col++) {
    224           x[row, col] = ds.GetDoubleValue(variableNames[col], r);
     225          if (ds.VariableHasType<double>(variableNames[col])) {
     226            x[row, col] = ds.GetDoubleValue(variableNames[col], r);
     227          } else if (ds.VariableHasType<string>(variableNames[col])) {
     228            x[row, col] = ds.GetStringValue(variableNames[col], r) == categoricalVariableValues[col] ? 1 : 0;
     229          } else throw new InvalidProgramException("found a variable of unknown type");
    225230        }
    226231        row++;
     
    286291    }
    287292
    288     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, bool updateVariableWeights, out AutoDiff.Term term) {
     293    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters,
     294      List<string> variableNames, List<string> categoricalVariableValues, bool updateVariableWeights, out AutoDiff.Term term) {
    289295      if (node.Symbol is Constant) {
    290296        var var = new AutoDiff.Variable();
     
    298304        parameters.Add(par);
    299305        variableNames.Add(varNode.VariableName);
     306        categoricalVariableValues.Add(string.Empty);   // as a value as placeholder (variableNames.Length == catVariableValues.Length)
    300307
    301308        if (updateVariableWeights) {
     
    308315        return true;
    309316      }
     317      if (node.Symbol is FactorVariable) {
     318        // nothing to update in this case (like a variable without a weight)
     319        // values are only 0 or 1 and set in x accordingly
     320        var factorNode = node as FactorVariableTreeNode;
     321        var par = new AutoDiff.Variable();
     322        parameters.Add(par);
     323        variableNames.Add(factorNode.VariableName);
     324        categoricalVariableValues.Add(factorNode.VariableValue);
     325        term = par;
     326        return true;
     327      }
    310328      if (node.Symbol is Addition) {
    311329        List<AutoDiff.Term> terms = new List<Term>();
    312330        foreach (var subTree in node.Subtrees) {
    313331          AutoDiff.Term t;
    314           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
     332          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    315333            term = null;
    316334            return false;
     
    325343        for (int i = 0; i < node.SubtreeCount; i++) {
    326344          AutoDiff.Term t;
    327           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, updateVariableWeights, out t)) {
     345          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    328346            term = null;
    329347            return false;
     
    340358        foreach (var subTree in node.Subtrees) {
    341359          AutoDiff.Term t;
    342           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
     360          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    343361            term = null;
    344362            return false;
     
    355373        foreach (var subTree in node.Subtrees) {
    356374          AutoDiff.Term t;
    357           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
     375          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    358376            term = null;
    359377            return false;
     
    367385      if (node.Symbol is Logarithm) {
    368386        AutoDiff.Term t;
    369         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     387        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    370388          term = null;
    371389          return false;
     
    377395      if (node.Symbol is Exponential) {
    378396        AutoDiff.Term t;
    379         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     397        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    380398          term = null;
    381399          return false;
     
    387405      if (node.Symbol is Square) {
    388406        AutoDiff.Term t;
    389         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     407        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    390408          term = null;
    391409          return false;
     
    397415      if (node.Symbol is SquareRoot) {
    398416        AutoDiff.Term t;
    399         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     417        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    400418          term = null;
    401419          return false;
     
    407425      if (node.Symbol is Sine) {
    408426        AutoDiff.Term t;
    409         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     427        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    410428          term = null;
    411429          return false;
     
    417435      if (node.Symbol is Cosine) {
    418436        AutoDiff.Term t;
    419         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     437        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    420438          term = null;
    421439          return false;
     
    427445      if (node.Symbol is Tangent) {
    428446        AutoDiff.Term t;
    429         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     447        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    430448          term = null;
    431449          return false;
     
    437455      if (node.Symbol is Erf) {
    438456        AutoDiff.Term t;
    439         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     457        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    440458          term = null;
    441459          return false;
     
    447465      if (node.Symbol is Norm) {
    448466        AutoDiff.Term t;
    449         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
     467        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
    450468          term = null;
    451469          return false;
     
    461479        variables.Add(alpha);
    462480        AutoDiff.Term branchTerm;
    463         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out branchTerm)) {
     481        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out branchTerm)) {
    464482          term = branchTerm * alpha + beta;
    465483          return true;
     
    478496        where
    479497         !(n.Symbol is Variable) &&
     498         !(n.Symbol is FactorVariable) &&
    480499         !(n.Symbol is Constant) &&
    481500         !(n.Symbol is Addition) &&
Note: See TracChangeset for help on using the changeset viewer.