Feat: Activation Checkpointing#154
Conversation
d87926e to
ba8db8f
Compare
| namespace gpt2 { | ||
| std::shared_ptr<infini_train::nn::TransformerModel> LoadFromLLMC(const std::string &filepath); | ||
| std::shared_ptr<infini_train::nn::TransformerModel> | ||
| LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &runtime_config); |
There was a problem hiding this comment.
这个接口需要吗,namespace gpt2下直接在LoadFromLLMC里写死gpt2::GPT2Config()行不。另外现在的逻辑是如果用户传入了 llmc_path,应该拿 .bin 里存储的 config 信息重新修正 model_config,会覆盖这里的runtime_config吗
There was a problem hiding this comment.
其实我想的是传入 runtime config 而已,之前和现在的 namespace gpt2 下确实是直接在 LoadFromLLMC 里写死 gpt2::GPT2Config() 的,这里传入的只会针对 runtime config 的部分做个更新。考虑到后面的修改提出了一个ActivationRecomputeOptions,我这里就直接用 ActivationRecomputeOptions 类型作为参数传入了,避免混淆。
| } | ||
| } | ||
|
|
||
| void ApplyRuntimeRecomputeConfig(nn::TransformerConfig *config, const nn::TransformerConfig &runtime_config) { |
There was a problem hiding this comment.
GPT2/LLaMA3 里都有这个重复的 ApplyRuntimeRecomputeConfig,能不能抽成公共的。另外参数要不要用 CheckpointOptions 包一下
There was a problem hiding this comment.
抽出来了,但是先叫做 ActivationRecomputeOptions,感觉 CheckpointOptions 容易和模型 ckpt 混淆。
| const std::vector<std::shared_ptr<Tensor>> &inputs, bool use_reentrant, | ||
| bool preserve_rng_state, bool determinism_check, bool early_stop) { | ||
| if (preserve_rng_state) { | ||
| // TODO(zbl): Preserve and restore RNG state for CPU/CUDA. |
|
|
||
| namespace infini_train::utils::checkpoint { | ||
|
|
||
| class CheckpointFunction : public autograd::Function { |
There was a problem hiding this comment.
这个类放在utils里了,但总感觉不应该作为外部 API 暴露出去,要不加到匿名 namespace 里?
| @@ -1,12 +1,15 @@ | |||
| #include "infini_train/include/nn/modules/transformer/transformer.h" | |||
There was a problem hiding this comment.
之后扩展 recompute 其他功能(如selective)的话就不适合全放在transformer.cc里了,之后可以拆分activation_recompute.cc。本PR可以不做修改
There was a problem hiding this comment.
确实,这块确实没太想清楚,主要是 megatron 的实现里面也把很多重计算逻辑融在了 transformer 模型层,之后可以再讨论下
| } | ||
| frame->autocast_state = GetAutocastState(); | ||
|
|
||
| autograd::Function::SavedTensorHooks hooks; |
There was a problem hiding this comment.
在有两套 SavedTensorHooks 的情况下,这个hooks能不能改名叫placeholder_hooks,方便理解一点。
| tls_autocast_context.autocast_dtype}; | ||
| } | ||
|
|
||
| inline void SetAutocastState(const AutocastState &state) { |
There was a problem hiding this comment.
这个SetAutocastState能不能也写成guard
|
|
||
| std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensors) { | ||
| std::vector<std::shared_ptr<Tensor>> | ||
| Module::ForwardWithHooks(const std::vector<std::shared_ptr<Tensor>> &input_tensors) { |
There was a problem hiding this comment.
这个能不能叫Module::Call,因为我理解这里才是通用入口,对应 PyTorch 里的 Module.call 语义
9cdf73c to
a38a4d3
Compare
No description provided.