Skip to content

Feat: Activation Checkpointing#154

Open
Chamberlain0w0 wants to merge 3 commits into
masterfrom
feat/activation_checkpointing
Open

Feat: Activation Checkpointing#154
Chamberlain0w0 wants to merge 3 commits into
masterfrom
feat/activation_checkpointing

Conversation

@Chamberlain0w0

Copy link
Copy Markdown
Contributor

No description provided.

@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/activation_checkpointing branch from d87926e to ba8db8f Compare June 3, 2026 08:45
@Chamberlain0w0 Chamberlain0w0 changed the title [WIP] Feat: Activation Checkpointing Feat: Activation Checkpointing Jun 5, 2026
Comment thread example/gpt2/checkpoint_loader.h Outdated
namespace gpt2 {
std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
std::shared_ptr<infini_train::nn::TransformerModel>
LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个接口需要吗,namespace gpt2下直接在LoadFromLLMC里写死gpt2::GPT2Config()行不。另外现在的逻辑是如果用户传入了 llmc_path,应该拿 .bin 里存储的 config 信息重新修正 model_config,会覆盖这里的runtime_config吗

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其实我想的是传入 runtime config 而已,之前和现在的 namespace gpt2 下确实是直接在 LoadFromLLMC 里写死 gpt2::GPT2Config() 的,这里传入的只会针对 runtime config 的部分做个更新。考虑到后面的修改提出了一个ActivationRecomputeOptions,我这里就直接用 ActivationRecomputeOptions 类型作为参数传入了,避免混淆。

Comment thread example/gpt2/checkpoint_loader.cc Outdated
}
}

void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPT2/LLaMA3 里都有这个重复的 ApplyRuntimeRecomputeConfig,能不能抽成公共的。另外参数要不要用 CheckpointOptions 包一下

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抽出来了,但是先叫做 ActivationRecomputeOptions,感觉 CheckpointOptions 容易和模型 ckpt 混淆。

const std::vector<std::shared_ptr<Tensor>> &inputs, bool use_reentrant,
bool preserve_rng_state, bool determinism_check, bool early_stop) {
if (preserve_rng_state) {
// TODO(zbl): Preserve and restore RNG state for CPU/CUDA.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里留一个 LOG(FATAL) 吧

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,改了

Comment thread infini_train/include/utils/checkpoint.h Outdated

namespace infini_train::utils::checkpoint {

class CheckpointFunction : public autograd::Function {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类放在utils里了,但总感觉不应该作为外部 API 暴露出去,要不加到匿名 namespace 里?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有道理,改了

@@ -1,12 +1,15 @@
#include "infini_train/include/nn/modules/transformer/transformer.h"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之后扩展 recompute 其他功能(如selective)的话就不适合全放在transformer.cc里了,之后可以拆分activation_recompute.cc。本PR可以不做修改

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,这块确实没太想清楚,主要是 megatron 的实现里面也把很多重计算逻辑融在了 transformer 模型层,之后可以再讨论下

Comment thread infini_train/src/utils/checkpoint.cc Outdated
}
frame->autocast_state = GetAutocastState();

autograd::Function::SavedTensorHooks hooks;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在有两套 SavedTensorHooks 的情况下,这个hooks能不能改名叫placeholder_hooks,方便理解一点。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,改了

tls_autocast_context.autocast_dtype};
}

inline void SetAutocastState(const AutocastState &state) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个SetAutocastState能不能也写成guard

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok 改了

Comment thread infini_train/src/nn/modules/module.cc Outdated

std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
std::vector<std::shared_ptr<Tensor>>
Module::ForwardWithHooks(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {

@chen2021673 chen2021673 Jun 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能不能叫Module::Call,因为我理解这里才是通用入口,对应 PyTorch 里的 Module.call 语义

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,改了

@Chamberlain0w0 Chamberlain0w0 force-pushed the feat/activation_checkpointing branch from 9cdf73c to a38a4d3 Compare June 11, 2026 01:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants