style: run clang-format and configure pre-commit hooks
This commit is contained in:
@@ -1,42 +1,45 @@
|
||||
#include "fces/optimizer.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/torch.h>
|
||||
#include "fces/optimizer.hpp"
|
||||
|
||||
using namespace fces;
|
||||
|
||||
TEST(OptimizerTest, Construction) {
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto& p : model->parameters()) params.push_back(p);
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto &p : model->parameters())
|
||||
params.push_back(p);
|
||||
|
||||
FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f));
|
||||
EXPECT_EQ(opt.step_count(), 0);
|
||||
FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f));
|
||||
EXPECT_EQ(opt.step_count(), 0);
|
||||
}
|
||||
|
||||
TEST(OptimizerTest, StepUpdatesCounter) {
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto& p : model->parameters()) params.push_back(p);
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto &p : model->parameters())
|
||||
params.push_back(p);
|
||||
|
||||
FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f));
|
||||
FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f));
|
||||
|
||||
// Simulate a training step
|
||||
auto x = torch::randn({2, 10});
|
||||
auto y = model->forward(x);
|
||||
auto loss = y.sum();
|
||||
loss.backward();
|
||||
opt.step();
|
||||
// Simulate a training step
|
||||
auto x = torch::randn({2, 10});
|
||||
auto y = model->forward(x);
|
||||
auto loss = y.sum();
|
||||
loss.backward();
|
||||
opt.step();
|
||||
|
||||
EXPECT_EQ(opt.step_count(), 1);
|
||||
EXPECT_EQ(opt.step_count(), 1);
|
||||
}
|
||||
|
||||
TEST(OptimizerTest, UpdateFitness) {
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto& p : model->parameters()) params.push_back(p);
|
||||
auto model = torch::nn::Linear(10, 5);
|
||||
std::vector<torch::Tensor> params;
|
||||
for (auto &p : model->parameters())
|
||||
params.push_back(p);
|
||||
|
||||
FCESOptimizer opt(params);
|
||||
opt.update_fitness(3.0f);
|
||||
opt.update_fitness(2.5f);
|
||||
// Should not crash
|
||||
FCESOptimizer opt(params);
|
||||
opt.update_fitness(3.0f);
|
||||
opt.update_fitness(2.5f);
|
||||
// Should not crash
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user