Add Stable Diffusion 1.5 support with data-parallel inference#418
Add Stable Diffusion 1.5 support with data-parallel inference#418csgoogle wants to merge 1 commit into
Conversation
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully adds support for Stable Diffusion 1.5 to MaxDiffusion and implements genuine data-parallel inference by sharding latents and context along the data mesh axis. The changes are elegant, well-structured, and provide significant performance improvements by avoiding redundant recomputation of batches across devices.
🔍 General Feedback
- Excellent Alignment with Checkpoint Semantics: Iterating over the actual timesteps shape in the loop instead of hardcoded steps ensures perfect alignment with schedulers like PNDM which emit an extra timestep when
skip_prk_stepsis enabled. - Robust Parallelism: Forcing genuine batch-sharding constraints is a clean and robust approach to propagation of data parallelism across UNet and VAE layers.
- Configuration-driven Design: Transitioning scheduler instantiation to use config-driven
create_schedulermirrors the SDXL pipeline beautifully and increases code reuse. - Robustness Improvement: An issue was identified in the config-override fallback mechanism where explicit falsy overrides (such as
False) were ignored, and checkpoint defaults were bypassed. An inline code suggestion has been provided to resolve this.
5de6089 to
308e742
Compare
Add base15.yml for SD 1.5 (PyTorch weights via from_pt, PNDM/epsilon scheduler) and wire generate.py to it: - Build the sampler from the checkpoint's scheduler config via create_scheduler instead of a hardcoded DDIM scheduler, and iterate the full PNDM schedule (skip_prk_steps emits one extra timestep). - Shard the latent batch over the data axis with sharding constraints plus out_shardings so inference runs data parallel instead of replicating the whole batch on every device. Sub-device batches replicate. - Make override_scheduler_config tolerant of scheduler configs that omit keys (e.g. SD 1.5's PNDM config).
308e742 to
164ca87
Compare
|
🤖 Hi @csgoogle, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @csgoogle, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details. |
Summary
Adds Stable Diffusion 1.5 to MaxDiffusion and makes its inference data parallel.
configs/base15.yml(new) — SD 1.5 config. Mirrorsbase14.yml(same architecture, different weights), points at thestable-diffusion-v1-5checkpoint, setsfrom_pt: True(upstream ships PyTorch weights only), and defaults to the checkpoint's PNDM / epsilon scheduler.generate.pycreate_scheduler(config-driven, mirrors the SDXL path) instead of a hardcodedFlaxDDIMScheduler, and iterate the full PNDM schedule (skip_prk_stepsemits one extra timestep).dataaxis usingwith_sharding_constraint+ a batch-shardedout_shardings, so GSPMD propagates data-parallelism through the whole UNet/VAE instead of replicating the entire batch on every device. A singleget_batch_shardinghelper is the source of truth and replicates for sub-device batches (per_device_batch_size < 1).override_scheduler_configis now tolerant of scheduler configs that omit keys (e.g. SD 1.5's older PNDM config).Why
The previous generate path declared a
datamesh axis but the program ran fully replicated (device_putinside the jit is a weak hint GSPMD ignored), so each chip recomputed the whole batch. Forcing the batch sharding makes inference genuinely data parallel.Performance
TPU7x (8 chips), SD 1.5 1024px, 20-step PNDM, 8 images: 0.130s . Scales cleanly to larger batches.
Test plan
Prompt: "A cinematic photo of a glass greenhouse on a snowy mountain"


