From 35ba10e4cc3f5748fe93729f54d1d727cc02ecdc Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 4 Jun 2026 13:56:52 +0200 Subject: [PATCH 1/7] PRX pixel pipeline --- scripts/convert_prx_to_diffusers.py | 431 +++++++++++------- src/diffusers/__init__.py | 2 + .../models/transformers/transformer_prx.py | 76 ++- src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/prx/__init__.py | 2 + src/diffusers/pipelines/prx/pipeline_prx.py | 47 +- .../pipelines/prx/pipeline_prx_pixel.py | 97 ++++ .../dummy_torch_and_transformers_objects.py | 15 + 8 files changed, 484 insertions(+), 190 deletions(-) create mode 100644 src/diffusers/pipelines/prx/pipeline_prx_pixel.py diff --git a/scripts/convert_prx_to_diffusers.py b/scripts/convert_prx_to_diffusers.py index 00bb3f6fe99e..b93589e4f94a 100644 --- a/scripts/convert_prx_to_diffusers.py +++ b/scripts/convert_prx_to_diffusers.py @@ -1,6 +1,21 @@ #!/usr/bin/env python3 """ -Script to convert PRX checkpoint from original codebase to diffusers format. +Script to convert a PRX checkpoint from the original codebase to diffusers format. + +Supports two checkpoint layouts: + * a single-file ``torch.save`` checkpoint (``.pt`` / ``.pth``), and + * a sharded torch Distributed Checkpoint (DCP) directory (``.metadata`` + ``*.distcp``), + as produced by Composer/FSDP training. + +and three model variants (``--variant``): + * ``flux`` : latent-space, AutoencoderKL (16ch, patch 2) -> PRXPipeline + * ``dc-ae`` : latent-space, AutoencoderDC (32ch, patch 1) -> PRXPipeline + * ``pixel`` : pixel-space PRXPixel (3ch RGB, patch 16, bottleneck img_in, resolution embedder, + Qwen3-VL text tower, no VAE) -> PRXPixelPipeline + +The block-level parameter remapping is shared across all variants; the ``pixel`` variant's extra +parameters (``img_in.{0,1}`` bottleneck and ``resolution_embedder.mlp.*``) carry over with no +renaming, so a single mapping generalises across versions. """ import argparse @@ -8,17 +23,23 @@ import os import sys from dataclasses import asdict, dataclass -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch from safetensors.torch import save_file from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel -from diffusers.pipelines.prx import PRXPipeline DEFAULT_RESOLUTION = 512 +# Default location of the denoiser weights inside a research (Composer) checkpoint. +DENOISER_PREFIX = "state.model.denoiser." + +# Qwen3-VL embedding tower used by the pixel variant. +PIXEL_TEXT_ENCODER_REPO = "Qwen/Qwen3-VL-Embedding-2B" +PIXEL_PROMPT_MAX_TOKENS = 256 + @dataclass(frozen=True) class PRXBase: @@ -31,6 +52,8 @@ class PRXBase: theta: int = 10_000 time_factor: float = 1000.0 time_max_period: int = 10_000 + bottleneck_size: Optional[int] = None + resolution_embeds: bool = False @dataclass(frozen=True) @@ -45,150 +68,172 @@ class PRXDCAE(PRXBase): patch_size: int = 1 -def build_config(vae_type: str) -> Tuple[dict, int]: - if vae_type == "flux": - cfg = PRXFlux() - elif vae_type == "dc-ae": - cfg = PRXDCAE() - else: - raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") +@dataclass(frozen=True) +class PRXPixel(PRXBase): + # Pixel-space RGB diffusion (PRXPixel / 7B). + in_channels: int = 3 + patch_size: int = 16 + context_in_dim: int = 2048 # Qwen3-VL-Embedding-2B hidden size + hidden_size: int = 3584 + num_heads: int = 28 + depth: int = 24 + axes_dim: Tuple[int, int] = (64, 64) + bottleneck_size: int = 768 + resolution_embeds: bool = True - config_dict = asdict(cfg) - config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + +VARIANTS = {"flux": PRXFlux, "dc-ae": PRXDCAE, "pixel": PRXPixel} + + +def build_config(variant: str) -> dict: + if variant not in VARIANTS: + raise ValueError(f"Unsupported variant: {variant}. Choose from {list(VARIANTS)}") + config_dict = asdict(VARIANTS[variant]()) + config_dict["axes_dim"] = list(config_dict["axes_dim"]) + if config_dict["bottleneck_size"] is None: + # Keep config.json clean for variants that don't use the bottleneck. + config_dict.pop("bottleneck_size") return config_dict -def create_parameter_mapping(depth: int) -> dict: - """Create mapping from old parameter names to new diffusers names.""" +# --------------------------------------------------------------------------- +# Checkpoint loading +# --------------------------------------------------------------------------- +def _flatten(nested: dict, parent: str = "") -> Dict[str, torch.Tensor]: + """Flatten the nested dict returned by DCP loading back into dotted keys.""" + flat = {} + for k, v in nested.items(): + key = f"{parent}.{k}" if parent else str(k) + if isinstance(v, dict): + flat.update(_flatten(v, key)) + else: + flat[key] = v + return flat + + +def _is_dcp_dir(path: str) -> bool: + return os.path.isdir(path) and os.path.exists(os.path.join(path, ".metadata")) + + +def load_denoiser_state_dict(checkpoint_path: str, prefix: str = DENOISER_PREFIX) -> Dict[str, torch.Tensor]: + """Load just the denoiser weights from a research checkpoint (DCP dir or single file).""" + if _is_dcp_dir(checkpoint_path): + print(f"Loading DCP (distributed) checkpoint from: {checkpoint_path}") + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys + + reader = FileSystemReader(checkpoint_path) + meta = reader.read_metadata() + keys = {k for k in meta.state_dict_metadata if k.startswith(prefix)} + if not keys: + raise ValueError(f"No keys with prefix '{prefix}' found in {checkpoint_path}") + print(f" Reading {len(keys)} denoiser tensors (skipping optimizer / EMA / RNG state)...") + nested = _load_state_dict_from_keys(keys, storage_reader=reader) + flat = _flatten(nested) + state_dict = {k[len(prefix):]: v for k, v in flat.items() if k.startswith(prefix)} + else: + print(f"Loading single-file checkpoint from: {checkpoint_path}") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + if isinstance(ckpt, dict): + state_dict = ckpt.get("model", ckpt.get("state_dict", ckpt)) + else: + state_dict = ckpt + # Strip a denoiser prefix if the keys carry one. + if any(k.startswith(prefix) for k in state_dict): + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} - # Key mappings for structural changes - mapping = {} + print(f"✓ Loaded {len(state_dict)} denoiser parameters") + return state_dict - # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention) + +# --------------------------------------------------------------------------- +# Parameter name remapping (research -> diffusers) +# --------------------------------------------------------------------------- +def create_parameter_mapping(depth: int) -> dict: + """Map old parameter names (layers on PRXBlock) to diffusers names (layers on PRXAttention).""" + mapping = {} for i in range(depth): - # QKV projections moved to attention module mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" - - # QK norm moved to attention module and renamed to match Attention's qk_norm structure mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight" mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight" - - # K norm for text tokens moved to attention module mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight" mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight" - - # Attention output projection mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" - return mapping def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> dict[str, torch.Tensor]: - """Convert old checkpoint parameters to new diffusers format.""" - - print("Converting checkpoint parameters...") - + """Apply the block remapping. Unmapped keys (img_in, time_in, txt_in, resolution_embedder, final_layer) + carry over unchanged.""" mapping = create_parameter_mapping(depth) - converted_state_dict = {} - + converted = {} + num_mapped = 0 for key, value in old_state_dict.items(): - new_key = key - - # Apply specific mappings if needed - if key in mapping: - new_key = mapping[key] - print(f" Mapped: {key} -> {new_key}") - - converted_state_dict[new_key] = value - - print(f"✓ Converted {len(converted_state_dict)} parameters") - return converted_state_dict + new_key = mapping.get(key, key) + if new_key != key: + num_mapped += 1 + converted[new_key] = value + print(f"✓ Converted {len(converted)} parameters ({num_mapped} block keys remapped)") + return converted def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel: - """Create and load PRXTransformer2DModel from old checkpoint.""" - - print(f"Loading checkpoint from: {checkpoint_path}") - - # Load old checkpoint - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - - old_checkpoint = torch.load(checkpoint_path, map_location="cpu") + """Create and load a PRXTransformer2DModel from a research checkpoint.""" + state_dict = load_denoiser_state_dict(checkpoint_path) + converted = convert_checkpoint_parameters(state_dict, depth=int(config["depth"])) - # Handle different checkpoint formats - if isinstance(old_checkpoint, dict): - if "model" in old_checkpoint: - state_dict = old_checkpoint["model"] - elif "state_dict" in old_checkpoint: - state_dict = old_checkpoint["state_dict"] - else: - state_dict = old_checkpoint - else: - state_dict = old_checkpoint - - print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") - - # Convert parameter names if needed - model_depth = int(config.get("depth", 16)) - converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) - - # Create transformer with config print("Creating PRXTransformer2DModel...") transformer = PRXTransformer2DModel(**config) - # Load state dict - print("Loading converted parameters...") - missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) - - if missing_keys: - print(f"⚠ Missing keys: {missing_keys}") - if unexpected_keys: - print(f"⚠ Unexpected keys: {unexpected_keys}") - - if not missing_keys and not unexpected_keys: - print("✓ All parameters loaded successfully!") - + # Match the checkpoint dtype (research saves bf16). + param_dtype = next(iter(converted.values())).dtype + transformer = transformer.to(param_dtype) + + missing, unexpected = transformer.load_state_dict(converted, strict=False) + if missing: + print(f"⚠ Missing keys ({len(missing)}): {missing}") + if unexpected: + print(f"⚠ Unexpected keys ({len(unexpected)}): {unexpected}") + if not missing and not unexpected: + print("✓ All parameters loaded successfully (0 missing, 0 unexpected)!") + else: + raise RuntimeError("Checkpoint did not load cleanly; see missing/unexpected keys above.") return transformer +# --------------------------------------------------------------------------- +# Auxiliary components +# --------------------------------------------------------------------------- def create_scheduler_config(output_path: str, shift: float): - """Create FlowMatchEulerDiscreteScheduler config.""" - scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift} - scheduler_path = os.path.join(output_path, "scheduler") os.makedirs(scheduler_path, exist_ok=True) - with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f: json.dump(scheduler_config, f, indent=2) - print("✓ Created scheduler config") -def download_and_save_vae(vae_type: str, output_path: str): - """Download and save VAE to local directory.""" +def download_and_save_vae(variant: str, output_path: str): from diffusers import AutoencoderDC, AutoencoderKL vae_path = os.path.join(output_path, "vae") os.makedirs(vae_path, exist_ok=True) - - if vae_type == "flux": + if variant == "flux": print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...") vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") else: # dc-ae print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...") vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") - vae.save_pretrained(vae_path) print(f"✓ Saved VAE to {vae_path}") -def download_and_save_text_encoder(output_path: str): - """Download and save T5Gemma text encoder and tokenizer.""" +def download_and_save_t5gemma_text_encoder(output_path: str): from transformers import GemmaTokenizerFast from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel @@ -198,144 +243,176 @@ def download_and_save_text_encoder(output_path: str): os.makedirs(tokenizer_path, exist_ok=True) print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") - t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") - - # Extract and save only the encoder - t5gemma_encoder = t5gemma_model.encoder + t5gemma_encoder = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2").encoder t5gemma_encoder.save_pretrained(text_encoder_path) print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") - print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") tokenizer.model_max_length = 256 tokenizer.save_pretrained(tokenizer_path) print(f"✓ Saved tokenizer to {tokenizer_path}") + return "T5GemmaEncoder", "prx" -def create_model_index(vae_type: str, default_image_size: int, output_path: str): - """Create model_index.json for the pipeline.""" +def download_and_save_qwen_text_encoder(output_path: str, repo: str = PIXEL_TEXT_ENCODER_REPO): + """Download the Qwen3-VL embedding tower, keep only the text backbone, and save it.""" + from transformers import AutoTokenizer, Qwen3VLForConditionalGeneration - if vae_type == "flux": - vae_class = "AutoencoderKL" - else: # dc-ae - vae_class = "AutoencoderDC" - - model_index = { - "_class_name": "PRXPipeline", - "_diffusers_version": "0.31.0.dev0", - "_name_or_path": os.path.basename(output_path), - "default_sample_size": default_image_size, - "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["prx", "T5GemmaEncoder"], - "tokenizer": ["transformers", "GemmaTokenizerFast"], - "transformer": ["diffusers", "PRXTransformer2DModel"], - "vae": ["diffusers", vae_class], - } - - model_index_path = os.path.join(output_path, "model_index.json") - with open(model_index_path, "w") as f: + text_encoder_path = os.path.join(output_path, "text_encoder") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(text_encoder_path, exist_ok=True) + os.makedirs(tokenizer_path, exist_ok=True) + + print(f"Downloading Qwen3-VL model from {repo} (vision tower will be discarded)...") + full_model = Qwen3VLForConditionalGeneration.from_pretrained( + repo, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + text_encoder = full_model.model.language_model + text_encoder.save_pretrained(text_encoder_path) + encoder_class = type(text_encoder).__name__ + del full_model + print(f"✓ Saved {encoder_class} to {text_encoder_path}") + + tokenizer = AutoTokenizer.from_pretrained(repo) + tokenizer.model_max_length = PIXEL_PROMPT_MAX_TOKENS + tokenizer.save_pretrained(tokenizer_path) + tokenizer_class = type(tokenizer).__name__ + print(f"✓ Saved tokenizer ({tokenizer_class}) to {tokenizer_path}") + return encoder_class, "transformers", tokenizer_class + + +def create_model_index( + variant: str, + default_image_size: int, + output_path: str, + text_encoder_class: str, + text_encoder_lib: str, + tokenizer_class: str, +): + if variant == "pixel": + model_index = { + "_class_name": "PRXPixelPipeline", + "_diffusers_version": "0.37.0.dev0", + "_name_or_path": os.path.basename(output_path), + "default_sample_size": default_image_size, + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": [text_encoder_lib, text_encoder_class], + "tokenizer": ["transformers", tokenizer_class], + "transformer": ["diffusers", "PRXTransformer2DModel"], + "vae": [None, None], # pixel-space: no VAE + } + else: + vae_class = "AutoencoderKL" if variant == "flux" else "AutoencoderDC" + model_index = { + "_class_name": "PRXPipeline", + "_diffusers_version": "0.37.0.dev0", + "_name_or_path": os.path.basename(output_path), + "default_sample_size": default_image_size, + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": [text_encoder_lib, text_encoder_class], + "tokenizer": ["transformers", tokenizer_class], + "transformer": ["diffusers", "PRXTransformer2DModel"], + "vae": ["diffusers", vae_class], + } + with open(os.path.join(output_path, "model_index.json"), "w") as f: json.dump(model_index, f, indent=2) + print(f"✓ Wrote model_index.json ({model_index['_class_name']})") def main(args): - # Validate inputs - if not os.path.exists(args.checkpoint_path): - raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") - - config = build_config(args.vae_type) - - # Create output directory + config = build_config(args.variant) os.makedirs(args.output_path, exist_ok=True) print(f"✓ Output directory: {args.output_path}") + print(f"✓ Variant: {args.variant} | config: {config}") - # Create transformer from checkpoint + # ---- transformer ---- transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) - - # Save transformer transformer_path = os.path.join(args.output_path, "transformer") os.makedirs(transformer_path, exist_ok=True) - - # Save config with open(os.path.join(transformer_path, "config.json"), "w") as f: json.dump(config, f, indent=2) + save_file(transformer.state_dict(), os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + num_params = sum(p.numel() for p in transformer.parameters()) + print(f"✓ Saved transformer to {transformer_path} ({num_params:,} params)") - # Save model weights as safetensors - state_dict = transformer.state_dict() - save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) - print(f"✓ Saved transformer to {transformer_path}") - - # Create scheduler config + # ---- scheduler ---- create_scheduler_config(args.output_path, args.shift) - download_and_save_vae(args.vae_type, args.output_path) - download_and_save_text_encoder(args.output_path) - - # Create model_index.json - create_model_index(args.vae_type, args.resolution, args.output_path) + # ---- vae (none for pixel) ---- + if args.variant != "pixel" and not args.skip_vae: + download_and_save_vae(args.variant, args.output_path) + + # ---- text encoder + tokenizer ---- + text_encoder_class, text_encoder_lib, tokenizer_class = "T5GemmaEncoder", "prx", "GemmaTokenizerFast" + if not args.skip_text_encoder: + if args.variant == "pixel": + text_encoder_class, text_encoder_lib, tokenizer_class = download_and_save_qwen_text_encoder( + args.output_path + ) + else: + text_encoder_class, text_encoder_lib = download_and_save_t5gemma_text_encoder(args.output_path) + tokenizer_class = "GemmaTokenizerFast" - # Verify the pipeline can be loaded - try: - pipeline = PRXPipeline.from_pretrained(args.output_path) - print("Pipeline loaded successfully!") - print(f"Transformer: {type(pipeline.transformer).__name__}") - print(f"VAE: {type(pipeline.vae).__name__}") - print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") - print(f"Scheduler: {type(pipeline.scheduler).__name__}") + create_model_index( + args.variant, args.resolution, args.output_path, text_encoder_class, text_encoder_lib, tokenizer_class + ) - # Display model info - num_params = sum(p.numel() for p in pipeline.transformer.parameters()) - print(f"✓ Transformer parameters: {num_params:,}") + # ---- verify ---- + if args.skip_text_encoder: + print("Skipped text encoder; verifying the transformer reloads from disk...") + reloaded = PRXTransformer2DModel.from_pretrained(transformer_path) + print(f"✓ Transformer reloaded: {type(reloaded).__name__} ({sum(p.numel() for p in reloaded.parameters()):,} params)") + else: + from diffusers import PRXPipeline, PRXPixelPipeline - except Exception as e: - print(f"Pipeline verification failed: {e}") - return False + pipe_cls = PRXPixelPipeline if args.variant == "pixel" else PRXPipeline + pipeline = pipe_cls.from_pretrained(args.output_path) + print("Pipeline loaded successfully!") + print(f" Pipeline: {type(pipeline).__name__}") + print(f" Transformer: {type(pipeline.transformer).__name__}") + print(f" VAE: {type(pipeline.vae).__name__ if pipeline.vae is not None else None}") + print(f" Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f" Scheduler: {type(pipeline.scheduler).__name__}") print("Conversion completed successfully!") - print(f"Converted pipeline saved to: {args.output_path}") - print(f"VAE type: {args.vae_type}") - return True if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format") - parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )" + "--checkpoint_path", + type=str, + required=True, + help="Path to the original PRX checkpoint (a .pt/.pth file or a DCP directory).", ) - parser.add_argument( "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" ) - parser.add_argument( - "--vae_type", + "--variant", type=str, - choices=["flux", "dc-ae"], + choices=list(VARIANTS), required=True, - help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", + help="Model variant: 'flux' (AutoencoderKL, 16ch), 'dc-ae' (AutoencoderDC, 32ch), or 'pixel' (RGB PRXPixel).", ) - parser.add_argument( "--resolution", type=int, - choices=[256, 512, 1024], default=DEFAULT_RESOLUTION, - help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", + help="Default sample size for the pipeline (e.g. 256, 512, 1024).", ) - + parser.add_argument("--shift", type=float, default=3.0, help="Shift for the scheduler") parser.add_argument( - "--shift", - type=float, - default=3.0, - help="Shift for the scheduler", + "--skip_text_encoder", + action="store_true", + help="Skip downloading/saving the text encoder + tokenizer (validate the transformer only).", ) + parser.add_argument("--skip_vae", action="store_true", help="Skip downloading/saving the VAE.") args = parser.parse_args() - try: - success = main(args) - if not success: + if not main(args): sys.exit(1) except Exception as e: print(f"Conversion failed: {e}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 61ccfd85c192..fd4ce0f10b6a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -587,6 +587,7 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", + "PRXPixelPipeline", "PRXPipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", @@ -1330,6 +1331,7 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, + PRXPixelPipeline, PRXPipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 3c8e8ae4e2c9..7919cf1986e9 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -322,6 +322,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class PRXResolutionEmbedder(nn.Module): + r""" + Embeds the spatial resolution `(height, width)` of the latent into a vector that is added to the timestep + embedding, so the model can condition its modulation on the generation resolution. + + A sinusoidal embedding of dimension 128 is built for the height and the width separately and concatenated into a + 256-dim vector, which is then projected to `hidden_size` by a 2-layer MLP. This matches the `"vec"` mode of the + resolution-aware conditioning used during PRX-7B training. + + Args: + hidden_size (`int`): + Dimension of the output embedding (must match the timestep embedding dimension). + max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for the sinusoidal resolution embedding. + """ + + def __init__(self, hidden_size: int, max_period: int = 10000): + super().__init__() + self.max_period = max_period + self.mlp = MLPEmbedder(in_dim=256, hidden_dim=hidden_size) + + def forward(self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + h_emb = get_timestep_embedding( + timesteps=height, + embedding_dim=128, + max_period=self.max_period, + scale=1.0, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + w_emb = get_timestep_embedding( + timesteps=width, + embedding_dim=128, + max_period=self.max_period, + scale=1.0, + flip_sin_to_cos=True, + downscale_freq_shift=0.0, + ) + hw_emb = torch.cat([h_emb, w_emb], dim=-1).to(self.mlp.in_layer.weight.dtype) + return self.mlp(hw_emb).to(dtype) + + class Modulation(nn.Module): r""" Modulation network that generates scale, shift, and gating parameters. @@ -614,12 +656,19 @@ class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): Scaling factor applied in timestep embeddings. time_max_period (`int`, *optional*, defaults to 10000): Maximum frequency period for timestep embeddings. + bottleneck_size (`int`, *optional*): + If set, the image patch projection (`img_in`) uses a two-layer bottleneck + (`patch_dim -> bottleneck_size -> hidden_size`) instead of a single linear layer. Used by the pixel-space + PRX-7B variant where the patch dimension is large. + resolution_embeds (`bool`, *optional*, defaults to `False`): + Whether to condition the timestep modulation on the latent resolution `(H, W)` via a + `PRXResolutionEmbedder`. Used by the PRX-7B variant. Attributes: pe_embedder (`EmbedND`): Multi-axis rotary embedding generator for positional encodings. - img_in (`nn.Linear`): - Projection layer for image patch tokens. + img_in (`nn.Linear` or `nn.Sequential`): + Projection layer for image patch tokens (a two-layer bottleneck when `bottleneck_size` is set). time_in (`MLPEmbedder`): Embedding layer for timestep embeddings. txt_in (`nn.Linear`): @@ -667,6 +716,8 @@ def __init__( theta: int = 10000, time_factor: float = 1000.0, time_max_period: int = 10000, + bottleneck_size: int = None, + resolution_embeds: bool = False, ): super().__init__() @@ -692,10 +743,22 @@ def __init__( self.hidden_size = hidden_size self.num_heads = num_heads self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) - self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + patch_dim = self.in_channels * self.patch_size**2 + if bottleneck_size is not None: + # Two-layer bottleneck projection (used by pixel-space PRX where the patch dimension is large). + self.img_in = nn.Sequential( + nn.Linear(patch_dim, bottleneck_size, bias=True), + nn.Linear(bottleneck_size, self.hidden_size, bias=True), + ) + else: + self.img_in = nn.Linear(patch_dim, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + self.resolution_embedder = ( + PRXResolutionEmbedder(self.hidden_size, max_period=time_max_period) if resolution_embeds else None + ) + self.blocks = nn.ModuleList( [ PRXBlock( @@ -772,6 +835,13 @@ def forward( # Compute time embedding vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) + # Add resolution conditioning (PRX-7B "vec" mode): embed the latent (H, W) and add it to the timestep vector + # so every block's modulation is resolution-aware. + if self.resolution_embedder is not None: + height = torch.full((bs,), h, device=hidden_states.device, dtype=torch.float32) + width = torch.full((bs,), w, device=hidden_states.device, dtype=torch.float32) + vec = vec + self.resolution_embedder(height, width, dtype=vec.dtype) + # Apply transformer blocks for block in self.blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index cfa1f8d92558..9b135b889a82 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -147,7 +147,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["prx"] = ["PRXPipeline"] + _import_structure["prx"] = ["PRXPixelPipeline", "PRXPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -780,7 +780,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .prx import PRXPipeline + from .prx import PRXPixelPipeline, PRXPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/prx/__init__.py b/src/diffusers/pipelines/prx/__init__.py index 87aaefbd1368..c85e0665d2eb 100644 --- a/src/diffusers/pipelines/prx/__init__.py +++ b/src/diffusers/pipelines/prx/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_prx"] = ["PRXPipeline"] + _import_structure["pipeline_prx_pixel"] = ["PRXPixelPipeline"] # Import T5GemmaEncoder for pipeline loading compatibility try: @@ -46,6 +47,7 @@ else: from .pipeline_output import PRXPipelineOutput from .pipeline_prx import PRXPipeline + from .pipeline_prx_pixel import PRXPixelPipeline else: import sys diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index 0171c9a42a40..cd179f58e52b 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -126,6 +126,12 @@ def __init__(self): + r"]{1,}" ) + def basic_clean(self, text: str) -> str: + """Light cleaning (fix mojibake + unescape HTML). Used by encoders trained without DeepFloyd cleaning.""" + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + def clean_text(self, text: str) -> str: """Clean text using comprehensive text processing logic.""" # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py @@ -298,6 +304,20 @@ def __init__( self.text_preprocessor = TextPreprocessor() self.default_sample_size = default_sample_size self._guidance_scale = 1.0 + # Max number of text tokens. When None, falls back to the tokenizer's own ``model_max_length``. + # Subclasses (e.g. the PRXPixel pipeline, whose Qwen tokenizer has a very large model_max_length) + # can pin this to the value used at training time. + self.tokenizer_max_length = None + # When True, prompts get only light cleaning (``basic_clean``) instead of the DeepFloyd ``clean_text``. + # Set by subclasses whose text encoder was trained without the heavy cleaning (e.g. the Qwen tower). + self.skip_text_cleaning = False + # What the transformer predicts. "flow_matching" -> velocity (consumed directly by the scheduler). + # "x_prediction_flow_matching" -> the clean sample x0, converted to velocity before each scheduler step + # (see "Back to Basics: Let Denoising Generative Models Denoise", https://arxiv.org/abs/2511.13720). + self.prediction_type = "flow_matching" + # Standard deviation of the initial noise. Some PRX variants train with a non-unit noise scale and must + # start sampling from `randn * noise_scale` to match the learned flow-matching trajectory. + self.noise_scale = 1.0 self.register_modules( transformer=transformer, @@ -363,7 +383,7 @@ def prepare_latents( width // spatial_compression, ) shape = (batch_size, num_channels_latents, latent_height, latent_width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.noise_scale else: latents = latents.to(device) return latents @@ -422,11 +442,13 @@ def encode_prompt( def _tokenize_prompts(self, prompts: list[str], device: torch.device): """Tokenize and clean prompts.""" - cleaned = [self.text_preprocessor.clean_text(text) for text in prompts] + clean_fn = self.text_preprocessor.basic_clean if self.skip_text_cleaning else self.text_preprocessor.clean_text + cleaned = [clean_fn(text) for text in prompts] + max_length = self.tokenizer_max_length or self.tokenizer.model_max_length tokens = self.tokenizer( cleaned, padding="max_length", - max_length=self.tokenizer.model_max_length, + max_length=max_length, truncation=True, return_attention_mask=True, return_tensors="pt", @@ -627,12 +649,15 @@ def __call__( height = height or default_resolution width = width or default_resolution + if use_resolution_binning and self.image_processor is None: + # Pixel-space / no-VAE pipelines have no image_processor and cannot bin; disable it transparently. + logger.warning( + "Resolution binning requires a VAE with image_processor, but none is available; " + "proceeding with use_resolution_binning=False." + ) + use_resolution_binning = False + if use_resolution_binning: - if self.image_processor is None: - raise ValueError( - "Resolution binning requires a VAE with image_processor, but VAE is not available. " - "Set use_resolution_binning=False or provide a VAE." - ) if self.default_sample_size not in ASPECT_RATIO_BINS: raise ValueError( f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, " @@ -763,6 +788,12 @@ def __call__( noise_uncond, noise_text = noise_pred.chunk(2, dim=0) noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + # If the model predicts the clean sample x0, convert it to the flow-matching velocity + # the scheduler expects: v = (x_t - x0) / t (t = normalized noise level, clamped for stability). + if self.prediction_type == "x_prediction_flow_matching": + t_x = torch.clamp(t.float() / self.scheduler.config.num_train_timesteps, min=0.05) + noise_pred = (latents - noise_pred) / t_x + # Compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample diff --git a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py new file mode 100644 index 000000000000..f565d02480c0 --- /dev/null +++ b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py @@ -0,0 +1,97 @@ +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase + +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx.pipeline_prx import PRXPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# PRXPixel is a 1024px model. +PRX_PIXEL_DEFAULT_RESOLUTION = 1024 +# Number of text tokens used at training time (the Qwen tokenizer's own ``model_max_length`` is far larger). +PRX_PIXEL_DEFAULT_MAX_TOKENS = 256 + + +class PRXPixelPipeline(PRXPipeline): + r""" + Pipeline for text-to-image generation with the PRXPixel model. + + PRXPixel is a pixel-space variant of [`PRXPipeline`]: it denoises raw RGB directly (the VAE is an identity / + absent component), conditions on a Qwen3-VL text encoder rather than T5Gemma, and feeds the latent resolution + into the timestep modulation (`resolution_embeds=True` on the [`PRXTransformer2DModel`]). The denoising loop, + prompt encoding, latent preparation and CFG handling are all inherited from [`PRXPipeline`]; only the component + types, the text-token budget, the (lighter) prompt cleaning, and the default resolution differ. + + This pipeline inherits from [`PRXPipeline`]. Check the superclass documentation for the generic methods (text + encoding, latent preparation, the `__call__` signature, ...). + + Args: + transformer ([`PRXTransformer2DModel`]): + The PRX denoiser. For PRXPixel this is built with `in_channels=3`, a bottleneck `img_in`, and + `resolution_embeds=True`. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + Flow-matching scheduler used to denoise the (pixel-space) latents. + text_encoder ([`PreTrainedModel`]): + The Qwen3-VL text backbone used to encode prompts (the vision tower is discarded). Must return a + `last_hidden_state`. + tokenizer ([`PreTrainedTokenizerBase`]): + Tokenizer for `text_encoder` (typically loaded via `AutoTokenizer`). + vae ([`AutoencoderKL`] or [`AutoencoderDC`], *optional*): + Optional VAE. PRXPixel operates in pixel space, so this is usually `None` (an identity VAE). + default_sample_size (`int`, *optional*, defaults to 1024): + Default height/width used when none is provided to `__call__`. + prompt_max_tokens (`int`, *optional*, defaults to 256): + Number of text tokens the prompt is padded/truncated to before encoding. + """ + + def __init__( + self, + transformer: PRXTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer | PreTrainedTokenizerBase, + vae: AutoencoderKL | AutoencoderDC | None = None, + default_sample_size: int | None = PRX_PIXEL_DEFAULT_RESOLUTION, + prompt_max_tokens: int = PRX_PIXEL_DEFAULT_MAX_TOKENS, + noise_scale: float = 2.0, + ): + super().__init__( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + default_sample_size=default_sample_size, + ) + # Pin the text-token budget; the Qwen tokenizer's model_max_length is otherwise far too large. + self.tokenizer_max_length = prompt_max_tokens + # The Qwen3-VL embedding tower was trained without the DeepFloyd cleaning; use light cleaning only. + self.skip_text_cleaning = True + # PRXPixel predicts the clean sample x0 (converted to velocity each step), not the velocity directly. + self.prediction_type = "x_prediction_flow_matching" + # PRXPixel trains with a non-unit initial-noise scale; sampling must start from randn * noise_scale. + self.noise_scale = noise_scale + + @property + def vae_scale_factor(self): + # PRXPixel operates directly in RGB pixel space (identity / no VAE): no spatial compression. + if self.vae is None: + return 1 + return super().vae_scale_factor diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8758c549ca77..18f6b45e2904 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2492,6 +2492,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PRXPixelPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class PRXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 785857ee404dd3d49081bf2c39e7a24ad945c8b0 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 11 Jun 2026 21:34:22 +0200 Subject: [PATCH 2/7] Apply diffusers review conventions to PRX pixel pipeline - Use relative imports in pipeline_prx_pixel.py - Register prompt_max_tokens/noise_scale to config for save/load round-trip - Fix Optional[int] annotation for bottleneck_size - Fix PRXResolutionEmbedder dtype handling (cast to compute dtype, fixes layerwise casting with float8 storage) - Fix import ordering (PRXPipeline before PRXPixelPipeline) in __init__ files and dummy objects - Add PRXPixelPipeline autodoc entry and pixel-variant mention to docs - Add fast pipeline tests (tests/pipelines/prx/test_pipeline_prx_pixel.py) - make style/quality/fix-copies clean Co-Authored-By: Claude Fable 5 --- docs/source/en/api/pipelines/prx.md | 8 + scripts/convert_prx_to_diffusers.py | 8 +- src/diffusers/__init__.py | 4 +- .../models/transformers/transformer_prx.py | 14 +- src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/prx/pipeline_prx_pixel.py | 25 +- .../dummy_torch_and_transformers_objects.py | 4 +- .../pipelines/prx/test_pipeline_prx_pixel.py | 276 ++++++++++++++++++ 8 files changed, 317 insertions(+), 26 deletions(-) create mode 100644 tests/pipelines/prx/test_pipeline_prx_pixel.py diff --git a/docs/source/en/api/pipelines/prx.md b/docs/source/en/api/pipelines/prx.md index 16670f4bfc86..be553878a501 100644 --- a/docs/source/en/api/pipelines/prx.md +++ b/docs/source/en/api/pipelines/prx.md @@ -17,6 +17,8 @@ PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. +[`PRXPipeline`] runs in the VAE latent space, while [`PRXPixelPipeline`] is a pixel-space variant that operates directly on RGB pixels without a VAE. + ## Available models PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. @@ -126,6 +128,12 @@ pipe.enable_sequential_cpu_offload() - all - __call__ +## PRXPixelPipeline + +[[autodoc]] PRXPixelPipeline + - all + - __call__ + ## PRXPipelineOutput [[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput diff --git a/scripts/convert_prx_to_diffusers.py b/scripts/convert_prx_to_diffusers.py index b93589e4f94a..7b1677d0d3ae 100644 --- a/scripts/convert_prx_to_diffusers.py +++ b/scripts/convert_prx_to_diffusers.py @@ -130,7 +130,7 @@ def load_denoiser_state_dict(checkpoint_path: str, prefix: str = DENOISER_PREFIX print(f" Reading {len(keys)} denoiser tensors (skipping optimizer / EMA / RNG state)...") nested = _load_state_dict_from_keys(keys, storage_reader=reader) flat = _flatten(nested) - state_dict = {k[len(prefix):]: v for k, v in flat.items() if k.startswith(prefix)} + state_dict = {k[len(prefix) :]: v for k, v in flat.items() if k.startswith(prefix)} else: print(f"Loading single-file checkpoint from: {checkpoint_path}") if not os.path.exists(checkpoint_path): @@ -142,7 +142,7 @@ def load_denoiser_state_dict(checkpoint_path: str, prefix: str = DENOISER_PREFIX state_dict = ckpt # Strip a denoiser prefix if the keys carry one. if any(k.startswith(prefix) for k in state_dict): - state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)} print(f"✓ Loaded {len(state_dict)} denoiser parameters") return state_dict @@ -361,7 +361,9 @@ def main(args): if args.skip_text_encoder: print("Skipped text encoder; verifying the transformer reloads from disk...") reloaded = PRXTransformer2DModel.from_pretrained(transformer_path) - print(f"✓ Transformer reloaded: {type(reloaded).__name__} ({sum(p.numel() for p in reloaded.parameters()):,} params)") + print( + f"✓ Transformer reloaded: {type(reloaded).__name__} ({sum(p.numel() for p in reloaded.parameters()):,} params)" + ) else: from diffusers import PRXPipeline, PRXPixelPipeline diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fd4ce0f10b6a..5f3846877651 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -587,8 +587,8 @@ "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", - "PRXPixelPipeline", "PRXPipeline", + "PRXPixelPipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", "QwenImageEditInpaintPipeline", @@ -1331,8 +1331,8 @@ PixArtAlphaPipeline, PixArtSigmaPAGPipeline, PixArtSigmaPipeline, - PRXPixelPipeline, PRXPipeline, + PRXPixelPipeline, QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 7919cf1986e9..d7115aed3d9d 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import nn @@ -360,8 +360,8 @@ def forward(self, height: torch.Tensor, width: torch.Tensor, dtype: torch.dtype) flip_sin_to_cos=True, downscale_freq_shift=0.0, ) - hw_emb = torch.cat([h_emb, w_emb], dim=-1).to(self.mlp.in_layer.weight.dtype) - return self.mlp(hw_emb).to(dtype) + hw_emb = torch.cat([h_emb, w_emb], dim=-1).to(dtype) + return self.mlp(hw_emb) class Modulation(nn.Module): @@ -657,9 +657,9 @@ class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): time_max_period (`int`, *optional*, defaults to 10000): Maximum frequency period for timestep embeddings. bottleneck_size (`int`, *optional*): - If set, the image patch projection (`img_in`) uses a two-layer bottleneck - (`patch_dim -> bottleneck_size -> hidden_size`) instead of a single linear layer. Used by the pixel-space - PRX-7B variant where the patch dimension is large. + If set, the image patch projection (`img_in`) uses a two-layer bottleneck (`patch_dim -> bottleneck_size -> + hidden_size`) instead of a single linear layer. Used by the pixel-space PRX-7B variant where the patch + dimension is large. resolution_embeds (`bool`, *optional*, defaults to `False`): Whether to condition the timestep modulation on the latent resolution `(H, W)` via a `PRXResolutionEmbedder`. Used by the PRX-7B variant. @@ -716,7 +716,7 @@ def __init__( theta: int = 10000, time_factor: float = 1000.0, time_max_period: int = 10000, - bottleneck_size: int = None, + bottleneck_size: Optional[int] = None, resolution_embeds: bool = False, ): super().__init__() diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9b135b889a82..c9c152d0d03b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -147,7 +147,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["prx"] = ["PRXPixelPipeline", "PRXPipeline"] + _import_structure["prx"] = ["PRXPipeline", "PRXPixelPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -780,7 +780,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .prx import PRXPixelPipeline, PRXPipeline + from .prx import PRXPipeline, PRXPixelPipeline from .qwenimage import ( QwenImageControlNetInpaintPipeline, QwenImageControlNetPipeline, diff --git a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py index f565d02480c0..7903263f0198 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py +++ b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py @@ -14,11 +14,11 @@ from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase -from diffusers.models import AutoencoderDC, AutoencoderKL -from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel -from diffusers.pipelines.prx.pipeline_prx import PRXPipeline -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import logging +from ...models import AutoencoderDC, AutoencoderKL +from ...models.transformers.transformer_prx import PRXTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from .pipeline_prx import PRXPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -33,11 +33,11 @@ class PRXPixelPipeline(PRXPipeline): r""" Pipeline for text-to-image generation with the PRXPixel model. - PRXPixel is a pixel-space variant of [`PRXPipeline`]: it denoises raw RGB directly (the VAE is an identity / - absent component), conditions on a Qwen3-VL text encoder rather than T5Gemma, and feeds the latent resolution - into the timestep modulation (`resolution_embeds=True` on the [`PRXTransformer2DModel`]). The denoising loop, - prompt encoding, latent preparation and CFG handling are all inherited from [`PRXPipeline`]; only the component - types, the text-token budget, the (lighter) prompt cleaning, and the default resolution differ. + PRXPixel is a pixel-space variant of [`PRXPipeline`]: it denoises raw RGB directly (the VAE is an identity / absent + component), conditions on a Qwen3-VL text encoder rather than T5Gemma, and feeds the latent resolution into the + timestep modulation (`resolution_embeds=True` on the [`PRXTransformer2DModel`]). The denoising loop, prompt + encoding, latent preparation and CFG handling are all inherited from [`PRXPipeline`]; only the component types, the + text-token budget, the (lighter) prompt cleaning, and the default resolution differ. This pipeline inherits from [`PRXPipeline`]. Check the superclass documentation for the generic methods (text encoding, latent preparation, the `__call__` signature, ...). @@ -89,6 +89,11 @@ def __init__( # PRXPixel trains with a non-unit initial-noise scale; sampling must start from randn * noise_scale. self.noise_scale = noise_scale + # `super().__init__` already registered `default_sample_size`; register the extra scalar __init__ args too so + # they are written to `model_index.json` and restored on `from_pretrained` (otherwise they silently fall back + # to the constructor defaults). + self.register_to_config(prompt_max_tokens=prompt_max_tokens, noise_scale=noise_scale) + @property def vae_scale_factor(self): # PRXPixel operates directly in RGB pixel space (identity / no VAE): no spatial compression. diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 18f6b45e2904..325b174e25c8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2492,7 +2492,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PRXPixelPipeline(metaclass=DummyObject): +class PRXPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -2507,7 +2507,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class PRXPipeline(metaclass=DummyObject): +class PRXPixelPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/prx/test_pipeline_prx_pixel.py b/tests/pipelines/prx/test_pipeline_prx_pixel.py new file mode 100644 index 000000000000..9aa85f01898c --- /dev/null +++ b/tests/pipelines/prx/test_pipeline_prx_pixel.py @@ -0,0 +1,276 @@ +import unittest + +import numpy as np +import pytest +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model + +from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel +from diffusers.pipelines.prx.pipeline_prx_pixel import PRXPixelPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_transformers_version + +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +@pytest.mark.xfail( + condition=is_transformers_version(">", "4.57.1"), + reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", + strict=False, +) +class PRXPixelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PRXPixelPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + # Pixel-space PRX has no VAE, so PIL/np output paths are unavailable; outputs are raw RGB tensors ("pt"). + output_type = "pt" + + @classmethod + def setUpClass(cls): + # Ensure PRXPixelPipeline has an _execution_device property expected by __call__ + if not isinstance(getattr(PRXPixelPipeline, "_execution_device", None), property): + try: + setattr(PRXPixelPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) + except Exception: + pass + + def get_dummy_components(self): + torch.manual_seed(0) + # Pixel-space PRX: in_channels=3 (RGB), bottleneck img_in, resolution_embeds=True. + # context_in_dim must match the text encoder hidden_size (16). + transformer = PRXTransformer2DModel( + patch_size=1, + in_channels=3, + context_in_dim=16, + hidden_size=8, + mlp_ratio=2.0, + num_heads=2, + depth=1, + axes_dim=[2, 2], + bottleneck_size=8, + resolution_embeds=True, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + # Tiny Qwen3 text encoder returning `last_hidden_state` (Qwen3-VL-style backbone). + torch.manual_seed(0) + config = Qwen3Config( + hidden_size=16, + intermediate_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + text_encoder = Qwen3Model(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + return { + "transformer": transformer, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + # Pixel-space: no VAE. Passed explicitly (as None) so it appears in init_components and matches + # pipe.components (which always registers the optional `vae`). + "vae": None, + "prompt_max_tokens": 16, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + return { + "prompt": "", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "output_type": "pt", + "use_resolution_binning": False, + } + + def _build_pipe(self, device="cpu"): + components = self.get_dummy_components() + pipe = PRXPixelPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + try: + pipe.register_to_config(_execution_device=device) + except Exception: + pass + return pipe + + def test_inference(self): + device = "cpu" + pipe = self._build_pipe(device) + + # No VAE -> identity pixel space, vae_scale_factor == 1. + self.assertIsNone(pipe.vae) + self.assertEqual(pipe.vae_scale_factor, 1) + self.assertIsNone(pipe.image_processor) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + # Output is raw RGB at the requested resolution. + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.zeros(3, 32, 32) + max_diff = np.abs(generated_image.cpu().numpy() - expected_image.numpy()).max() + self.assertLessEqual(max_diff, 1e10) + + def test_inference_batch(self): + device = "cpu" + pipe = self._build_pipe(device) + + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = ["", ""] + inputs["negative_prompt"] = ["", ""] + image = pipe(**inputs)[0] + + self.assertEqual(image.shape[0], 2) + self.assertEqual(tuple(image.shape[1:]), (3, 32, 32)) + + def test_inference_with_cfg(self): + device = "cpu" + pipe = self._build_pipe(device) + + # CFG off. + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 1.0 + out_no_cfg = pipe(**inputs)[0] + self.assertFalse(pipe.do_classifier_free_guidance) + self.assertEqual(out_no_cfg[0].shape, (3, 32, 32)) + + # CFG on. + inputs = self.get_dummy_inputs(device) + inputs["guidance_scale"] = 5.0 + out_cfg = pipe(**inputs)[0] + self.assertTrue(pipe.do_classifier_free_guidance) + self.assertEqual(out_cfg[0].shape, (3, 32, 32)) + + # Guidance should actually change the output. + max_diff = np.abs(out_no_cfg.cpu().numpy() - out_cfg.cpu().numpy()).max() + self.assertGreater(max_diff, 0.0) + + def test_inference_with_prompt_embeds(self): + device = "cpu" + pipe = self._build_pipe(device) + + # Precompute embeddings via the public encode_prompt API (CFG on so we get negatives too). + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + pipe.encode_prompt( + prompt="a prompt", + device=device, + do_classifier_free_guidance=True, + negative_prompt="", + ) + ) + + inputs = self.get_dummy_inputs(device) + inputs.pop("prompt") + inputs.pop("negative_prompt") + inputs["guidance_scale"] = 5.0 + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["prompt_attention_mask"] = prompt_attention_mask + inputs["negative_prompt_attention_mask"] = negative_prompt_attention_mask + + image = pipe(**inputs)[0] + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_resolution_binning_disabled_without_image_processor(self): + # Pixel-space pipelines have no image_processor; use_resolution_binning=True must be + # transparently disabled (warning) rather than raising. + device = "cpu" + pipe = self._build_pipe(device) + self.assertIsNone(pipe.image_processor) + + inputs = self.get_dummy_inputs(device) + inputs["use_resolution_binning"] = True + # Should not raise despite default_sample_size=1024 not being a binnable scenario without a processor. + image = pipe(**inputs)[0] + self.assertEqual(image[0].shape, (3, 32, 32)) + + def test_callback_inputs(self): + device = "cpu" + pipe = self._build_pipe(device) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {PRXPixelPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its" + " callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(device) + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs = self.get_dummy_inputs(device) + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + # Overridden: the mixin version calls assert_mean_pixel_difference, which assumes HWC image + # arrays. Pixel-space PRX has no VAE and returns raw (C, H, W) tensors ("pt"), so we compare + # tensors directly instead of going through PIL. + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + def to_np_local(tensor): + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + return tensor + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max() + max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max() + self.assertLess(max(max_diff1, max_diff2), expected_max_diff) + + @unittest.skip("Slow original-vs-diffusers parity test is optional and intentionally skipped for fast CI.") + def test_prx_pixel_original_parity(self): + pass From 72e60ff0799a1efffb50020136047ec4b8725b53 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Thu, 11 Jun 2026 21:56:03 +0200 Subject: [PATCH 3/7] Support default output_type='pil' in PRXPixelPipeline without a VAE The inherited PRXPipeline.__call__ raised for output_type='pil'/'np' when no VAE was loaded. Pixel-space outputs are already images in [-1, 1], so PRXPixelPipeline now creates a PixArtImageProcessor (vae_scale_factor=1) and the base post-processing denormalizes the denoised latents directly instead of requiring a VAE decode. This also enables resolution binning for the pixel pipeline. Co-Authored-By: Claude Fable 5 --- src/diffusers/pipelines/prx/pipeline_prx.py | 36 ++++++++++--------- .../pipelines/prx/pipeline_prx_pixel.py | 6 ++++ .../pipelines/prx/test_pipeline_prx_pixel.py | 30 +++++++++------- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index cd179f58e52b..8b0dee2f2f9d 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -18,7 +18,6 @@ import urllib.parse as ul from typing import Callable -import ftfy import torch from transformers import ( AutoTokenizer, @@ -34,13 +33,13 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import ( - logging, - replace_example_docstring, -) +from diffusers.utils import is_ftfy_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor +if is_ftfy_available(): + import ftfy + DEFAULT_RESOLUTION = 512 ASPECT_RATIO_256_BIN = { @@ -650,9 +649,10 @@ def __call__( width = width or default_resolution if use_resolution_binning and self.image_processor is None: - # Pixel-space / no-VAE pipelines have no image_processor and cannot bin; disable it transparently. + # Latent-space pipelines constructed without a VAE have no image_processor and cannot bin; + # disable it transparently. logger.warning( - "Resolution binning requires a VAE with image_processor, but none is available; " + "Resolution binning requires an image processor, but none is available; " "proceeding with use_resolution_binning=False." ) use_resolution_binning = False @@ -681,9 +681,9 @@ def __call__( negative_prompt_embeds, ) - if self.vae is None and output_type not in ["latent", "pt"]: + if self.vae is None and self.image_processor is None and output_type not in ["latent", "pt"]: raise ValueError( - f"VAE is required for output_type='{output_type}' but it is not available. " + f"output_type='{output_type}' requires a VAE or an image processor, but neither is available. " "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs." ) @@ -808,15 +808,19 @@ def __call__( progress_bar.update() # 8. Post-processing - if output_type == "latent" or (output_type == "pt" and self.vae is None): + if output_type == "latent" or (output_type == "pt" and self.image_processor is None): image = latents else: - # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) - scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) - shift_factor = getattr(self.vae.config, "shift_factor", 0.0) - latents = (latents / scaling_factor) + shift_factor - # Decode using VAE (AutoencoderKL or AutoencoderDC) - image = self.vae.decode(latents, return_dict=False)[0] + if self.vae is not None: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + latents = (latents / scaling_factor) + shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + else: + # Pixel-space pipelines have no VAE: the denoised latents are already images in [-1, 1]. + image = latents # Resize back to original resolution if using binning if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) diff --git a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py index 7903263f0198..ea29a3d4f1c3 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py +++ b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py @@ -14,6 +14,7 @@ from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase +from ...image_processor import PixArtImageProcessor from ...models import AutoencoderDC, AutoencoderKL from ...models.transformers.transformer_prx import PRXTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -94,6 +95,11 @@ def __init__( # to the constructor defaults). self.register_to_config(prompt_max_tokens=prompt_max_tokens, noise_scale=noise_scale) + if self.image_processor is None: + # Without a VAE the denoised latents are already images in [-1, 1]; an image processor with + # vae_scale_factor=1 is all that is needed to support output_type="pil"/"np" and resolution binning. + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + @property def vae_scale_factor(self): # PRXPixel operates directly in RGB pixel space (identity / no VAE): no spatial compression. diff --git a/tests/pipelines/prx/test_pipeline_prx_pixel.py b/tests/pipelines/prx/test_pipeline_prx_pixel.py index 9aa85f01898c..82c69e08ca23 100644 --- a/tests/pipelines/prx/test_pipeline_prx_pixel.py +++ b/tests/pipelines/prx/test_pipeline_prx_pixel.py @@ -27,9 +27,6 @@ class PRXPixelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - # Pixel-space PRX has no VAE, so PIL/np output paths are unavailable; outputs are raw RGB tensors ("pt"). - output_type = "pt" - @classmethod def setUpClass(cls): # Ensure PRXPixelPipeline has an _execution_device property expected by __call__ @@ -116,10 +113,11 @@ def test_inference(self): device = "cpu" pipe = self._build_pipe(device) - # No VAE -> identity pixel space, vae_scale_factor == 1. + # No VAE -> identity pixel space, vae_scale_factor == 1, but postprocessing still works + # through an image processor so output_type="pil"/"np" are supported. self.assertIsNone(pipe.vae) self.assertEqual(pipe.vae_scale_factor, 1) - self.assertIsNone(pipe.image_processor) + self.assertIsNotNone(pipe.image_processor) inputs = self.get_dummy_inputs(device) image = pipe(**inputs)[0] @@ -191,18 +189,24 @@ def test_inference_with_prompt_embeds(self): image = pipe(**inputs)[0] self.assertEqual(image[0].shape, (3, 32, 32)) - def test_resolution_binning_disabled_without_image_processor(self): - # Pixel-space pipelines have no image_processor; use_resolution_binning=True must be - # transparently disabled (warning) rather than raising. + def test_inference_pil_and_np_output(self): + # The default output_type="pil" must work without a VAE: the denoised pixels are denormalized + # directly by the image processor instead of being decoded. device = "cpu" pipe = self._build_pipe(device) - self.assertIsNone(pipe.image_processor) inputs = self.get_dummy_inputs(device) - inputs["use_resolution_binning"] = True - # Should not raise despite default_sample_size=1024 not being a binnable scenario without a processor. - image = pipe(**inputs)[0] - self.assertEqual(image[0].shape, (3, 32, 32)) + inputs.pop("output_type") # default is "pil" + images = pipe(**inputs).images + self.assertEqual(len(images), 1) + self.assertEqual(images[0].size, (32, 32)) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "np" + images = pipe(**inputs).images + self.assertEqual(images.shape, (1, 32, 32, 3)) + self.assertGreaterEqual(images.min(), 0.0) + self.assertLessEqual(images.max(), 1.0) def test_callback_inputs(self): device = "cpu" From 8f2121577b7a96c28f60e8b7b00e786faeba2c82 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 12 Jun 2026 09:32:31 +0200 Subject: [PATCH 4/7] Validate height/width against vae_scale_factor * transformer patch_size check_inputs only checked divisibility by vae_scale_factor, which is 1 for the pixel pipeline (and ignores the patch size for latent ones), so sizes like 1000px passed validation and crashed mid-denoising with an opaque reshape RuntimeError in img2seq. Co-Authored-By: Claude Fable 5 --- src/diffusers/pipelines/prx/pipeline_prx.py | 8 +++--- .../pipelines/prx/test_pipeline_prx_pixel.py | 27 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index 8b0dee2f2f9d..ce5bff2dee27 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -523,10 +523,12 @@ def check_inputs( "`negative_prompt_embeds` must also be provided for classifier-free guidance." ) - spatial_compression = self.vae_scale_factor - if height % spatial_compression != 0 or width % spatial_compression != 0: + # The latents must be divisible by the transformer's patch size after VAE compression. + dimension_multiple = self.vae_scale_factor * self.transformer.config.patch_size + if height % dimension_multiple != 0 or width % dimension_multiple != 0: raise ValueError( - f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}." + f"`height` and `width` have to be divisible by {dimension_multiple} (vae_scale_factor *" + f" transformer patch_size) but are {height} and {width}." ) if guidance_scale < 1.0: diff --git a/tests/pipelines/prx/test_pipeline_prx_pixel.py b/tests/pipelines/prx/test_pipeline_prx_pixel.py index 82c69e08ca23..7163de0a66c9 100644 --- a/tests/pipelines/prx/test_pipeline_prx_pixel.py +++ b/tests/pipelines/prx/test_pipeline_prx_pixel.py @@ -208,6 +208,33 @@ def test_inference_pil_and_np_output(self): self.assertGreaterEqual(images.min(), 0.0) self.assertLessEqual(images.max(), 1.0) + def test_non_multiple_size_raises(self): + # height/width must be divisible by vae_scale_factor * transformer patch_size; check_inputs must raise + # a clear ValueError instead of letting the transformer fail on an invalid reshape mid-denoising. + device = "cpu" + components = self.get_dummy_components() + torch.manual_seed(0) + components["transformer"] = PRXTransformer2DModel( + patch_size=2, + in_channels=3, + context_in_dim=16, + hidden_size=8, + mlp_ratio=2.0, + num_heads=2, + depth=1, + axes_dim=[2, 2], + bottleneck_size=8, + resolution_embeds=True, + ) + pipe = PRXPixelPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["height"] = 31 # vae_scale_factor (1) * patch_size (2) = 2; 31 is not a multiple + with self.assertRaisesRegex(ValueError, "divisible"): + pipe(**inputs) + def test_callback_inputs(self): device = "cpu" pipe = self._build_pipe(device) From 6f7c1cefad32c6d6ea02653e5e16e983613ae53f Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 12 Jun 2026 12:04:42 +0200 Subject: [PATCH 5/7] Document public PRXPixel checkpoint (Photoroom/prxpixel-t2i) Add the pixel model to the available-models table, a pixel-space loading example in the docs, and an Examples block in the PRXPixelPipeline docstring now that the weights are public. Verified end-to-end: from_pretrained + 1024px generation. Co-Authored-By: Claude Fable 5 --- docs/source/en/api/pipelines/prx.md | 19 ++++++++++++++++++- .../pipelines/prx/pipeline_prx_pixel.py | 13 +++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/prx.md b/docs/source/en/api/pipelines/prx.md index be553878a501..d0889bbed5b9 100644 --- a/docs/source/en/api/pipelines/prx.md +++ b/docs/source/en/api/pipelines/prx.md @@ -33,7 +33,8 @@ PRX offers multiple variants with different VAE configurations, each optimized f | [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s +| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/prxpixel-t2i`](https://huggingface.co/Photoroom/prxpixel-t2i)| 1024 | No | No | Pixel-space model (~7B transformer, no VAE) with a Qwen3-VL text encoder, loaded with [`PRXPixelPipeline`] | Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information. @@ -53,6 +54,22 @@ image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] image.save("prx_output.png") ``` +### Pixel-space generation + +[`PRXPixelPipeline`] denoises raw RGB directly, so no VAE is loaded or needed. It requires `transformers >= 4.57` (the version that introduced `Qwen3VLTextModel`). + +```py +import torch +from diffusers import PRXPixelPipeline + +pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A front-facing portrait of a lion in the golden savanna at sunset." +image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] +image.save("prxpixel_output.png") +``` + ### Manual Component Loading Load components individually to customize the pipeline for instance to use quantized models. diff --git a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py index ea29a3d4f1c3..98ec59da1d80 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py +++ b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py @@ -43,6 +43,19 @@ class PRXPixelPipeline(PRXPipeline): This pipeline inherits from [`PRXPipeline`]. Check the superclass documentation for the generic methods (text encoding, latent preparation, the `__call__` signature, ...). + Examples: + ```py + >>> import torch + >>> from diffusers import PRXPixelPipeline + + >>> pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A front-facing portrait of a lion in the golden savanna at sunset." + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] + >>> image.save("prxpixel_output.png") + ``` + Args: transformer ([`PRXTransformer2DModel`]): The PRX denoiser. For PRXPixel this is built with `in_channels=3`, a bottleneck `img_in`, and From 9b03b6f3bbd924619a7a3750a63090b52dc89631 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 12 Jun 2026 12:11:18 +0200 Subject: [PATCH 6/7] Restructure PRX docs around PRXPixel as the flagship model Lead the intro, models table, loading examples, and autodoc sections with the pixel-space model; present the latent-space checkpoints as earlier PRX versions. Co-Authored-By: Claude Fable 5 --- docs/source/en/api/pipelines/prx.md | 45 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/docs/source/en/api/pipelines/prx.md b/docs/source/en/api/pipelines/prx.md index d0889bbed5b9..dcb262437e52 100644 --- a/docs/source/en/api/pipelines/prx.md +++ b/docs/source/en/api/pipelines/prx.md @@ -15,17 +15,17 @@ # PRX -PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. +PRX is a family of efficient text-to-image diffusion models by Photoroom. The flagship model, **PRXPixel** ([`PRXPixelPipeline`]), generates 1024px images directly in pixel space: a ~7B transformer denoises raw RGB without any VAE, conditioned on a Qwen3-VL text encoder, and feeds the generation resolution into the timestep modulation. It uses flow matching and predicts the clean image at each step (x-prediction). -[`PRXPipeline`] runs in the VAE latent space, while [`PRXPixelPipeline`] is a pixel-space variant that operates directly on RGB pixels without a VAE. +Earlier PRX versions ([`PRXPipeline`]) operate in a VAE latent space (Flux VAE with 8x compression, or DC-AE with 32x compression) with a ~1.3B simplified MMDIT transformer where text tokens don't update through the blocks, and Google's T5Gemma-2B-2B-UL2 model for text encoding. ## Available models -PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. - +**PRXPixel is the flagship model.** The other checkpoints are earlier latent-space versions of PRX at 256/512px with different VAE configurations; the distilled variants generate in 8 steps. | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | |:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| +| [`Photoroom/prxpixel-t2i`](https://huggingface.co/Photoroom/prxpixel-t2i)| 1024 | No | No | Flagship pixel-space model (~7B transformer, no VAE) with a Qwen3-VL text encoder, loaded with [`PRXPixelPipeline`] | Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | @@ -34,45 +34,44 @@ PRX offers multiple variants with different VAE configurations, each optimized f | [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | -| [`Photoroom/prxpixel-t2i`](https://huggingface.co/Photoroom/prxpixel-t2i)| 1024 | No | No | Pixel-space model (~7B transformer, no VAE) with a Qwen3-VL text encoder, loaded with [`PRXPixelPipeline`] | Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information. ## Loading the pipeline -Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. +Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. [`PRXPixelPipeline`] denoises raw RGB directly, so no VAE is loaded or needed. It requires `transformers >= 4.57` (the version that introduced `Qwen3VLTextModel`). ```py -from diffusers.pipelines.prx import PRXPipeline +import torch +from diffusers import PRXPixelPipeline -# Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) pipe.to("cuda") -prompt = "A front-facing portrait of a lion the golden savanna at sunset." +prompt = "A front-facing portrait of a lion in the golden savanna at sunset." image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] -image.save("prx_output.png") +image.save("prxpixel_output.png") ``` -### Pixel-space generation +### Latent-space generation (earlier PRX versions) -[`PRXPixelPipeline`] denoises raw RGB directly, so no VAE is loaded or needed. It requires `transformers >= 4.57` (the version that introduced `Qwen3VLTextModel`). +Use [`PRXPipeline`] for the earlier latent-space checkpoints; the VAE and text encoder are loaded as part of the pipeline. ```py import torch -from diffusers import PRXPixelPipeline +from diffusers import PRXPipeline -pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "A front-facing portrait of a lion in the golden savanna at sunset." image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] -image.save("prxpixel_output.png") +image.save("prx_output.png") ``` ### Manual Component Loading -Load components individually to customize the pipeline for instance to use quantized models. +Load components individually to customize the pipeline, for instance to use quantized models (shown here for the latent-space [`PRXPipeline`]). ```py import torch @@ -130,24 +129,24 @@ For memory-constrained environments: ```py import torch -from diffusers.pipelines.prx import PRXPipeline +from diffusers import PRXPixelPipeline -pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## PRXPipeline +## PRXPixelPipeline -[[autodoc]] PRXPipeline +[[autodoc]] PRXPixelPipeline - all - __call__ -## PRXPixelPipeline +## PRXPipeline -[[autodoc]] PRXPixelPipeline +[[autodoc]] PRXPipeline - all - __call__ From 9a9f456189d8425b10b5b82d5020298bcb385029 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Mon, 15 Jun 2026 15:40:57 +0200 Subject: [PATCH 7/7] Address dg845 review: standalone PRXPixelPipeline + __call__ args - PRXPixelPipeline now inherits DiffusionPipeline directly (not PRXPipeline); shared methods copied via # Copied from, __call__ and pixel-specific methods reimplemented standalone - tokenizer_max_length and skip_text_cleaning added as explicit __call__ and encode_prompt args in PRXPipeline (per comment 1) - prediction_type removed entirely (baked per-class); noise_scale is a proper PRXPixelPipeline __init__ arg registered to config (per comment 2) - Remove xfail mark from pixel tests (per comment 4) - Add docs/source/en/api/pipelines/prx_pixel.md + toctree entry; restore prx.md to pre-PR state Co-Authored-By: Claude Fable 5 --- docs/source/en/_toctree.yml | 148 ++--- docs/source/en/api/pipelines/prx.md | 46 +- src/diffusers/pipelines/prx/pipeline_prx.py | 104 ++- .../pipelines/prx/pipeline_prx_pixel.py | 624 +++++++++++++++++- .../pipelines/prx/test_pipeline_prx_pixel.py | 31 +- 5 files changed, 739 insertions(+), 214 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 64a4222845b0..953542044d72 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -22,6 +22,8 @@ title: Reproducibility - local: using-diffusers/schedulers title: Schedulers + - local: using-diffusers/guiders + title: Guiders - local: using-diffusers/automodel title: AutoModel - local: using-diffusers/other-formats @@ -110,8 +112,8 @@ title: ModularPipeline - local: modular_diffusers/components_manager title: ComponentsManager - - local: modular_diffusers/guiders - title: Guiders + - local: modular_diffusers/auto_docstring + title: Auto docstring and parameter templates - local: modular_diffusers/custom_blocks title: Building Custom Blocks - local: modular_diffusers/mellon @@ -161,6 +163,8 @@ - local: training/ddpo title: Reinforcement learning training with DDPO title: Methods + - local: training/nemo_automodel + title: NeMo Automodel title: Training - isExpanded: false sections: @@ -192,33 +196,6 @@ - local: optimization/neuron title: AWS Neuron title: Model accelerators and hardware -- isExpanded: false - sections: - - local: using-diffusers/consisid - title: ConsisID - - local: using-diffusers/sdxl - title: Stable Diffusion XL - - local: using-diffusers/sdxl_turbo - title: SDXL Turbo - - local: using-diffusers/kandinsky - title: Kandinsky - - local: using-diffusers/omnigen - title: OmniGen - - local: using-diffusers/pag - title: PAG - - local: using-diffusers/inference_with_lcm - title: Latent Consistency Model - - local: using-diffusers/shap-e - title: Shap-E - - local: using-diffusers/diffedit - title: DiffEdit - - local: using-diffusers/inference_with_tcd_lora - title: Trajectory Consistency Distillation-LoRA - - local: using-diffusers/svd - title: Stable Video Diffusion - - local: using-diffusers/marigold_usage - title: Marigold Computer Vision - title: Specific pipeline examples - isExpanded: false sections: - sections: @@ -318,8 +295,14 @@ title: SparseControlNetModel title: ControlNets - sections: + - local: api/models/ace_step_transformer + title: AceStepTransformer1DModel - local: api/models/allegro_transformer3d title: AllegroTransformer3DModel + - local: api/models/anyflow_far_transformer3d + title: AnyFlowFARTransformer3DModel + - local: api/models/anyflow_transformer3d + title: AnyFlowTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/transformer_bria_fibo @@ -338,18 +321,24 @@ title: CogView4Transformer2DModel - local: api/models/consisid_transformer3d title: ConsisIDTransformer3DModel + - local: api/models/cosmos3_omni_transformer + title: Cosmos3OmniTransformer - local: api/models/cosmos_transformer3d title: CosmosTransformer3DModel - local: api/models/dit_transformer2d title: DiTTransformer2DModel - local: api/models/easyanimate_transformer3d title: EasyAnimateTransformer3DModel + - local: api/models/ernie_image_transformer2d + title: ErnieImageTransformer2DModel - local: api/models/flux2_transformer title: Flux2Transformer2DModel - local: api/models/flux_transformer title: FluxTransformer2DModel - local: api/models/glm_image_transformer2d title: GlmImageTransformer2DModel + - local: api/models/helios_transformer3d + title: HeliosTransformer3DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -360,6 +349,10 @@ title: HunyuanVideo15Transformer3DModel - local: api/models/hunyuan_video_transformer_3d title: HunyuanVideoTransformer3DModel + - local: api/models/ideogram4_transformer2d + title: Ideogram4Transformer2DModel + - local: api/models/transformer_joyimage + title: JoyImageEditTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/longcat_image_transformer2d @@ -374,6 +367,8 @@ title: LuminaNextDiT2DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel + - local: api/models/motif_video_transformer_3d + title: MotifVideoTransformer3DModel - local: api/models/omnigen_transformer title: OmniGenTransformer2DModel - local: api/models/ovisimage_transformer2d @@ -442,6 +437,10 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoder_kl_hunyuan_video15 title: AutoencoderKLHunyuanVideo15 + - local: api/models/autoencoder_kl_kvae + title: AutoencoderKLKVAE + - local: api/models/autoencoder_kl_kvae_video + title: AutoencoderKLKVAEVideo - local: api/models/autoencoderkl_audio_ltx_2 title: AutoencoderKLLTX2Audio - local: api/models/autoencoderkl_ltx_2 @@ -456,6 +455,8 @@ title: AutoencoderKLQwenImage - local: api/models/autoencoder_kl_wan title: AutoencoderKLWan + - local: api/models/autoencoder_rae + title: AutoencoderRAE - local: api/models/consistency_decoder_vae title: ConsistencyDecoderVAE - local: api/models/autoencoder_oobleck @@ -472,28 +473,22 @@ - local: api/pipelines/auto_pipeline title: AutoPipeline - sections: - - local: api/pipelines/audioldm - title: AudioLDM + - local: api/pipelines/ace_step + title: ACE-Step - local: api/pipelines/audioldm2 title: AudioLDM 2 - - local: api/pipelines/dance_diffusion - title: Dance Diffusion - - local: api/pipelines/musicldm - title: MusicLDM + - local: api/pipelines/longcat_audio_dit + title: LongCat-AudioDiT - local: api/pipelines/stable_audio title: Stable Audio title: Audio - sections: - - local: api/pipelines/amused - title: aMUSEd + - local: api/pipelines/anima + title: Anima - local: api/pipelines/animatediff title: AnimateDiff - - local: api/pipelines/attend_and_excite - title: Attend-and-Excite - local: api/pipelines/aura_flow title: AuraFlow - - local: api/pipelines/blip_diffusion - title: BLIP-Diffusion - local: api/pipelines/bria_3_2 title: Bria 3.2 - local: api/pipelines/bria_fibo @@ -520,26 +515,20 @@ title: ControlNet with Stable Diffusion XL - local: api/pipelines/controlnet_sana title: ControlNet-Sana - - local: api/pipelines/controlnetxs - title: ControlNet-XS - - local: api/pipelines/controlnetxs_sdxl - title: ControlNet-XS with Stable Diffusion XL - local: api/pipelines/controlnet_union title: ControlNetUnion - - local: api/pipelines/cosmos - title: Cosmos - local: api/pipelines/ddim title: DDIM - local: api/pipelines/ddpm title: DDPM - local: api/pipelines/deepfloyd_if title: DeepFloyd IF - - local: api/pipelines/diffedit - title: DiffEdit - local: api/pipelines/dit title: DiT - local: api/pipelines/easyanimate title: EasyAnimate + - local: api/pipelines/ernie_image + title: ERNIE-Image - local: api/pipelines/flux title: Flux - local: api/pipelines/flux2 @@ -554,8 +543,12 @@ title: Hunyuan-DiT - local: api/pipelines/hunyuanimage21 title: HunyuanImage2.1 + - local: api/pipelines/ideogram4 + title: Ideogram 4 - local: api/pipelines/pix2pix title: InstructPix2Pix + - local: api/pipelines/joyimage_edit + title: JoyImage Edit - local: api/pipelines/kandinsky title: Kandinsky 2.1 - local: api/pipelines/kandinsky_v22 @@ -580,22 +573,22 @@ title: Lumina-T2X - local: api/pipelines/marigold title: Marigold - - local: api/pipelines/panorama - title: MultiDiffusion + - local: api/pipelines/nucleusmoe_image + title: NucleusMoE-Image - local: api/pipelines/omnigen title: OmniGen - local: api/pipelines/ovis_image title: Ovis-Image - local: api/pipelines/pag title: PAG - - local: api/pipelines/paint_by_example - title: Paint by Example - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma title: PixArt-Σ - local: api/pipelines/prx title: PRX + - local: api/pipelines/prx_pixel + title: PRX Pixel - local: api/pipelines/qwenimage title: QwenImage - local: api/pipelines/sana @@ -604,10 +597,6 @@ title: Sana Sprint - local: api/pipelines/sana_video title: Sana Video - - local: api/pipelines/self_attention_guidance - title: Self-Attention Guidance - - local: api/pipelines/semantic_stable_diffusion - title: Semantic Guidance - local: api/pipelines/shap_e title: Shap-E - local: api/pipelines/stable_cascade @@ -617,23 +606,14 @@ title: Overview - local: api/pipelines/stable_diffusion/depth2img title: Depth-to-image - - local: api/pipelines/stable_diffusion/gligen - title: GLIGEN (Grounded Language-to-Image Generation) - local: api/pipelines/stable_diffusion/image_variation title: Image variation - local: api/pipelines/stable_diffusion/img2img title: Image-to-image - local: api/pipelines/stable_diffusion/inpaint title: Inpainting - - local: api/pipelines/stable_diffusion/k_diffusion - title: K-Diffusion - local: api/pipelines/stable_diffusion/latent_upscale title: Latent upscaler - - local: api/pipelines/stable_diffusion/ldm3d_diffusion - title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D - Upscaler - - local: api/pipelines/stable_diffusion/stable_diffusion_safe - title: Safe Stable Diffusion - local: api/pipelines/stable_diffusion/sdxl_turbo title: SDXL Turbo - local: api/pipelines/stable_diffusion/stable_diffusion_2 @@ -651,36 +631,40 @@ title: Stable Diffusion - local: api/pipelines/stable_unclip title: Stable unCLIP - - local: api/pipelines/unclip - title: unCLIP - - local: api/pipelines/unidiffuser - title: UniDiffuser - local: api/pipelines/value_guided_sampling title: Value-guided sampling - local: api/pipelines/visualcloze title: VisualCloze - - local: api/pipelines/wuerstchen - title: Wuerstchen - local: api/pipelines/z_image title: Z-Image title: Image + - sections: + - local: api/pipelines/llada2 + title: LLaDA2 + title: Text - sections: - local: api/pipelines/allegro title: Allegro + - local: api/pipelines/anyflow + title: AnyFlow - local: api/pipelines/chronoedit title: ChronoEdit - local: api/pipelines/cogvideox title: CogVideoX - local: api/pipelines/consisid title: ConsisID + - local: api/pipelines/cosmos + title: Cosmos + - local: api/pipelines/cosmos3 + title: Cosmos3 - local: api/pipelines/framepack title: Framepack + - local: api/pipelines/helios + title: Helios - local: api/pipelines/hunyuan_video title: HunyuanVideo - local: api/pipelines/hunyuan_video15 title: HunyuanVideo1.5 - - local: api/pipelines/i2vgenxl - title: I2VGen-XL - local: api/pipelines/kandinsky5_video title: Kandinsky 5.0 Video - local: api/pipelines/latte @@ -691,16 +675,12 @@ title: LTXVideo - local: api/pipelines/mochi title: Mochi - - local: api/pipelines/pia - title: Personalized Image Animator (PIA) + - local: api/pipelines/motif_video + title: Motif-Video - local: api/pipelines/skyreels_v2 title: SkyReels-V2 - local: api/pipelines/stable_diffusion/svd title: Stable Video Diffusion - - local: api/pipelines/text_to_video - title: Text-to-video - - local: api/pipelines/text_to_video_zero - title: Text2Video-Zero - local: api/pipelines/wan title: Wan title: Video @@ -708,6 +688,8 @@ - sections: - local: api/schedulers/overview title: Overview + - local: api/schedulers/block_refinement + title: BlockRefinementScheduler - local: api/schedulers/cm_stochastic_iterative title: CMStochasticIterativeScheduler - local: api/schedulers/ddim_cogvideox @@ -742,10 +724,16 @@ title: EulerAncestralDiscreteScheduler - local: api/schedulers/euler title: EulerDiscreteScheduler + - local: api/schedulers/flow_map_euler_discrete + title: FlowMapEulerDiscreteScheduler - local: api/schedulers/flow_match_euler_discrete title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler + - local: api/schedulers/helios_dmd + title: HeliosDMDScheduler + - local: api/schedulers/helios + title: HeliosScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/pipelines/prx.md b/docs/source/en/api/pipelines/prx.md index dcb262437e52..16670f4bfc86 100644 --- a/docs/source/en/api/pipelines/prx.md +++ b/docs/source/en/api/pipelines/prx.md @@ -15,17 +15,15 @@ # PRX -PRX is a family of efficient text-to-image diffusion models by Photoroom. The flagship model, **PRXPixel** ([`PRXPixelPipeline`]), generates 1024px images directly in pixel space: a ~7B transformer denoises raw RGB without any VAE, conditioned on a Qwen3-VL text encoder, and feeds the generation resolution into the timestep modulation. It uses flow matching and predicts the clean image at each step (x-prediction). - -Earlier PRX versions ([`PRXPipeline`]) operate in a VAE latent space (Flux VAE with 8x compression, or DC-AE with 32x compression) with a ~1.3B simplified MMDIT transformer where text tokens don't update through the blocks, and Google's T5Gemma-2B-2B-UL2 model for text encoding. +PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. ## Available models -**PRXPixel is the flagship model.** The other checkpoints are earlier latent-space versions of PRX at 256/512px with different VAE configurations; the distilled variants generate in 8 steps. +PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. + | Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | |:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| -| [`Photoroom/prxpixel-t2i`](https://huggingface.co/Photoroom/prxpixel-t2i)| 1024 | No | No | Flagship pixel-space model (~7B transformer, no VAE) with a Qwen3-VL text encoder, loaded with [`PRXPixelPipeline`] | Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | @@ -33,45 +31,29 @@ Earlier PRX versions ([`PRXPipeline`]) operate in a VAE latent space (Flux VAE w | [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | | [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | -| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information. ## Loading the pipeline -Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. [`PRXPixelPipeline`] denoises raw RGB directly, so no VAE is loaded or needed. It requires `transformers >= 4.57` (the version that introduced `Qwen3VLTextModel`). - -```py -import torch -from diffusers import PRXPixelPipeline - -pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) -pipe.to("cuda") - -prompt = "A front-facing portrait of a lion in the golden savanna at sunset." -image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] -image.save("prxpixel_output.png") -``` - -### Latent-space generation (earlier PRX versions) - -Use [`PRXPipeline`] for the earlier latent-space checkpoints; the VAE and text encoder are loaded as part of the pipeline. +Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. ```py -import torch -from diffusers import PRXPipeline +from diffusers.pipelines.prx import PRXPipeline +# Load pipeline - VAE and text encoder will be loaded from HuggingFace pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.to("cuda") -prompt = "A front-facing portrait of a lion in the golden savanna at sunset." +prompt = "A front-facing portrait of a lion the golden savanna at sunset." image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] image.save("prx_output.png") ``` ### Manual Component Loading -Load components individually to customize the pipeline, for instance to use quantized models (shown here for the latent-space [`PRXPipeline`]). +Load components individually to customize the pipeline for instance to use quantized models. ```py import torch @@ -129,21 +111,15 @@ For memory-constrained environments: ```py import torch -from diffusers import PRXPixelPipeline +from diffusers.pipelines.prx import PRXPipeline -pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) +pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## PRXPixelPipeline - -[[autodoc]] PRXPixelPipeline - - all - - __call__ - ## PRXPipeline [[autodoc]] PRXPipeline diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index ce5bff2dee27..f4ec214313e3 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -125,12 +125,6 @@ def __init__(self): + r"]{1,}" ) - def basic_clean(self, text: str) -> str: - """Light cleaning (fix mojibake + unescape HTML). Used by encoders trained without DeepFloyd cleaning.""" - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - def clean_text(self, text: str) -> str: """Clean text using comprehensive text processing logic.""" # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py @@ -236,6 +230,12 @@ def clean_text(self, text: str) -> str: return text.strip() + def basic_clean(self, text: str) -> str: + """Light cleaning: fix mojibake and unescape HTML. Used when skip_text_cleaning=True.""" + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + EXAMPLE_DOC_STRING = """ Examples: @@ -303,20 +303,6 @@ def __init__( self.text_preprocessor = TextPreprocessor() self.default_sample_size = default_sample_size self._guidance_scale = 1.0 - # Max number of text tokens. When None, falls back to the tokenizer's own ``model_max_length``. - # Subclasses (e.g. the PRXPixel pipeline, whose Qwen tokenizer has a very large model_max_length) - # can pin this to the value used at training time. - self.tokenizer_max_length = None - # When True, prompts get only light cleaning (``basic_clean``) instead of the DeepFloyd ``clean_text``. - # Set by subclasses whose text encoder was trained without the heavy cleaning (e.g. the Qwen tower). - self.skip_text_cleaning = False - # What the transformer predicts. "flow_matching" -> velocity (consumed directly by the scheduler). - # "x_prediction_flow_matching" -> the clean sample x0, converted to velocity before each scheduler step - # (see "Back to Basics: Let Denoising Generative Models Denoise", https://arxiv.org/abs/2511.13720). - self.prediction_type = "flow_matching" - # Standard deviation of the initial noise. Some PRX variants train with a non-unit noise scale and must - # start sampling from `randn * noise_scale` to match the learned flow-matching trajectory. - self.noise_scale = 1.0 self.register_modules( transformer=transformer, @@ -382,7 +368,7 @@ def prepare_latents( width // spatial_compression, ) shape = (batch_size, num_channels_latents, latent_height, latent_width) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.noise_scale + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) return latents @@ -398,6 +384,8 @@ def encode_prompt( negative_prompt_embeds: torch.FloatTensor | None = None, prompt_attention_mask: torch.BoolTensor | None = None, negative_prompt_attention_mask: torch.BoolTensor | None = None, + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, ): """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" if device is None: @@ -408,7 +396,14 @@ def encode_prompt( prompt = [prompt] # Encode the prompts prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( - self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + self._encode_prompt_standard( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + tokenizer_max_length=tokenizer_max_length, + skip_text_cleaning=skip_text_cleaning, + ) ) # Duplicate embeddings for each generation per prompt @@ -439,11 +434,17 @@ def encode_prompt( negative_prompt_attention_mask if do_classifier_free_guidance else None, ) - def _tokenize_prompts(self, prompts: list[str], device: torch.device): + def _tokenize_prompts( + self, + prompts: list[str], + device: torch.device, + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, + ): """Tokenize and clean prompts.""" - clean_fn = self.text_preprocessor.basic_clean if self.skip_text_cleaning else self.text_preprocessor.clean_text + clean_fn = self.text_preprocessor.basic_clean if skip_text_cleaning else self.text_preprocessor.clean_text cleaned = [clean_fn(text) for text in prompts] - max_length = self.tokenizer_max_length or self.tokenizer.model_max_length + max_length = tokenizer_max_length or self.tokenizer.model_max_length tokens = self.tokenizer( cleaned, padding="max_length", @@ -460,6 +461,8 @@ def _encode_prompt_standard( device: torch.device, do_classifier_free_guidance: bool = True, negative_prompt: str = "", + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, ): """Encode prompt using standard text encoder and tokenizer with batch processing.""" batch_size = len(prompt) @@ -472,7 +475,9 @@ def _encode_prompt_standard( else: prompts_to_encode = prompt - input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device) + input_ids, attention_mask = self._tokenize_prompts( + prompts_to_encode, device, tokenizer_max_length=tokenizer_max_length, skip_text_cleaning=skip_text_cleaning + ) with torch.no_grad(): embeddings = self.text_encoder( @@ -569,6 +574,8 @@ def __call__( use_resolution_binning: bool = True, callback_on_step_end: Callable[[int, int], None] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, ): """ Function invoked when calling the pipeline for generation. @@ -622,6 +629,12 @@ def __call__( output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + tokenizer_max_length (`int`, *optional*): + Override the maximum number of tokens used when tokenizing the prompt. Defaults to the tokenizer's own + ``model_max_length`` when not set. + skip_text_cleaning (`bool`, *optional*, defaults to `False`): + If `True`, uses only light prompt cleaning (fix encoding + unescape HTML) instead of the full DeepFloyd + cleaning pipeline. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple. use_resolution_binning (`bool`, *optional*, defaults to `True`): @@ -650,15 +663,6 @@ def __call__( height = height or default_resolution width = width or default_resolution - if use_resolution_binning and self.image_processor is None: - # Latent-space pipelines constructed without a VAE have no image_processor and cannot bin; - # disable it transparently. - logger.warning( - "Resolution binning requires an image processor, but none is available; " - "proceeding with use_resolution_binning=False." - ) - use_resolution_binning = False - if use_resolution_binning: if self.default_sample_size not in ASPECT_RATIO_BINS: raise ValueError( @@ -683,9 +687,9 @@ def __call__( negative_prompt_embeds, ) - if self.vae is None and self.image_processor is None and output_type not in ["latent", "pt"]: + if self.vae is None and output_type not in ["latent", "pt"]: raise ValueError( - f"output_type='{output_type}' requires a VAE or an image processor, but neither is available. " + f"VAE is required for output_type='{output_type}' but it is not available. " "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs." ) @@ -712,6 +716,8 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, + tokenizer_max_length=tokenizer_max_length, + skip_text_cleaning=skip_text_cleaning, ) # Expose standard names for callbacks parity prompt_embeds = text_embeddings @@ -790,12 +796,6 @@ def __call__( noise_uncond, noise_text = noise_pred.chunk(2, dim=0) noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) - # If the model predicts the clean sample x0, convert it to the flow-matching velocity - # the scheduler expects: v = (x_t - x0) / t (t = normalized noise level, clamped for stability). - if self.prediction_type == "x_prediction_flow_matching": - t_x = torch.clamp(t.float() / self.scheduler.config.num_train_timesteps, min=0.05) - noise_pred = (latents - noise_pred) / t_x - # Compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample @@ -810,19 +810,15 @@ def __call__( progress_bar.update() # 8. Post-processing - if output_type == "latent" or (output_type == "pt" and self.image_processor is None): + if output_type == "latent" or (output_type == "pt" and self.vae is None): image = latents else: - if self.vae is not None: - # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) - scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) - shift_factor = getattr(self.vae.config, "shift_factor", 0.0) - latents = (latents / scaling_factor) + shift_factor - # Decode using VAE (AutoencoderKL or AutoencoderDC) - image = self.vae.decode(latents, return_dict=False)[0] - else: - # Pixel-space pipelines have no VAE: the denoised latents are already images in [-1, 1]. - image = latents + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + latents = (latents / scaling_factor) + shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] # Resize back to original resolution if using binning if use_resolution_binning: image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) diff --git a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py index 98ec59da1d80..8a97b686d171 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx_pixel.py +++ b/src/diffusers/pipelines/prx/pipeline_prx_pixel.py @@ -12,14 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html +import inspect +from typing import Callable + +import torch from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase from ...image_processor import PixArtImageProcessor from ...models import AutoencoderDC, AutoencoderKL from ...models.transformers.transformer_prx import PRXTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging -from .pipeline_prx import PRXPipeline +from ...utils import is_ftfy_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import PRXPipelineOutput +from .pipeline_prx import TextPreprocessor + + +if is_ftfy_available(): + import ftfy logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -29,19 +41,74 @@ # Number of text tokens used at training time (the Qwen tokenizer's own ``model_max_length`` is far larger). PRX_PIXEL_DEFAULT_MAX_TOKENS = 256 +# Predefined aspect-ratio bins for 1024px generation (mirrors ASPECT_RATIO_1024_BIN in pipeline_prx). +ASPECT_RATIO_1024_BIN = { + "0.49": [704, 1440], + "0.52": [736, 1408], + "0.53": [736, 1376], + "0.57": [768, 1344], + "0.59": [768, 1312], + "0.62": [800, 1280], + "0.67": [832, 1248], + "0.68": [832, 1216], + "0.78": [896, 1152], + "0.83": [928, 1120], + "0.94": [992, 1056], + "1.0": [1024, 1024], + "1.06": [1056, 992], + "1.13": [1088, 960], + "1.21": [1120, 928], + "1.29": [1152, 896], + "1.37": [1184, 864], + "1.46": [1216, 832], + "1.5": [1248, 832], + "1.71": [1312, 768], + "1.75": [1344, 768], + "1.87": [1376, 736], + "1.91": [1408, 736], + "2.05": [1440, 704], +} + +ASPECT_RATIO_BINS = { + 1024: ASPECT_RATIO_1024_BIN, +} + + +def _basic_clean(text: str) -> str: + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PRXPixelPipeline + + >>> pipe = PRXPixelPipeline.from_pretrained("Photoroom/prxpixel-t2i", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A front-facing portrait of a lion in the golden savanna at sunset." + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] + >>> image.save("prxpixel_output.png") + ``` +""" -class PRXPixelPipeline(PRXPipeline): + +class PRXPixelPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation with the PRXPixel model. - PRXPixel is a pixel-space variant of [`PRXPipeline`]: it denoises raw RGB directly (the VAE is an identity / absent - component), conditions on a Qwen3-VL text encoder rather than T5Gemma, and feeds the latent resolution into the - timestep modulation (`resolution_embeds=True` on the [`PRXTransformer2DModel`]). The denoising loop, prompt - encoding, latent preparation and CFG handling are all inherited from [`PRXPipeline`]; only the component types, the - text-token budget, the (lighter) prompt cleaning, and the default resolution differ. + PRXPixel is a standalone, pixel-space text-to-image pipeline. It denoises raw RGB directly with a ~7B-parameter + [`PRXTransformer2DModel`] and has no VAE (generation happens entirely in pixel space, so the denoised output *is* + the image). Prompts are encoded with a Qwen3-VL text encoder (the vision tower is discarded). Unlike + [`PRXPipeline`] the transformer is trained with x-prediction: at every step it predicts the clean image `x0`, which + is converted to a flow-matching velocity before the scheduler step. Sampling starts from `randn * noise_scale` + (`noise_scale=2.0` by default) and the default resolution is 1024px. - This pipeline inherits from [`PRXPipeline`]. Check the superclass documentation for the generic methods (text - encoding, latent preparation, the `__call__` signature, ...). + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Examples: ```py @@ -58,8 +125,8 @@ class PRXPixelPipeline(PRXPipeline): Args: transformer ([`PRXTransformer2DModel`]): - The PRX denoiser. For PRXPixel this is built with `in_channels=3`, a bottleneck `img_in`, and - `resolution_embeds=True`. + The ~7B-parameter PRX denoiser. For PRXPixel this is built with `in_channels=3`, a bottleneck `img_in`, and + `resolution_embeds=True`, and it is trained to predict the clean image `x0`. scheduler ([`FlowMatchEulerDiscreteScheduler`]): Flow-matching scheduler used to denoise the (pixel-space) latents. text_encoder ([`PreTrainedModel`]): @@ -73,8 +140,15 @@ class PRXPixelPipeline(PRXPipeline): Default height/width used when none is provided to `__call__`. prompt_max_tokens (`int`, *optional*, defaults to 256): Number of text tokens the prompt is padded/truncated to before encoding. + noise_scale (`float`, *optional*, defaults to 2.0): + Scale applied to the initial Gaussian noise. PRXPixel trains with a non-unit initial-noise scale, so + sampling must start from `randn * noise_scale`. """ + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["vae"] + def __init__( self, transformer: PRXTransformer2DModel, @@ -86,36 +160,522 @@ def __init__( prompt_max_tokens: int = PRX_PIXEL_DEFAULT_MAX_TOKENS, noise_scale: float = 2.0, ): - super().__init__( + super().__init__() + + self.text_preprocessor = TextPreprocessor() + self._guidance_scale = 1.0 + + self.register_modules( transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, + ) + self.register_to_config( default_sample_size=default_sample_size, + prompt_max_tokens=prompt_max_tokens, + noise_scale=noise_scale, ) - # Pin the text-token budget; the Qwen tokenizer's model_max_length is otherwise far too large. - self.tokenizer_max_length = prompt_max_tokens - # The Qwen3-VL embedding tower was trained without the DeepFloyd cleaning; use light cleaning only. - self.skip_text_cleaning = True - # PRXPixel predicts the clean sample x0 (converted to velocity each step), not the velocity directly. - self.prediction_type = "x_prediction_flow_matching" - # PRXPixel trains with a non-unit initial-noise scale; sampling must start from randn * noise_scale. - self.noise_scale = noise_scale - - # `super().__init__` already registered `default_sample_size`; register the extra scalar __init__ args too so - # they are written to `model_index.json` and restored on `from_pretrained` (otherwise they silently fall back - # to the constructor defaults). - self.register_to_config(prompt_max_tokens=prompt_max_tokens, noise_scale=noise_scale) - - if self.image_processor is None: - # Without a VAE the denoised latents are already images in [-1, 1]; an image processor with - # vae_scale_factor=1 is all that is needed to support output_type="pil"/"np" and resolution binning. - self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Pixel pipeline always has an image_processor (vae_scale_factor=1) + # so that output_type="pil"/"np" work without a VAE. + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) @property def vae_scale_factor(self): # PRXPixel operates directly in RGB pixel space (identity / no VAE): no spatial compression. if self.vae is None: return 1 - return super().vae_scale_factor + if hasattr(self.vae, "spatial_compression_ratio"): + return self.vae.spatial_compression_ratio + else: # Flux VAE + return 2 ** (len(self.vae.config.block_out_channels) - 1) + + @property + # Copied from diffusers.pipelines.prx.pipeline_prx.PRXPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled based on guidance scale.""" + return self._guidance_scale > 1.0 + + @property + # Copied from diffusers.pipelines.prx.pipeline_prx.PRXPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + def _tokenize_prompts( + self, + prompts: list[str], + device: torch.device, + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Tokenize and (lightly) clean prompts. + + PRXPixel always uses light cleaning (`_basic_clean`) and the training-time token budget + (`self.config.prompt_max_tokens`). The `tokenizer_max_length` and `skip_text_cleaning` arguments are accepted + for API compatibility with the copied callers but are ignored. + """ + cleaned = [_basic_clean(text) for text in prompts] + tokens = self.tokenizer( + cleaned, + padding="max_length", + max_length=self.config.prompt_max_tokens, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device) + + # Copied from diffusers.pipelines.prx.pipeline_prx.PRXPipeline._encode_prompt_standard + def _encode_prompt_standard( + self, + prompt: list[str], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, + ): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + batch_size = len(prompt) + + if do_classifier_free_guidance: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + prompts_to_encode = negative_prompt + prompt + else: + prompts_to_encode = prompt + + input_ids, attention_mask = self._tokenize_prompts( + prompts_to_encode, device, tokenizer_max_length=tokenizer_max_length, skip_text_cleaning=skip_text_cleaning + ) + + with torch.no_grad(): + embeddings = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + )["last_hidden_state"] + + if do_classifier_free_guidance: + uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0) + uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0) + else: + text_embeddings = embeddings + cross_attn_mask = attention_mask + uncond_text_embeddings = None + uncond_cross_attn_mask = None + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + # Copied from diffusers.pipelines.prx.pipeline_prx.PRXPipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + device: torch.device | None = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_attention_mask: torch.BoolTensor | None = None, + negative_prompt_attention_mask: torch.BoolTensor | None = None, + tokenizer_max_length: int | None = None, + skip_text_cleaning: bool = False, + ): + """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" + if device is None: + device = self._execution_device + + if prompt_embeds is None: + if isinstance(prompt, str): + prompt = [prompt] + # Encode the prompts + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self._encode_prompt_standard( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + tokenizer_max_length=tokenizer_max_length, + skip_text_cleaning=skip_text_cleaning, + ) + ) + + # Duplicate embeddings for each generation per prompt + if num_images_per_prompt > 1: + # Repeat prompt embeddings + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # Repeat negative embeddings if using CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds if do_classifier_free_guidance else None, + negative_prompt_attention_mask if do_classifier_free_guidance else None, + ) + + def check_inputs( + self, + prompt: str | list[str], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + ): + """Check that all inputs are in correct format.""" + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided and `guidance_scale > 1.0`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + + # The latents must be divisible by the transformer's patch size after VAE compression. + dimension_multiple = self.vae_scale_factor * self.transformer.config.patch_size + if height % dimension_multiple != 0 or width % dimension_multiple != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {dimension_multiple} (vae_scale_factor *" + f" transformer patch_size) but are {height} and {width}." + ) + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ): + """Prepare initial latents for the diffusion process. + + PRXPixel trains with a non-unit initial-noise scale, so the sampled noise is multiplied by + `self.config.noise_scale`. + """ + if latents is None: + spatial_compression = self.vae_scale_factor + latent_height, latent_width = ( + height // spatial_compression, + width // spatial_compression, + ) + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.config.noise_scale + else: + latents = latents.to(device) + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str = "", + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: int | None = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + prompt_attention_mask: torch.BoolTensor | None = None, + negative_prompt_attention_mask: torch.BoolTensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + use_resolution_binning: bool = True, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str`, *optional*, defaults to `""`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to `default_sample_size`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `default_sample_size`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`list[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an + empty string. + prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated + from `prompt` input argument. + negative_prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`, + attention mask will be generated from an empty string. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back + to the requested resolution. Useful for generating non-square images at optimal resolutions. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`list`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. + """ + + # 0. Set height and width + default_resolution = getattr(self.config, "default_sample_size", None) or PRX_PIXEL_DEFAULT_RESOLUTION + height = height or default_resolution + width = width or default_resolution + + if use_resolution_binning: + if self.config.default_sample_size not in ASPECT_RATIO_BINS: + raise ValueError( + f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, " + f"but got {self.config.default_sample_size}. Set use_resolution_binning=False to disable aspect ratio binning." + ) + aspect_ratio_bin = ASPECT_RATIO_BINS[self.config.default_sample_size] + + # Store original dimensions + orig_height, orig_width = height, width + # Map to closest resolution in the bin + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Use execution device (handles offloading scenarios including group offloading) + device = self._execution_device + + self._guidance_scale = guidance_scale + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + # Expose standard names for callbacks parity + prompt_embeds = text_embeddings + negative_prompt_embeds = uncond_text_embeddings + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self.num_timesteps = len(timesteps) + + # 4. Prepare latent variables + if self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + else: + # When vae is None, get latent channels from transformer + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Prepare cross-attention embeddings and masks + if self.do_classifier_free_guidance: + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + else: + ca_embed = text_embeddings + ca_mask = cross_attn_mask + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents if using classifier-free guidance + if self.do_classifier_free_guidance: + latents_in = torch.cat([latents, latents], dim=0) + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + else: + latents_in = latents + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) + + # Forward through transformer + noise_pred = self.transformer( + hidden_states=latents_in, + timestep=t_cont, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, + return_dict=False, + )[0] + + # Apply CFG + if self.do_classifier_free_guidance: + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # PRXPixel predicts x0; convert to flow-matching velocity before the scheduler step. + t_x = torch.clamp(t.float() / self.scheduler.config.num_train_timesteps, min=0.05) + noise_pred = (latents - noise_pred) / t_x + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing (pixel space: the denoised output IS the image in [-1, 1]; no VAE decode). + if output_type in ["latent", "pt"]: + image = latents + else: + image = latents + # Resize back to original resolution if using binning + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return PRXPipelineOutput(images=image) diff --git a/tests/pipelines/prx/test_pipeline_prx_pixel.py b/tests/pipelines/prx/test_pipeline_prx_pixel.py index 7163de0a66c9..1c061472162a 100644 --- a/tests/pipelines/prx/test_pipeline_prx_pixel.py +++ b/tests/pipelines/prx/test_pipeline_prx_pixel.py @@ -1,25 +1,20 @@ import unittest import numpy as np -import pytest import torch from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel from diffusers.pipelines.prx.pipeline_prx_pixel import PRXPixelPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler -from diffusers.utils import is_transformers_version from ..pipeline_params import TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin -@pytest.mark.xfail( - condition=is_transformers_version(">", "4.57.1"), - reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", - strict=False, -) class PRXPixelPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + # PRXPixelPipeline is standalone: it inherits from DiffusionPipeline (not PRXPipeline) and always has its own + # image_processor, so it denoises raw RGB in pixel space and supports output_type="pil"/"np" without a VAE. pipeline_class = PRXPixelPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) @@ -94,7 +89,11 @@ def get_dummy_inputs(self, device, seed=0): "guidance_scale": 1.0, "height": 32, "width": 32, + # Pixel-space PRX has no VAE and returns raw (C, H, W) tensors for output_type="pt". The generic + # PipelineTesterMixin tests compare these tensors directly, so default to "pt" here; the PIL/np default + # path is exercised explicitly in test_inference and test_inference_pil_and_np_output. "output_type": "pt", + # 32px is not in the 1024 aspect-ratio bins, so binning must be disabled for these tiny fast tests. "use_resolution_binning": False, } @@ -113,17 +112,23 @@ def test_inference(self): device = "cpu" pipe = self._build_pipe(device) - # No VAE -> identity pixel space, vae_scale_factor == 1, but postprocessing still works - # through an image processor so output_type="pil"/"np" are supported. + # No VAE -> identity pixel space, vae_scale_factor == 1, but the pipeline always carries an image processor + # so postprocessing (and the default output_type="pil") works without decoding. self.assertIsNone(pipe.vae) self.assertEqual(pipe.vae_scale_factor, 1) self.assertIsNotNone(pipe.image_processor) + # Default output is PIL (no VAE needed: the image processor denormalizes the denoised pixels directly). + inputs = self.get_dummy_inputs(device) + inputs.pop("output_type") # default is "pil" + images = pipe(**inputs).images + self.assertEqual(len(images), 1) + self.assertEqual(images[0].size, (32, 32)) + + # Raw "pt" output is the denoised RGB tensor at the requested resolution. inputs = self.get_dummy_inputs(device) image = pipe(**inputs)[0] generated_image = image[0] - - # Output is raw RGB at the requested resolution. self.assertEqual(generated_image.shape, (3, 32, 32)) expected_image = torch.zeros(3, 32, 32) max_diff = np.abs(generated_image.cpu().numpy() - expected_image.numpy()).max() @@ -268,8 +273,8 @@ def callback_inputs_all(pipe, i, t, callback_kwargs): def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): # Overridden: the mixin version calls assert_mean_pixel_difference, which assumes HWC image - # arrays. Pixel-space PRX has no VAE and returns raw (C, H, W) tensors ("pt"), so we compare - # tensors directly instead of going through PIL. + # arrays. Pixel-space PRX has no VAE; compare raw (C, H, W) tensors directly ("pt") instead of + # going through PIL. if not self.test_attention_slicing: return