llama-stack-mirror/llama_stack/models/llama/llama4/chat_format.py
ehhuang 2976b5d992
fix: OAI compat endpoint for meta reference inference provider (#1962)
Test plan:
python tests/verifications/generate_report.py --providers
fireworks,together,llama_meta_ref,openai

Co-authored-by: Eric Huang <erichuang@fb.com>
2025-04-17 11:16:04 -07:00

318 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image as PIL_Image
# TODO: either fork these or move them to the common package
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
RawMessage,
RawTextItem,
Role,
StopReason,
ToolCall,
ToolPromptFormat,
)
from ..llama3.tool_utils import ToolUtils
from .args import VisionArgs
from .datatypes import LLMInput
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
from .tokenizer import Tokenizer
def role_str(role: Role) -> str:
role_strs = {
Role.user: "user",
Role.system: "system",
Role.tool: "ipython", # special
Role.assistant: "assistant",
}
return role_strs[role]
@dataclass
class TransformedImage:
image_tiles: torch.Tensor
# is the aspect ratio needed anywhere?
aspect_ratio: Tuple[int, int]
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
if image.mode == "RGBA":
image.load() # for png.split()
new_img = PIL_Image.new("RGB", image.size, bg)
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
return new_img
return image.convert("RGB")
class ChatFormat:
possible_headers: Dict[Role, str]
def __init__(
self,
tokenizer: Tokenizer,
vision_args: Optional[VisionArgs] = None,
max_num_chunks: int = 16,
):
self.tokenizer = tokenizer
self.vision_args = vision_args
self.max_num_chunks = max_num_chunks
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
self.image_transform = None
self.dynamic_image_transform = None
if vision_args:
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
self.image_transform = ResizeNormalizeImageTransform(
vision_args.image_size.width, vision_args.image_size.height
)
def _encode_header(self, role: str) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
# TODO: need to check if this is correct
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_content(self, content: RawContent) -> LLMInput:
tokens, images = self._encode_content(content, bos=True)
return self._model_input_from_tokens_images(tokens, images)
def _encode_image(
self,
transformed_image: TransformedImage,
) -> List[int]:
assert self.vision_args is not None, "The model is not vision-enabled"
image_tensor = transformed_image.image_tiles
image_channels = image_tensor.shape[-3]
image_height = image_tensor.shape[-2]
image_width = image_tensor.shape[-1]
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
patch_height = self.vision_args.patch_size.height
patch_width = self.vision_args.patch_size.width
if image_height % patch_height != 0:
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
if image_width % patch_width != 0:
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
image_ar = transformed_image.aspect_ratio
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
if image_chunks == 1:
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
else:
ratio_h, ratio_w = image_ar
for _ in range(ratio_h):
for xx in range(ratio_w):
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
if xx < ratio_w - 1:
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
tokens += [self.tokenizer.special_tokens["<|image|>"]]
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
return tokens
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
tokens = []
tranformed_images = []
added_bos = False
def _process(c):
nonlocal added_bos, bos
if isinstance(c, str) or isinstance(c, RawTextItem):
if isinstance(c, RawTextItem):
c = c.text
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
added_bos = True
elif isinstance(c, RawMediaItem):
if not self.vision_args:
raise ValueError("The model is not vision-enabled, but a media item was found")
bos = False if added_bos else bos
if bos:
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
added_bos = True
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
image = PIL_Image.open(bytes_io)
image = convert_image_to_rgb(image)
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
if image_tiles.shape[0] > 1:
image_global = self.image_transform(image)
image_global = image_global.unsqueeze(0)
image_combine = torch.cat((image_tiles, image_global), dim=0)
image_tiles = image_combine
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
tokens.extend(self._encode_image(transformed_image))
tranformed_images.append(transformed_image)
if isinstance(content, list):
for c in content:
_process(c)
else:
_process(content)
return tokens, tranformed_images
def encode_message(
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
) -> Tuple[List[int], List[TransformedImage]]:
tokens = self._encode_header(message.role)
images = []
def _process_content(c):
toks, imgs = self._encode_content(c)
tokens.extend(toks)
images.extend(imgs)
_process_content(message.content)
if message.role == "user" and message.context is not None:
# This is RAG context; why is it here in the chat format? I don't think
# this is needed and can be moved upwards
_process_content("\n\n")
_process_content(message.context)
if message.role == "assistant":
for t in message.tool_calls:
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False
if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images
def encode_dialog_prompt(
self,
messages: List[RawMessage],
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> LLMInput:
tokens = []
images = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in messages:
toks, imgs = self.encode_message(message, tool_prompt_format)
tokens.extend(toks)
images.extend(imgs)
# Add the start of an assistant message for the model to complete.
tokens.extend(self._encode_header("assistant"))
return self._model_input_from_tokens_images(tokens, images)
# TODO(this should be generic, not only for assistant messages)
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
content = self.tokenizer.decode(tokens)
return self.decode_assistant_message_from_content(content, stop_reason)
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
content = content.strip(" ")
header_str = self.possible_headers[Role.assistant]
if content.startswith(header_str):
content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn
elif content.endswith("<|eom|>"):
content = content[: -len("<|eom|>")]
stop_reason = StopReason.end_of_message
tool_name = None
tool_arguments = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
tool_name, query = builtin_tool_info
tool_arguments = {
"query": query,
}
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = []
if tool_name is not None and tool_arguments is not None:
call_id = str(uuid.uuid4())
tool_calls.append(
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
return RawMessage(
role="assistant",
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
)
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
return LLMInput(
tokens=tokens,
images=[x.image_tiles for x in images] if len(images) > 0 else None,
)