mirror of
https://github.com/lxsang/antd-lua-plugin
synced 2025-01-03 21:08:22 +01:00
36 lines
1.0 KiB
C++
36 lines
1.0 KiB
C++
#include "fann_test_train.h"
|
|
|
|
using namespace std;
|
|
|
|
void FannTestTrain::SetUp() {
|
|
FannTest::SetUp();
|
|
}
|
|
|
|
void FannTestTrain::TearDown() {
|
|
FannTest::TearDown();
|
|
|
|
}
|
|
|
|
TEST_F(FannTestTrain, TrainOnDateSimpleXor) {
|
|
neural_net net(LAYER, 3, 2, 3, 1);
|
|
|
|
data.set_train_data(4, 2, xorInput, 1, xorOutput);
|
|
net.train_on_data(data, 100, 100, 0.001);
|
|
|
|
EXPECT_LT(net.get_MSE(), 0.001);
|
|
EXPECT_LT(net.test_data(data), 0.001);
|
|
}
|
|
|
|
TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
|
|
neural_net net(LAYER, 3, 2, 3, 1);
|
|
|
|
for(int i = 0; i < 100000; i++) {
|
|
net.train((fann_type*) (const fann_type[]) {0.0, 0.0}, (fann_type*) (const fann_type[]) {0.0});
|
|
net.train((fann_type*) (const fann_type[]) {1.0, 0.0}, (fann_type*) (const fann_type[]) {1.0});
|
|
net.train((fann_type*) (const fann_type[]) {0.0, 1.0}, (fann_type*) (const fann_type[]) {1.0});
|
|
net.train((fann_type*) (const fann_type[]) {1.0, 1.0}, (fann_type*) (const fann_type[]) {0.0});
|
|
}
|
|
|
|
EXPECT_LT(net.get_MSE(), 0.01);
|
|
}
|