53 lines
1.9 KiB
C++
53 lines
1.9 KiB
C++
/**
|
|
* @file fces_native.cpp
|
|
* @brief Python bindings for FCES-native via pybind11.
|
|
*
|
|
* Exposes FCESOptimizer as a drop-in replacement for the Python implementation.
|
|
*
|
|
* Usage:
|
|
* from fces_native import FCESOptimizer
|
|
* opt = FCESOptimizer(model.parameters(), lr=1.6e-3, population_size=200)
|
|
*/
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
#include <torch/extension.h>
|
|
|
|
#include "fces/config.hpp"
|
|
#include "fces/optimizer.hpp"
|
|
|
|
namespace py = pybind11;
|
|
|
|
PYBIND11_MODULE(fces_native, m) {
|
|
m.doc() = "FCES-native: High-performance C++ FCES optimizer";
|
|
|
|
py::class_<fces::FCESConfig>(m, "FCESConfig")
|
|
.def(py::init<>())
|
|
.def_readwrite("lr", &fces::FCESConfig::lr)
|
|
.def_readwrite("population_size", &fces::FCESConfig::population_size)
|
|
.def_readwrite("total_steps", &fces::FCESConfig::total_steps)
|
|
.def_readwrite("grokking_coefficient",
|
|
&fces::FCESConfig::grokking_coefficient)
|
|
.def_readwrite("direct_construction",
|
|
&fces::FCESConfig::direct_construction);
|
|
|
|
py::class_<fces::FCESOptimizer>(m, "FCESOptimizer")
|
|
.def(py::init<std::vector<torch::Tensor>, fces::FCESConfig>(),
|
|
py::arg("params"), py::arg("config") = fces::FCESConfig{})
|
|
.def("step", [](fces::FCESOptimizer &self) { return self.step(); })
|
|
.def("update_fitness", &fces::FCESOptimizer::update_fitness)
|
|
.def("backup_to_ram", &fces::FCESOptimizer::backup_to_ram)
|
|
.def("restore_from_ram", &fces::FCESOptimizer::restore_from_ram)
|
|
.def("step_count", &fces::FCESOptimizer::step_count)
|
|
.def("calculate_sparsity", &fces::FCESOptimizer::calculate_sparsity)
|
|
.def("zero_grad", [](fces::FCESOptimizer &self) {
|
|
for (auto &group : self.param_groups()) {
|
|
for (auto &p : group.params()) {
|
|
if (p.grad().defined()) {
|
|
p.grad().zero_();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|