1
0
mirror of https://github.com/lxsang/antd-lua-plugin synced 2024-12-26 17:38:21 +01:00
antd-lua-plugin/lib/ann/ann.c
2018-09-19 15:08:49 +02:00

267 lines
15 KiB
C

/*
This file is part of fann.
fann is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
fann is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License
along with fann. If not, see <http://www.gnu.org/licenses/>.
Copyright (C) 2009 - 2013 Lucas Hermann Negri
*/
#include "lfann.h"
#include "net.c"
#include "data.c"
#include "extension.c"
/* Interface */
static void priv_register(lua_State* L, const char* name, const char* mt,
const luaL_reg* methods, lua_CFunction gc, lua_CFunction tostring)
{
/* Register the methods */
luaL_register(L, name, methods);
/* Register the metatable */
luaL_newmetatable(L, mt);
/* __index */
lua_pushliteral(L, "__index");
lua_pushvalue(L, -3);
lua_rawset(L, -3);
/* __gc */
lua_pushliteral(L, "__gc");
lua_pushcfunction(L, gc);
lua_rawset(L, -3);
/* __tostring */
lua_pushliteral(L, "__tostring");
lua_pushcfunction(L, tostring);
lua_rawset(L, -3);
}
static const struct luaL_reg ann [] =
{
{NULL, NULL}
};
static const struct luaL_reg data [] =
{
{"read_from_file", lfann_data_read_from_file},
{"create_from_callback", lfann_data_create_from_callback},
{"save", lfann_data_save},
{"save_to_fixed", lfann_data_save_to_fixed},
{"shuffle", lfann_data_shuffle},
{"scale", lfann_data_scale},
{"scale_input", lfann_data_scale_input},
{"scale_output", lfann_data_scale_output},
{"merge", lfann_data_merge},
{"duplicate", lfann_data_duplicate},
{"subset", lfann_data_subset},
{"length", lfann_data_length},
{"num_input", lfann_data_num_input},
{"num_output", lfann_data_num_output},
{"get_row", lfann_data_get_row},
{"get_bounds", lfann_data_get_bounds},
{"get_bounds_input", lfann_data_get_bounds_input},
{"get_bounds_output", lfann_data_get_bounds_output},
{NULL, NULL}
};
static const struct luaL_reg net [] =
{
{"copy", lfann_net_copy},
{"create_standard", lfann_net_create_standard},
{"create_sparse", lfann_net_create_sparse},
{"create_shortcut", lfann_net_create_shortcut},
{"create_from_file", lfann_net_create_from_file},
{"print_connections", lfann_net_print_connections},
{"train_on_file", lfann_net_train_on_file},
{"train_on_data", lfann_net_train_on_data},
{"train_epoch", lfann_net_train_epoch},
{"save", lfann_net_save},
{"save_to_fixed", lfann_net_save_to_fixed},
{"run", lfann_net_run},
{"init_weights", lfann_net_init_weights},
{"randomize_weights", lfann_net_randomize_weights},
{"print_parameters", lfann_net_print_parameters},
{"set_weight", lfann_net_set_weight},
{"get_MSE", lfann_net_get_MSE},
{"get_training_algorithm", lfann_net_get_training_algorithm},
{"set_training_algorithm", lfann_net_set_training_algorithm},
{"reset_MSE", lfann_net_reset_MSE},
{"get_bit_fail", lfann_net_get_bit_fail},
{"set_callback", lfann_net_set_callback},
{"test_data", lfann_net_test_data},
{"get_network_type", lfann_net_get_network_type},
{"set_learning_rate", lfann_net_set_learning_rate},
{"get_learning_rate", lfann_net_get_learning_rate},
{"set_learning_momentum", lfann_net_set_learning_momentum},
{"get_learning_momentum", lfann_net_get_learning_momentum},
{"set_activation_function", lfann_net_set_activation_function},
{"get_activation_function", lfann_net_get_activation_function},
{"set_activation_function_hidden", lfann_net_set_activation_function_hidden},
{"set_activation_function_layer", lfann_net_set_activation_function_layer},
{"set_activation_function_output", lfann_net_set_activation_function_output},
{"set_activation_steepness", lfann_net_set_activation_steepness},
{"get_activation_steepness", lfann_net_get_activation_steepness},
{"set_activation_steepness_hidden", lfann_net_set_activation_steepness_hidden},
{"set_activation_steepness_layer", lfann_net_set_activation_steepness_layer},
{"set_activation_steepness_output", lfann_net_set_activation_steepness_output},
{"set_train_error_function", lfann_net_set_train_error_function},
{"get_train_error_function", lfann_net_get_train_error_function},
{"set_train_stop_function", lfann_net_set_train_stop_function},
{"get_train_stop_function", lfann_net_get_train_stop_function},
{"set_bit_fail_limit", lfann_net_set_bit_fail_limit},
{"get_bit_fail_limit", lfann_net_get_bit_fail_limit},
{"set_quickprop_decay", lfann_net_set_quickprop_decay},
{"get_quickprop_decay", lfann_net_get_quickprop_decay},
{"set_quickprop_mu", lfann_net_set_quickprop_mu},
{"get_quickprop_mu", lfann_net_get_quickprop_mu},
{"set_rprop_increase_factor", lfann_net_set_rprop_increase_factor},
{"get_rprop_increase_factor", lfann_net_get_rprop_increase_factor},
{"set_rprop_decrease_factor", lfann_net_set_rprop_decrease_factor},
{"get_rprop_decrease_factor", lfann_net_get_rprop_decrease_factor},
{"get_rprop_delta_min", lfann_net_get_rprop_delta_min},
{"set_rprop_delta_min", lfann_net_set_rprop_delta_min},
{"get_rprop_delta_max", lfann_net_get_rprop_delta_max},
{"set_rprop_delta_max", lfann_net_set_rprop_delta_max},
{"set_rprop_delta_zero", lfann_net_set_rprop_delta_zero},
{"get_rprop_delta_zero", lfann_net_get_rprop_delta_zero},
{"get_num_input", lfann_net_get_num_input},
{"get_num_output", lfann_net_get_num_output},
{"get_total_neurons", lfann_net_get_total_neurons},
{"get_total_connections", lfann_net_get_total_connections},
{"get_connection_rate", lfann_net_get_connection_rate},
{"get_num_layers", lfann_net_get_num_layers},
{"get_layer_array", lfann_net_get_layer_array},
{"get_bias_array", lfann_net_get_bias_array},
{"get_connection_array", lfann_net_get_connection_array},
{"cascade_train_on_data", lfann_net_cascade_train_on_data},
{"cascade_train_on_file", lfann_net_cascade_train_on_file},
{"set_cascade_output_change_fraction", lfann_net_set_cascade_output_change_fraction},
{"get_cascade_output_change_fraction", lfann_net_get_cascade_output_change_fraction},
{"set_cascade_output_stagnation_epochs", lfann_net_set_cascade_output_stagnation_epochs},
{"get_cascade_output_stagnation_epochs", lfann_net_get_cascade_output_stagnation_epochs},
{"set_cascade_candidate_change_fraction", lfann_net_set_cascade_candidate_change_fraction},
{"get_cascade_candidate_change_fraction", lfann_net_get_cascade_candidate_change_fraction},
{"set_cascade_candidate_stagnation_epochs", lfann_net_set_cascade_candidate_stagnation_epochs},
{"get_cascade_candidate_stagnation_epochs", lfann_net_get_cascade_candidate_stagnation_epochs},
{"set_cascade_weight_multiplier", lfann_net_set_cascade_weight_multiplier},
{"get_cascade_weight_multiplier", lfann_net_get_cascade_weight_multiplier},
{"set_cascade_candidate_limit", lfann_net_set_cascade_candidate_limit},
{"get_cascade_candidate_limit", lfann_net_get_cascade_candidate_limit},
{"get_cascade_max_cand_epochs", lfann_net_get_cascade_max_cand_epochs},
{"set_cascade_max_cand_epochs", lfann_net_set_cascade_max_cand_epochs},
{"get_cascade_max_out_epochs", lfann_net_get_cascade_max_out_epochs},
{"set_cascade_max_out_epochs", lfann_net_set_cascade_max_out_epochs},
{"get_cascade_num_candidates", lfann_net_get_cascade_num_candidates},
{"get_cascade_activation_functions_count", lfann_net_get_cascade_activation_functions_count},
{"get_cascade_activation_functions", lfann_net_get_cascade_activation_functions},
{"set_cascade_activation_functions", lfann_net_set_cascade_activation_functions},
{"get_cascade_activation_steepnesses_count", lfann_net_get_cascade_activation_steepnesses_count},
{"get_cascade_activation_steepnesses", lfann_net_get_cascade_activation_steepnesses},
{"set_cascade_activation_steepnesses", lfann_net_set_cascade_activation_steepnesses},
{"get_cascade_activation_steepnesses_count", lfann_net_get_cascade_activation_steepnesses_count},
{"set_cascade_num_candidate_groups", lfann_net_set_cascade_num_candidate_groups},
{"get_cascade_num_candidate_groups", lfann_net_get_cascade_num_candidate_groups},
{"get_cascade_min_out_epochs", lfann_net_get_cascade_min_out_epochs},
{"set_cascade_min_out_epochs", lfann_net_set_cascade_min_out_epochs},
{"get_cascade_min_cand_epochs", lfann_net_get_cascade_min_cand_epochs},
{"set_cascade_min_cand_epochs", lfann_net_set_cascade_min_cand_epochs},
{"get_sarprop_weight_decay_shift", lfann_net_get_sarprop_weight_decay_shift},
{"set_sarprop_weight_decay_shift", lfann_net_set_sarprop_weight_decay_shift},
{"get_sarprop_step_error_threshold_factor", lfann_net_get_sarprop_step_error_threshold_factor},
{"set_sarprop_step_error_threshold_factor", lfann_net_set_sarprop_step_error_threshold_factor},
{"get_sarprop_step_error_shift", lfann_net_get_sarprop_step_error_shift},
{"set_sarprop_step_error_shift", lfann_net_set_sarprop_step_error_shift},
{"get_sarprop_temperature", lfann_net_get_sarprop_temperature},
{"set_sarprop_temperature", lfann_net_set_sarprop_temperature},
{NULL, NULL}
};
#ifdef _MSC_VER
__declspec(dllexport)
#endif
int luaopen_lfann(lua_State *L)
{
int top;
fann_set_error_log(&err_handler, NULL);
luaL_register(L, "fann", ann);
top = lua_gettop(L);
priv_register(L, "fann.Net", "lfannNeT", net, lfann_net_gc, lfann_net_tostring);
priv_register(L, "fann.Data", "lfannDaTa", data, lfann_data_gc, lfann_data_tostring);
lua_settop(L, top);
/* activation_functipn_enum */
lua_pushliteral(L, "LINEAR"); lua_pushnumber(L, FANN_LINEAR); lua_rawset(L, -3);
lua_pushliteral(L, "THRESHOLD"); lua_pushnumber(L, FANN_THRESHOLD); lua_rawset(L, -3);
lua_pushliteral(L, "THRESHOLD_SYMMETRIC"); lua_pushnumber(L, FANN_THRESHOLD_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "SIGMOID"); lua_pushnumber(L, FANN_SIGMOID); lua_rawset(L, -3);
lua_pushliteral(L, "SIGMOID_STEPWISE"); lua_pushnumber(L, FANN_SIGMOID_STEPWISE); lua_rawset(L, -3);
lua_pushliteral(L, "SIGMOID_SYMMETRIC"); lua_pushnumber(L, FANN_SIGMOID_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "GAUSSIAN"); lua_pushnumber(L, FANN_GAUSSIAN); lua_rawset(L, -3);
lua_pushliteral(L, "GAUSSIAN_SYMMETRIC"); lua_pushnumber(L, FANN_GAUSSIAN_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "ELLIOT"); lua_pushnumber(L, FANN_ELLIOT); lua_rawset(L, -3);
lua_pushliteral(L, "ELLIOT_SYMMETRIC"); lua_pushnumber(L, FANN_ELLIOT_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "LINEAR_PIECE"); lua_pushnumber(L, FANN_LINEAR_PIECE); lua_rawset(L, -3);
lua_pushliteral(L, "LINEAR_PIECE_SYMMETRIC"); lua_pushnumber(L, FANN_LINEAR_PIECE_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "SIN"); lua_pushnumber(L, FANN_SIN); lua_rawset(L, -3);
lua_pushliteral(L, "SIN_SYMMETRIC"); lua_pushnumber(L, FANN_SIN_SYMMETRIC); lua_rawset(L, -3);
lua_pushliteral(L, "COS"); lua_pushnumber(L, FANN_COS); lua_rawset(L, -3);
lua_pushliteral(L, "COS_SYMMETRIC"); lua_pushnumber(L, FANN_COS_SYMMETRIC); lua_rawset(L, -3);
/* training_algorithm_enum */
lua_pushliteral(L, "TRAIN_INCREMENTAL"); lua_pushnumber(L, FANN_TRAIN_INCREMENTAL); lua_rawset(L, -3);
lua_pushliteral(L, "TRAIN_BATCH"); lua_pushnumber(L, FANN_TRAIN_BATCH); lua_rawset(L, -3);
lua_pushliteral(L, "TRAIN_RPROP"); lua_pushnumber(L, FANN_TRAIN_RPROP); lua_rawset(L, -3);
lua_pushliteral(L, "TRAIN_QUICKPROP"); lua_pushnumber(L, FANN_TRAIN_QUICKPROP); lua_rawset(L, -3);
lua_pushliteral(L, "TRAIN_SARPROP"); lua_pushnumber(L, FANN_TRAIN_SARPROP); lua_rawset(L, -3);
/* error_function_enum */
lua_pushliteral(L, "ERRORFUNC_LINEAR"); lua_pushnumber(L, FANN_ERRORFUNC_LINEAR); lua_rawset(L, -3);
lua_pushliteral(L, "ERRORFUNC_TANH"); lua_pushnumber(L, FANN_ERRORFUNC_TANH); lua_rawset(L, -3);
/* stop_function_enum */
lua_pushliteral(L, "STOPFUNC_MSE"); lua_pushnumber(L, FANN_STOPFUNC_MSE); lua_rawset(L, -3);
lua_pushliteral(L, "STOPFUNC_BIT"); lua_pushnumber(L, FANN_STOPFUNC_BIT); lua_rawset(L, -3);
/* network_type_enum */
lua_pushliteral(L, "NETTYPE_LAYER"); lua_pushnumber(L, FANN_NETTYPE_LAYER); lua_rawset(L, -3);
lua_pushliteral(L, "NETTYPE_SHORTCUT"); lua_pushnumber(L, FANN_NETTYPE_SHORTCUT); lua_rawset(L, -3);
/* lfann_errno_enum */
lua_pushliteral(L, "E_NO_ERROR"); lua_pushnumber(L, FANN_E_NO_ERROR); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_OPEN_CONFIG_R"); lua_pushnumber(L, FANN_E_CANT_OPEN_CONFIG_R); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_OPEN_CONFIG_W"); lua_pushnumber(L, FANN_E_CANT_OPEN_CONFIG_W); lua_rawset(L, -3);
lua_pushliteral(L, "E_WRONG_CONFIG_VERSION"); lua_pushnumber(L, FANN_E_WRONG_CONFIG_VERSION); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_READ_CONFIG"); lua_pushnumber(L, FANN_E_CANT_READ_CONFIG); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_READ_NEURON"); lua_pushnumber(L, FANN_E_CANT_READ_NEURON); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_READ_CONNECTIONS"); lua_pushnumber(L, FANN_E_CANT_READ_CONNECTIONS); lua_rawset(L, -3);
lua_pushliteral(L, "E_WRONG_NUM_CONNECTIONS"); lua_pushnumber(L, FANN_E_WRONG_NUM_CONNECTIONS); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_OPEN_TD_W"); lua_pushnumber(L, FANN_E_CANT_OPEN_TD_W); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_OPEN_TD_R"); lua_pushnumber(L, FANN_E_CANT_OPEN_TD_R); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_READ_TD"); lua_pushnumber(L, FANN_E_CANT_READ_TD); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_ALLOCATE_MEM"); lua_pushnumber(L, FANN_E_CANT_ALLOCATE_MEM); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_TRAIN_ACTIVATION"); lua_pushnumber(L, FANN_E_CANT_TRAIN_ACTIVATION); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_USE_ACTIVATION"); lua_pushnumber(L, FANN_E_CANT_USE_ACTIVATION); lua_rawset(L, -3);
lua_pushliteral(L, "E_TRAIN_DATA_MISMATCH"); lua_pushnumber(L, FANN_E_TRAIN_DATA_MISMATCH); lua_rawset(L, -3);
lua_pushliteral(L, "E_CANT_USE_TRAIN_ALG"); lua_pushnumber(L, FANN_E_CANT_USE_TRAIN_ALG); lua_rawset(L, -3);
lua_pushliteral(L, "E_TRAIN_DATA_SUBSET"); lua_pushnumber(L, FANN_E_TRAIN_DATA_SUBSET); lua_rawset(L, -3);
lua_pushliteral(L, "E_INDEX_OUT_OF_BOUND"); lua_pushnumber(L, FANN_E_INDEX_OUT_OF_BOUND); lua_rawset(L, -3);
lua_pushliteral(L, "E_SCALE_NOT_PRESENT"); lua_pushnumber(L, FANN_E_SCALE_NOT_PRESENT); lua_rawset(L, -3);
return 1;
}