Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/21/19 11:06:40 (5 years ago)
Author:
lleko
Message:

#3022 add generateUnivariateBases(), add BasisFunction class

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3022-FastFunctionExtraction/FFX/FastFunctionExtraction.cs

    r17218 r17219  
    11using System;
    22using System.Threading;
     3using System.Linq;
    34using HeuristicLab.Common; // required for parameters collection
    45using HeuristicLab.Core; // required for parameters collection
     
    1011using HeuristicLab.Random; // MersenneTwister
    1112using HEAL.Attic;
     13using HeuristicLab.Algorithms.DataAnalysis.Glmnet;
     14using HeuristicLab.Problems.DataAnalysis;
     15using System.Collections.Generic;
     16using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    1217
    13 namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction {
    14   // each HL item needs to have a name and a description (BasicAlgorithm is an Item)
    15   // The name and description of items is shown in the GUI<<<<
    16   [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
     18namespace HeuristicLab.Algorithms.DataAnalysis.FastFunctionExtraction
     19{
    1720
    18   // If the algorithm should be shown in the "New..." dialog it must be creatable. Entries in the new dialog are grouped to categories and ordered by priorities
    19   [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
     21    [Item(Name = "FastFunctionExtraction", Description = "An FFX algorithm.")]
     22    [Creatable(Category = CreatableAttribute.Categories.Algorithms, Priority = 999)]
     23    [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
     24    public sealed class FastFunctionExtraction : FixedDataAnalysisAlgorithm<RegressionProblem>
     25    {
     26        private enum Operator { Abs, Log };
     27        private static readonly double[] exponents = { 0.5, 1, 2 };
    2028
    21   [StorableType("689280F7-E371-44A2-98A5-FCEDF22CA343")] // for persistence (storing your algorithm to a files or transfer to HeuristicLab.Hive
    22   public class FastFunctionExtraction : BasicAlgorithm {
    23     // This algorithm only works for BinaryProblems.
    24     // Overriding the ProblemType property has the effect that only BinaryProblems can be set as problem
    25     // for the algorithm in the GUI
    26     public override Type ProblemType { get { return typeof(BinaryProblem); } }
    27     public new BinaryProblem Problem { get { return (BinaryProblem)base.Problem; } }
    28    
    29     #region parameters
    30     // If an algorithm has parameters then we usually also add properties to access these parameters.
    31     // This is not strictly required but considered good shape.
    32     private IFixedValueParameter<IntValue> MaxIterationsParameter {
    33       get { return (IFixedValueParameter<IntValue>)Parameters["MaxIterations"]; }
    34     }
    35     public int MaxIterations {
    36       get { return MaxIterationsParameter.Value.Value; }
    37       set { MaxIterationsParameter.Value.Value = value; }
    38     }
    39     #endregion
     29        private const string PenaltyParameterName = "Penalty";
     30        private const string ConsiderInteractionsParameterName = "Consider Interactions";
     31        private const string ConsiderDenominationParameterName = "Consider Denomination";
     32        private const string ConsiderExponentiationParameterName = "Consider Exponentiation";
     33        private const string ConsiderNonlinearFuncsParameterName = "Consider Nonlinear functions";
     34        private const string ConsiderHingeFuncsParameterName = "Consider Hinge Functions";
    4035
    41     // createable items must have a default ctor
    42     public FastFunctionExtraction() {
    43       // algorithm parameters are shown in the GUI
    44       Parameters.Add(new FixedValueParameter<IntValue>("MaxIterations", new IntValue(10000)));
    45     }
     36        #region parameters
     37        public IValueParameter<BoolValue> ConsiderInteractionsParameter
     38        {
     39            get { return (IValueParameter<BoolValue>)Parameters[ConsiderInteractionsParameterName]; }
     40        }
     41        #endregion
    4642
    47     // Persistence uses this ctor to improve deserialization efficiency.
    48     // If we would use the default ctor instead this would completely initialize the object (e.g. creating parameters)
    49     // even though the data is later overwritten by the stored data.
    50     [StorableConstructor]
    51     public FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
     43        #region properties
     44        public bool ConsiderInteractions
     45        {
     46            get { return ConsiderInteractionsParameter.Value.Value; }
     47            set { ConsiderInteractionsParameter.Value.Value = value; }
     48        }
     49        #endregion
    5250
    53     // Each clonable item must have a cloning ctor (deep cloning, the cloner is used to handle cyclic object references)
    54     public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner) {
    55       // Don't forget to call the cloning ctor of the base class
    56       // This class does not have fields, therefore we don't need to actually clone anything
    57     }
     51        [StorableConstructor]
     52        private FastFunctionExtraction(StorableConstructorFlag _) : base(_) { }
     53        public FastFunctionExtraction(FastFunctionExtraction original, Cloner cloner) : base(original, cloner)
     54        {
     55            // Don't forget to call the cloning ctor of the base class 
     56            // This class does not have fields, therefore we don't need to actually clone anything
     57        }
     58        public FastFunctionExtraction() : base()
     59        {
     60            // algorithm parameters are shown in the GUI
     61            Parameters.Add(new FixedValueParameter<DoubleValue>(PenaltyParameterName, "Penalty factor (alpha) for balancing between ridge (0.0) and lasso (1.0) regression", new DoubleValue(0.5)));
     62            Parameters.Add(new ValueParameter<BoolValue>(ConsiderInteractionsParameterName, "True if you want to consider interactions, otherwise false.", new BoolValue(true)));
     63            Parameters.Add(new ValueParameter<BoolValue>(ConsiderDenominationParameterName, "True if you want to consider denominations, otherwise false.", new BoolValue(true)));
     64            Parameters.Add(new ValueParameter<BoolValue>(ConsiderExponentiationParameterName, "True if you want to consider exponentiation, otherwise false.", new BoolValue(true)));
     65            Parameters.Add(new ValueParameter<BoolValue>(ConsiderNonlinearFuncsParameterName, "True if you want to consider nonlinear functions(abs, log,...), otherwise false.", new BoolValue(true)));
     66            Parameters.Add(new ValueParameter<BoolValue>(ConsiderHingeFuncsParameterName, "True if you want to consider Hinge Functions, otherwise false.", new BoolValue(true)));
     67        }
    5868
    59     public override IDeepCloneable Clone(Cloner cloner) {
    60       return new FastFunctionExtraction(this, cloner);
    61     }
     69        [StorableHook(HookType.AfterDeserialization)]
     70        private void AfterDeserialization() { }
    6271
    63     protected override void Run(CancellationToken cancellationToken) {
    64       int maxIters = MaxIterations;
    65       var problem = Problem;
    66       var rand = new MersenneTwister(1234);
     72        public override IDeepCloneable Clone(Cloner cloner)
     73        {
     74            return new FastFunctionExtraction(this, cloner);
     75        }
    6776
    68       var bestQuality = problem.Maximization ? double.MinValue : double.MaxValue;
    69 
    70       var curItersItem = new IntValue();
    71       var bestQualityItem = new DoubleValue(bestQuality);
    72       var curItersResult = new Result("Iteration", curItersItem);
    73       var bestQualityResult = new Result("Best quality", bestQualityItem);
    74       Results.Add(curItersResult);
    75       Results.Add(bestQualityResult);
    76 
    77       var funcs = generateBasisFunctions();
    78 
    79       for (int i = 0; i < maxIters; i++) {
    80         curItersItem.Value = i;
    81 
    82         // -----------------------------
    83         // IMPLEMENT YOUR ALGORITHM HERE
    84         // -----------------------------
     77        public override Type ProblemType { get { return typeof(RegressionProblem); } }
     78        public new RegressionProblem Problem { get { return (RegressionProblem)base.Problem; } }
    8579
    8680
    87         // this is an example for random search
    88         // for a more elaborate algorithm check the source code of "HeuristicLab.Algorithms.ParameterlessPopulationPyramid"
    89         var cand = new BinaryVector(problem.Length, rand);
    90         var quality = problem.Evaluate(cand, rand); // calling Evaluate like this is not possible for all problems...
    91         if (problem.Maximization) bestQuality = Math.Max(bestQuality, quality);
    92         else bestQuality = Math.Min(quality, bestQuality);
    93         bestQualityItem.Value = bestQuality;
     81        protected override void Run(CancellationToken cancellationToken)
     82        {
     83            var basisFunctions = generateBasisFunctions(Problem.ProblemData);
     84            var x = Problem.ProblemData.AllowedInputsTrainingValues;
     85            List<SymbolicExpressionTree> trees = new List<SymbolicExpressionTree>();
    9486
    95         // check the cancellation token to see if the used clicked "Stop"
    96         if (cancellationToken.IsCancellationRequested) break;
    97       }
    9887
    99       Results.Add(new Result("Execution time", new TimeSpanValue(this.ExecutionTime)));
    100     }
     88            foreach (var basisFunc in basisFunctions)
     89            {
     90                // add tree representation of basisFunc to trees
     91                trees.Add(generateSymbolicExpressionTree(basisFunc));
     92            }
    10193
    102        
    103         private object generateBasisFunctions()
     94            foreach (var tree in trees)
     95            {
     96                // create new data through the help of the Interpreter
     97                //IEnumerable<double> responses =
     98            }
     99
     100            var coefficientVectorSet = findCoefficientValues(basisFunctions);
     101            var paretoFront = nondominatedFilter(coefficientVectorSet);
     102        }
     103
     104        private SymbolicExpressionTree generateSymbolicExpressionTree(KeyValuePair<string, double[]> basisFunc)
    104105        {
    105106            throw new NotImplementedException();
    106107        }
    107108
    108         public override bool SupportsPause {
    109       get { return false; }
     109        // generate all possible models
     110        private static Dictionary<string, double[]> generateBasisFunctions(IRegressionProblemData problemData)
     111        {
     112            var basisFunctions = generateUnivariateBases(problemData);
     113            return basisFunctions;
     114        }
     115
     116        private static Dictionary<string, double[]> generateUnivariateBases(IRegressionProblemData problemData)
     117        {
     118
     119            var dataset = problemData.Dataset;
     120            var rows = problemData.TrainingIndices;
     121            var B1 = new Dictionary<string, double[]>();
     122
     123            foreach (var variableName in dataset.VariableNames)
     124            {
     125                foreach (var exp in new[] { 0.5, 1, 2 })
     126                {
     127                    var name = variableName + " ** " + exp;
     128                    var data = dataset.GetDoubleValues(variableName, rows).Select(x => Math.Pow(x, exp)).ToArray();
     129                    B1.Add(name, data);
     130                    foreach (Operator op in Enum.GetValues(typeof(Operator)))
     131                    {
     132                        var inner_name = op.ToString() + "(" + name + ")";
     133                        var inner_data = data.Select(x => executeOperator(x, op)).ToArray();
     134                        B1.Add(inner_name, inner_data);
     135                    }
     136                }
     137            }
     138           
     139            return B1;
     140        }
     141
     142        private static double executeOperator(double x, Operator op)
     143        {
     144            switch (op)
     145            {
     146                case Operator.Abs:
     147                    return x > 0 ? x : -x;
     148                case Operator.Log:
     149                    return Math.Log10(x);
     150                default:
     151                    throw new NotImplementedException();
     152            }
     153        }
     154
     155        private static Dictionary<string, double[]> generateMultiVariateBases(Dictionary<string, double[]> B1)
     156        {
     157            var B2 = new Dictionary<string, double[]>();
     158            for(int i = 1; i <= B1.Count(); i++ )
     159            {
     160                var b_i = B1.ElementAt(i);
     161                for (int j = 1; j < i; i++)
     162                {
     163                    var b_j = B1.ElementAt(j);
     164                }
     165            }
     166
     167            // return union of B1 and B2
     168            return B2.Concat(B1).ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
     169        }
     170
     171        private static object findCoefficientValues(IEnumerable<KeyValuePair<string, double[]>> basisFunctions)
     172        {
     173            return new object();
     174        }
     175
     176        private static object nondominatedFilter(object coefficientVectorSet)
     177        {
     178            return new object();
     179        }
     180
     181        public override bool SupportsPause
     182        {
     183            get { return false; }
     184        }
    110185    }
    111   }
    112186}
Note: See TracChangeset for help on using the changeset viewer.