Files
FCES-native/benchmarks/bench_step.cpp
2026-05-20 00:18:23 +02:00

27 lines
685 B
C++

#include "fces/optimizer.hpp"
#include <benchmark/benchmark.h>
#include <torch/torch.h>
using namespace fces;
static void BM_OptimizerStep(benchmark::State &state) {
auto model = torch::nn::Linear(state.range(0), state.range(0) / 2);
std::vector<torch::Tensor> params;
for (auto &p : model->parameters())
params.push_back(p);
FCESOptimizer opt(params, FCESConfig{}.set_lr(1e-3f));
auto x = torch::randn({8, state.range(0)});
for (auto _ : state) {
auto y = model->forward(x);
auto loss = y.sum();
loss.backward();
opt.step();
opt.zero_grad();
benchmark::DoNotOptimize(loss);
}
}
BENCHMARK(BM_OptimizerStep)->Arg(64)->Arg(256)->Arg(1024);