fix: correct LoRA initialization and forward pass under tensor parallelism#150
Open
chen2021673 wants to merge 4 commits into
Open
fix: correct LoRA initialization and forward pass under tensor parallelism#150chen2021673 wants to merge 4 commits into
chen2021673 wants to merge 4 commits into
Conversation
8459156 to
f918ae3
Compare
7bfc67d to
d2c8f7a
Compare
…ence Inline base and LoRA matmuls, add locally, then issue a single AllGather/AllReduce instead of two separate collective ops. The prior two-collective approach caused floating-point divergence in DDP loss. Also fix LoadLoRAWeights to slice sharded tensors by tp_rank when the checkpoint shape differs from the partitioned model shape.
Introduce new multi-stream Broadcast and Scatter APIs that take a root_rank_in_group argument, and rename the legacy single-stream variants to BroadCast_/Scatter_ to disambiguate.
d2c8f7a to
f4f2220
Compare
Chamberlain0w0
requested changes
Jun 11, 2026
| int64_t shard_size = dims[shard_dim] / tp_size; | ||
| int64_t start = parallel::tp_rank * shard_size; | ||
| auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size); | ||
| dst->CopyFrom(sliced); |
Contributor
There was a problem hiding this comment.
现在模型实现中,Attention 里面第一个 Linear 把 QKV 合并了(实际排列上是 [Q | K | V]),LoadFromLLMC 里面有对应的处理(如 TP=4,则切成 [Q0 Q1 Q2 Q3 | K0 K1 K2 K3 | V0 V1 V2 V3],然后依次拼接 [Qi | Ki | Vi] load 到每个 tp rank 上)。所以这块的切分逻辑对于这种情况没法适用,得想想看怎么能特别处理一下
|
|
||
| if (tp_rank == 0) { | ||
| if (config_.use_kaiming_a) { | ||
| init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); |
Contributor
There was a problem hiding this comment.
在这里做各个 tp rank 各自 init,可能会导致 DP 组之间无法达到权重一致的结果。现在的情况可能会因为都用默认 seed 导致恰好能对上,但是从代码看,语义上还是各初始化各的,DP 组之间有可能权重对不齐。原则上应该保证 DDP model 最后再做一步参数的广播/复制,从而确保在实际执行 forward 前各个能对应上的 dp rank 上面模型参数是相同的,但这里我知道改的话会不会影响其他 lora 基建;最简单的就是 tp_group->Broadcast({parameters_[kParamLoraAName]}, 0); 后面再加个 dp_group 的 broadcast。
| } | ||
| } | ||
|
|
||
| void SaveLoRAWeights(const std::shared_ptr<Module> &model, const std::string &filepath) { |
Contributor
There was a problem hiding this comment.
save 似乎 save 的是 tp local shard tensor,跟 load 逻辑对不上了?load 应该默认读 full tensor 吧?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix LoRA loss divergence under tensor/data parallel training by making LoRA tensor-parallel initialization deterministic and aligning the forward collective order between base linear and LoRA linear paths.
This PR changes TP LoRA parallel linear modules to compute the base shard and LoRA shard locally first, add them before communication, and then run a single TP collective on the combined output. It also adds rank-aware
Broadcast/ScatterProcessGroup APIs and updates LoRA weight loading to support loading full saved LoRA tensors into TP-sharded model parameters.Motivation
In tensor-parallel LoRA training, the base linear output and the LoRA update together form one logical linear output. Previously, the base path and LoRA path could run separate TP collectives and add their outputs afterward:
This changes the floating-point communication/reduction order compared with computing the logical local contribution first:
The separate-collective path can introduce numerical divergence across TP/DDP configurations.
Replicated or sharded LoRA parameters need rank-consistent initialization so every TP rank starts from the same logical LoRA weights.
Key Changes
LoRA parallel linear forward
Update
LoRAColumnParallelLinearto:base_shard + scaled_lora_shardbefore TP gather;gather_output_is enabled.Update
LoRARowParallelLinearto:base_shard + scaled_lora_shardbefore TP reduce/reduce-scatter;RowParallelLinearbehavior.This makes the LoRA path follow the same logical communication boundary as the base parallel linear layer and avoids running separate collectives for base and LoRA outputs.
Deterministic TP LoRA initialization
lora_Aconsistent across TP ranks.lora_Afrom TP rank 0 and distribute the correct local shards to each TP rank.ProcessGroup communication APIs
Add rank-aware communication APIs:
ProcessGroup::Broadcast(tensors, root_rank_in_group, async_op)ProcessGroup::Scatter(output_tensors, input_tensors, root_rank_in_group, async_op)Rename the old APIs to
BroadCast_/Scatter_to avoid semantic ambiguity.Update existing autograd communication wrappers to call the legacy renamed APIs where appropriate.
LoRA weight loading
LoadLoRAWeightsto handle shape mismatch between saved full LoRA tensors and TP-sharded destination tensors.tp_rankbefore copying into the local model parameter.Test
The loss values in the LoRA tests fluctuate slightly, which is expected because the tests now use saved fixed initialization values.
The performance variance is concentrated in the
lora/bfloat16/distopttests and is treated as normal fluctuation.