| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| ops.def("adam_atan2_cuda_impl_(" | |
| "Tensor(a!)[] params, " | |
| "Tensor(b!)[] grads, " | |
| "Tensor(c!)[] exp_avgs, " | |
| "Tensor(d!)[] exp_avg_sqs, " | |
| "Tensor(e!)[] state_steps, " | |
| "float lr, " | |
| "float beta1, " | |
| "float beta2, " | |
| "float weight_decay) -> ()"); | |
| ops.impl("adam_atan2_cuda_impl_", torch::kCUDA, &adam_atan2::adam_atan2_cuda_impl_); | |
| } | |
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) | |