diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c2d2757e13..544e755ea9 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -96,6 +96,7 @@ class Llama: def __init__( self, model_path: str, + clip_model_path: Optional[str] = None, *, # Model Params n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto", @@ -171,6 +172,7 @@ def __init__( log_filters: Optional[Sequence[str]] = None, log_filters_case_sensitive: bool = True, # Extra Params + chat_handler_kwargs: Dict[str, Any] = {}, **kwargs, # type: ignore ): """Load a llama.cpp model from `model_path`. @@ -706,6 +708,18 @@ def __init__( print(f"Failed to load metadata: {e}", file=sys.stderr) if self.verbose: + print(f"Model metadata: {self.metadata}", file=sys.stderr) + + if clip_model_path is not None: + if self.chat_handler is not None and self.verbose: + print("Warning: Both `chat_handler` and `clip_model_path` are not null. Chat handler will be overwritten.", flush = True) + + self.chat_handler = llama_chat_format.GenericMTMDChatHandler( + gguf_metadata = self.metadata, + clip_model_path = clip_model_path, + verbose = self.verbose, + **chat_handler_kwargs + ) print(f"Model desc: {self.model_desc}, " f"Model size: {self.model_size / (1024 * 1024):.2f} MB, " f"Model metadata: {self.metadata}", diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 0365d8f871..254195f95a 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2887,10 +2887,14 @@ def __init__( raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}") # Pre-compile Jinja template + if not hasattr(self, "chat_format") or self.chat_format is None: + self.chat_format = self.CHAT_FORMAT + + self._chat_format_parser_tags = [] self.chat_template = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, - ).from_string(self.CHAT_FORMAT) + ).from_string(self.chat_format) self._exit_stack = ExitStack() @@ -2992,13 +2996,13 @@ def _get_media_items(self, messages: List[llama_types.ChatCompletionRequestMessa media_items.append({"url": url, "type": "image"}) # 2. Audio Processing - elif content_type in ["audio_url", "input_audio"]: + elif content_type in ["audio", "audio_url", "input_audio"]: if not self.is_support_audio: raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support audio inputs.") # Case A: Handle custom/forward-compatible audio_url format - if content_type == "audio_url": - audio_url = content["audio_url"] + if content_type == "audio_url" or content_type == "audio": + audio_url = content[content_type] url = audio_url if isinstance(audio_url, str) else audio_url["url"] media_items.append({"url": url, "type": "audio"}) # Case B: Handle OpenAI standard input_audio format @@ -3117,6 +3121,13 @@ def _process_mtmd_prompt( tool_choice=tool_choice, **getattr(self, 'extra_template_arguments', {}) ) + + for tag in self._chat_format_parser_tags: + if tag not in text: + continue + + text = text.replace(tag, media_marker) + # Replace image_url by media_marker in text for item in media_items: text = text.replace(item["url"], media_marker) @@ -3828,6 +3839,43 @@ def from_pretrained( **kwargs, ) +class GenericMTMDChatHandler(MTMDChatHandler): + KNOWN_MEDIA_TAGS = [ + "<|image_pad|>", + "<|audio_pad|>", + "<|video_pad|>", + "<|image|>", + "<|audio|>", + "<|video|>", + "[IMG]" + ] + + def __init__( + self, + gguf_metadata: Dict[str, Any], + clip_model_path: str, + verbose: bool = True, + **kwargs + ) -> None: + self.model_metadata = gguf_metadata + self.chat_format = self.model_metadata.get("tokenizer.chat_template", None) + + if verbose: + print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True) + + if self.chat_format is None: + raise ValueError("Failed to get model chat template automatically.") + + super().__init__(clip_model_path = clip_model_path, verbose = verbose, **kwargs) + + def __call__(self, **kwargs): + self._chat_format_parser_tags = [tag for tag in self.KNOWN_MEDIA_TAGS if tag in self.chat_format] + + if self.verbose: + print(f"{self.log_prefix} - Start processing") + + # Use parent implementation + return super().__call__(**kwargs) class Llava15ChatHandler(MTMDChatHandler): CHAT_FORMAT = (