Skip to content

fix: correct LoRA initialization and forward pass under tensor parallelism#150

Open
chen2021673 wants to merge 4 commits into
masterfrom
lora_ddp_loss
Open

fix: correct LoRA initialization and forward pass under tensor parallelism#150
chen2021673 wants to merge 4 commits into
masterfrom
lora_ddp_loss

Conversation

@chen2021673

@chen2021673 chen2021673 commented Apr 30, 2026

Copy link
Copy Markdown
Contributor

Summary

Fix LoRA loss divergence under tensor/data parallel training by making LoRA tensor-parallel initialization deterministic and aligning the forward collective order between base linear and LoRA linear paths.

This PR changes TP LoRA parallel linear modules to compute the base shard and LoRA shard locally first, add them before communication, and then run a single TP collective on the combined output. It also adds rank-aware Broadcast/Scatter ProcessGroup APIs and updates LoRA weight loading to support loading full saved LoRA tensors into TP-sharded model parameters.

Motivation

In tensor-parallel LoRA training, the base linear output and the LoRA update together form one logical linear output. Previously, the base path and LoRA path could run separate TP collectives and add their outputs afterward:

base = TPCollective(base_i);
lora = TPCollective(lora_i);
out = base + lora;

This changes the floating-point communication/reduction order compared with computing the logical local contribution first:

out_i = base_i + lora_i;
out = TPCollective(out_i);

The separate-collective path can introduce numerical divergence across TP/DDP configurations.

Replicated or sharded LoRA parameters need rank-consistent initialization so every TP rank starts from the same logical LoRA weights.

Key Changes

LoRA parallel linear forward

  • Update LoRAColumnParallelLinear to:

    • compute the base shard locally;
    • compute the LoRA shard locally;
    • add base_shard + scaled_lora_shard before TP gather;
    • run a single gather when gather_output_ is enabled.
  • Update LoRARowParallelLinear to:

    • compute the base shard locally;
    • compute the LoRA shard locally;
    • add base_shard + scaled_lora_shard before TP reduce/reduce-scatter;
    • apply bias after the collective, matching the base RowParallelLinear behavior.

This makes the LoRA path follow the same logical communication boundary as the base parallel linear layer and avoids running separate collectives for base and LoRA outputs.

Deterministic TP LoRA initialization

  • Make replicated ColumnParallel lora_A consistent across TP ranks.
  • Initialize the logical RowParallel lora_A from TP rank 0 and distribute the correct local shards to each TP rank.
  • Ensure TP ranks observe the same logical LoRA initialization instead of independently sampling incompatible parameter values.

ProcessGroup communication APIs

  • Add rank-aware communication APIs:

    • ProcessGroup::Broadcast(tensors, root_rank_in_group, async_op)
    • ProcessGroup::Scatter(output_tensors, input_tensors, root_rank_in_group, async_op)
  • Rename the old APIs to BroadCast_ / Scatter_ to avoid semantic ambiguity.

  • Update existing autograd communication wrappers to call the legacy renamed APIs where appropriate.

LoRA weight loading

  • Extend LoadLoRAWeights to handle shape mismatch between saved full LoRA tensors and TP-sharded destination tensors.
  • When the destination tensor is sharded, slice the loaded full tensor according to tp_rank before copying into the local model parameter.

Test

The loss values in the LoRA tests fluctuate slightly, which is expected because the tests now use saved fixed initialization values.

image

The performance variance is concentrated in the lora/bfloat16/distopt tests and is treated as normal fluctuation.

image

…ence

Inline base and LoRA matmuls, add locally, then issue a single
AllGather/AllReduce instead of two separate collective ops. The prior
two-collective approach caused floating-point divergence in DDP loss.

Also fix LoadLoRAWeights to slice sharded tensors by tp_rank when the
checkpoint shape differs from the partitioned model shape.
Introduce new multi-stream Broadcast and Scatter APIs that take a
root_rank_in_group argument, and rename the legacy single-stream
variants to BroadCast_/Scatter_ to disambiguate.
int64_t shard_size = dims[shard_dim] / tp_size;
int64_t start = parallel::tp_rank * shard_size;
auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size);
dst->CopyFrom(sliced);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在模型实现中,Attention 里面第一个 Linear 把 QKV 合并了(实际排列上是 [Q | K | V]),LoadFromLLMC 里面有对应的处理(如 TP=4,则切成 [Q0 Q1 Q2 Q3 | K0 K1 K2 K3 | V0 V1 V2 V3],然后依次拼接 [Qi | Ki | Vi] load 到每个 tp rank 上)。所以这块的切分逻辑对于这种情况没法适用,得想想看怎么能特别处理一下


if (tp_rank == 0) {
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在这里做各个 tp rank 各自 init,可能会导致 DP 组之间无法达到权重一致的结果。现在的情况可能会因为都用默认 seed 导致恰好能对上,但是从代码看,语义上还是各初始化各的,DP 组之间有可能权重对不齐。原则上应该保证 DDP model 最后再做一步参数的广播/复制,从而确保在实际执行 forward 前各个能对应上的 dp rank 上面模型参数是相同的,但这里我知道改的话会不会影响其他 lora 基建;最简单的就是 tp_group->Broadcast({parameters_[kParamLoraAName]}, 0); 后面再加个 dp_group 的 broadcast。

}
}

void SaveLoRAWeights(const std::shared_ptr<Module> &model, const std::string &filepath) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save 似乎 save 的是 tp local shard tensor,跟 load 逻辑对不上了?load 应该默认读 full tensor 吧?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants