#include "fces/optimizer.hpp" #include #include using namespace fces; static void BM_OptimizerStep(benchmark::State &state) { auto model = torch::nn::Linear(state.range(0), state.range(0) / 2); std::vector params; for (auto &p : model->parameters()) params.push_back(p); FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f)); auto x = torch::randn({8, state.range(0)}); for (auto _ : state) { auto y = model->forward(x); auto loss = y.sum(); loss.backward(); opt.step(); opt.zero_grad(); benchmark::DoNotOptimize(loss); } } BENCHMARK(BM_OptimizerStep)->Arg(64)->Arg(256)->Arg(1024);