00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017 #ifndef __TRAININGALGORITHM_H__
00018 #define __TRAININGALGORITHM_H__
00019
00020 #include "../MultilayerPerceptron/MultilayerPerceptron.h"
00021 #include "../ObjectiveFunctional/ObjectiveFunctional.h"
00022
00023
00024 namespace Flood
00025 {
00026
00030
00031 class TrainingAlgorithm
00032 {
00033
00034 public:
00035
00036
00037
00039
00040 enum TrainingRateMethod{Fixed, GoldenSection, BrentMethod};
00041
00042
00043
00044
00045 explicit TrainingAlgorithm(ObjectiveFunctional*);
00046
00047
00048
00049
00050 explicit TrainingAlgorithm(void);
00051
00052
00053
00054
00055 virtual ~TrainingAlgorithm(void);
00056
00057
00058
00059
00060
00061
00062 ObjectiveFunctional* get_objective_functional_pointer(void);
00063
00064
00065
00066 TrainingAlgorithm::TrainingRateMethod get_training_rate_method(void);
00067 std::string get_training_rate_method_name(void);
00068
00069
00070
00071 double get_first_training_rate(void);
00072 double get_bracketing_factor(void);
00073 double get_training_rate_tolerance(void);
00074
00075 double get_warning_parameters_norm(void);
00076 double get_warning_gradient_norm(void);
00077 double get_warning_training_rate(void);
00078
00079 double get_error_parameters_norm(void);
00080 double get_error_gradient_norm(void);
00081 double get_error_training_rate(void);
00082
00083
00084
00085 double get_minimum_parameters_increment_norm(void);
00086
00087 double get_minimum_evaluation_improvement(void);
00088 double get_evaluation_goal(void);
00089 double get_gradient_norm_goal(void);
00090
00091 int get_maximum_epochs_number(void);
00092 double get_maximum_time(void);
00093
00094
00095
00096 bool get_early_stopping(void);
00097
00098
00099
00100 bool get_reserve_parameters_history(void);
00101 bool get_reserve_parameters_norm_history(void);
00102
00103 bool get_reserve_evaluation_history(void);
00104 bool get_reserve_gradient_history(void);
00105 bool get_reserve_gradient_norm_history(void);
00106 bool get_reserve_inverse_Hessian_history(void);
00107 bool get_reserve_validation_error_history(void);
00108
00109 bool get_reserve_training_direction_history(void);
00110 bool get_reserve_training_rate_history(void);
00111 bool get_reserve_elapsed_time_history(void);
00112
00113
00114
00115 Vector< Vector<double> >& get_parameters_history(void);
00116 Vector<double>& get_parameters_norm_history(void);
00117
00118 Vector<double>& get_evaluation_history(void);
00119 Vector< Vector<double> >& get_gradient_history(void);
00120 Vector<double>& get_gradient_norm_history(void);
00121 Vector< Matrix<double> >& get_inverse_Hessian_history(void);
00122 Vector<double>& get_validation_error_history(void);
00123
00124 Vector< Vector<double> >& get_training_direction_history(void);
00125 Vector<double>& get_training_rate_history(void);
00126 Vector<double>& get_elapsed_time_history(void);
00127
00128
00129
00130 bool get_display(void);
00131 int get_display_period(void);
00132
00133
00134
00135 void set(void);
00136 void set(ObjectiveFunctional*);
00137 virtual void set_default(void);
00138
00139 void set_objective_functional_pointer(ObjectiveFunctional*);
00140
00141
00142
00143 void set_training_rate_method(const TrainingRateMethod&);
00144 void set_training_rate_method(const std::string&);
00145
00146
00147
00148 void set_first_training_rate(double);
00149 void set_bracketing_factor(double);
00150 void set_training_rate_tolerance(double);
00151
00152 void set_warning_parameters_norm(double);
00153 void set_warning_gradient_norm(double);
00154 void set_warning_training_rate(double);
00155
00156 void set_error_parameters_norm(double);
00157 void set_error_gradient_norm(double);
00158 void set_error_training_rate(double);
00159
00160
00161
00162 void set_minimum_parameters_increment_norm(double);
00163
00164 void set_minimum_evaluation_improvement(double);
00165 void set_evaluation_goal(double);
00166 void set_gradient_norm_goal(double);
00167
00168 void set_maximum_epochs_number(int);
00169 void set_maximum_time(double);
00170
00171
00172
00173 void set_early_stopping(bool);
00174
00175
00176
00177 void set_reserve_parameters_history(bool);
00178 void set_reserve_parameters_norm_history(bool);
00179
00180 void set_reserve_evaluation_history(bool);
00181 void set_reserve_gradient_history(bool);
00182 void set_reserve_gradient_norm_history(bool);
00183 void set_reserve_inverse_Hessian_history(bool);
00184 void set_reserve_validation_error_history(bool);
00185
00186 void set_reserve_training_direction_history(bool);
00187 void set_reserve_training_rate_history(bool);
00188 void set_reserve_elapsed_time_history(bool);
00189
00191
00192 virtual void set_reserve_all_training_history(bool);
00193
00194
00195
00196 void set_parameters_history(const Vector< Vector<double> >&);
00197 void set_parameters_norm_history(const Vector<double>&);
00198
00199 void set_evaluation_history(const Vector<double>&);
00200 void set_gradient_history(const Vector< Vector<double> >&);
00201 void set_gradient_norm_history(const Vector<double>&);
00202 void set_inverse_Hessian_history(const Vector< Matrix<double> >&);
00203
00204 void set_training_direction_history(const Vector< Vector<double> >&);
00205 void set_training_rate_history(const Vector<double>&);
00206 void set_elapsed_time_history(const Vector<double>&);
00207
00208 void set_validation_error_history(const Vector<double>&);
00209
00210
00211
00212 void set_display(bool);
00213 void set_display_period(int);
00214
00215
00216
00218
00219 virtual void train(void) = 0;
00220
00221
00222
00223 Vector<double> calculate_training_rate_evaluation(double, const Vector<double>&, double);
00224
00225 Vector<double> calculate_fixed_training_rate_evaluation(double, const Vector<double>&, double);
00226 Vector<double> calculate_golden_section_training_rate_evaluation(double, const Vector<double>&, double);
00227 Vector<double> calculate_Brent_method_training_rate_evaluation(double, const Vector<double>&, double);
00228
00229 Vector<double> calculate_bracketing_training_rate_evaluation(double, const Vector<double>&, double);
00230
00231
00232
00233 virtual void resize_training_history(int);
00234
00235 virtual std::string get_training_history_XML(bool);
00236 void print_training_history(void);
00237 void save_training_history(const char*);
00238
00239
00240
00241 virtual std::string to_XML(bool);
00242 void print(void);
00243 void save(const char*);
00244 virtual void load(const char*);
00245
00246
00247 protected:
00248
00249
00250
00252
00253 ObjectiveFunctional* objective_functional_pointer;
00254
00255
00256
00258
00259 TrainingRateMethod training_rate_method;
00260
00262
00263 double bracketing_factor;
00264
00266
00267 double first_training_rate;
00268
00270
00271 double training_rate_tolerance;
00272
00274
00275 double warning_parameters_norm;
00276
00278
00279 double warning_gradient_norm;
00280
00282
00283 double warning_training_rate;
00284
00286
00287 double error_parameters_norm;
00288
00290
00291 double error_gradient_norm;
00292
00294
00295 double error_training_rate;
00296
00297
00298
00299
00301
00302 double minimum_parameters_increment_norm;
00303
00305
00306 double minimum_evaluation_improvement;
00307
00309
00310 double evaluation_goal;
00311
00313
00314 double gradient_norm_goal;
00315
00317
00318 int maximum_epochs_number;
00319
00321
00322 double maximum_time;
00323
00325
00326 bool early_stopping;
00327
00328
00329
00331
00332 bool reserve_parameters_history;
00333
00335
00336 bool reserve_parameters_norm_history;
00337
00339
00340 bool reserve_evaluation_history;
00341
00343
00344 bool reserve_gradient_history;
00345
00347
00348 bool reserve_gradient_norm_history;
00349
00351
00352 bool reserve_inverse_Hessian_history;
00353
00355
00356 bool reserve_training_direction_history;
00357
00359
00360 bool reserve_training_rate_history;
00361
00363
00364 bool reserve_elapsed_time_history;
00365
00367
00368 bool reserve_validation_error_history;
00369
00371
00372 Vector< Vector<double> > parameters_history;
00373
00375
00376 Vector<double> parameters_norm_history;
00377
00379
00380 Vector<double> evaluation_history;
00381
00383
00384 Vector< Vector<double> > gradient_history;
00385
00387
00388 Vector<double> gradient_norm_history;
00389
00391
00392 Vector< Matrix<double> > inverse_Hessian_history;
00393
00395
00396 Vector< Vector<double> > training_direction_history;
00397
00399
00400 Vector<double> training_rate_history;
00401
00403
00404 Vector<double> elapsed_time_history;
00405
00407
00408 Vector<double> validation_error_history;
00409
00410
00411
00413
00414 bool display;
00415
00417
00418 int display_period;
00419 };
00420
00421 }
00422
00423 #endif
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441