forked from phoenix-oss/llama-stack-mirror
320 lines
12 KiB
Python
320 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import io
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from PIL import Image as PIL_Image
|
|
|
|
# TODO: either fork these or move them to the common package
|
|
from llama_stack.models.llama.datatypes import (
|
|
BuiltinTool,
|
|
RawContent,
|
|
RawMediaItem,
|
|
RawMessage,
|
|
RawTextItem,
|
|
Role,
|
|
StopReason,
|
|
ToolCall,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
|
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
|
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
|
LLMInput,
|
|
)
|
|
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
|
ResizeNormalizeImageTransform,
|
|
VariableSizeImageTransform,
|
|
)
|
|
|
|
from .tokenizer import Tokenizer
|
|
|
|
|
|
def role_str(role: Role) -> str:
|
|
role_strs = {
|
|
Role.user: "user",
|
|
Role.system: "system",
|
|
Role.tool: "ipython", # special
|
|
Role.assistant: "assistant",
|
|
}
|
|
return role_strs[role]
|
|
|
|
|
|
@dataclass
|
|
class TransformedImage:
|
|
image_tiles: torch.Tensor
|
|
# is the aspect ratio needed anywhere?
|
|
aspect_ratio: Tuple[int, int]
|
|
|
|
|
|
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
|
if image.mode == "RGBA":
|
|
image.load() # for png.split()
|
|
new_img = PIL_Image.new("RGB", image.size, bg)
|
|
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
|
return new_img
|
|
return image.convert("RGB")
|
|
|
|
|
|
class ChatFormat:
|
|
possible_headers: Dict[Role, str]
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer: Tokenizer,
|
|
vision_args: Optional[VisionArgs] = None,
|
|
max_num_chunks: int = 16,
|
|
):
|
|
self.tokenizer = tokenizer
|
|
self.vision_args = vision_args
|
|
self.max_num_chunks = max_num_chunks
|
|
|
|
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
|
|
|
|
self.image_transform = None
|
|
self.dynamic_image_transform = None
|
|
if vision_args:
|
|
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
|
self.image_transform = ResizeNormalizeImageTransform(
|
|
vision_args.image_size.width, vision_args.image_size.height
|
|
)
|
|
|
|
def _encode_header(self, role: str) -> List[int]:
|
|
tokens = []
|
|
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
|
|
|
# TODO: need to check if this is correct
|
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
|
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
|
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
|
return tokens
|
|
|
|
def encode_content(self, content: RawContent) -> LLMInput:
|
|
tokens, images = self._encode_content(content, bos=True)
|
|
return self._model_input_from_tokens_images(tokens, images)
|
|
|
|
def _encode_image(
|
|
self,
|
|
transformed_image: TransformedImage,
|
|
) -> List[int]:
|
|
assert self.vision_args is not None, "The model is not vision-enabled"
|
|
|
|
image_tensor = transformed_image.image_tiles
|
|
image_channels = image_tensor.shape[-3]
|
|
image_height = image_tensor.shape[-2]
|
|
image_width = image_tensor.shape[-1]
|
|
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
|
|
|
|
patch_height = self.vision_args.patch_size.height
|
|
patch_width = self.vision_args.patch_size.width
|
|
|
|
if image_height % patch_height != 0:
|
|
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
|
|
if image_width % patch_width != 0:
|
|
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
|
|
|
|
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
|
|
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
|
|
|
|
image_ar = transformed_image.aspect_ratio
|
|
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
|
|
if image_chunks == 1:
|
|
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
|
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
|
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
|
else:
|
|
ratio_h, ratio_w = image_ar
|
|
for _ in range(ratio_h):
|
|
for xx in range(ratio_w):
|
|
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
|
if xx < ratio_w - 1:
|
|
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
|
|
|
|
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
|
|
|
|
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
|
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
|
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
|
|
|
return tokens
|
|
|
|
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
|
tokens = []
|
|
tranformed_images = []
|
|
|
|
added_bos = False
|
|
|
|
def _process(c):
|
|
nonlocal added_bos, bos
|
|
|
|
if isinstance(c, str) or isinstance(c, RawTextItem):
|
|
if isinstance(c, RawTextItem):
|
|
c = c.text
|
|
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
|
added_bos = True
|
|
|
|
elif isinstance(c, RawMediaItem):
|
|
if not self.vision_args:
|
|
raise ValueError("The model is not vision-enabled, but a media item was found")
|
|
|
|
bos = False if added_bos else bos
|
|
if bos:
|
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
|
added_bos = True
|
|
|
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
|
image = PIL_Image.open(bytes_io)
|
|
image = convert_rgba_to_rgb(image)
|
|
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
|
|
|
if image_tiles.shape[0] > 1:
|
|
image_global = self.image_transform(image)
|
|
image_global = image_global.unsqueeze(0)
|
|
image_combine = torch.cat((image_tiles, image_global), dim=0)
|
|
image_tiles = image_combine
|
|
|
|
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
|
|
tokens.extend(self._encode_image(transformed_image))
|
|
tranformed_images.append(transformed_image)
|
|
|
|
if isinstance(content, list):
|
|
for c in content:
|
|
_process(c)
|
|
else:
|
|
_process(content)
|
|
|
|
return tokens, tranformed_images
|
|
|
|
def encode_message(
|
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
|
) -> Tuple[List[int], List[TransformedImage]]:
|
|
tokens = self._encode_header(message.role)
|
|
images = []
|
|
|
|
def _process_content(c):
|
|
toks, imgs = self._encode_content(c)
|
|
tokens.extend(toks)
|
|
images.extend(imgs)
|
|
|
|
_process_content(message.content)
|
|
|
|
if message.role == "user" and message.context is not None:
|
|
# This is RAG context; why is it here in the chat format? I don't think
|
|
# this is needed and can be moved upwards
|
|
_process_content("\n\n")
|
|
_process_content(message.context)
|
|
|
|
if message.role == "assistant":
|
|
for t in message.tool_calls:
|
|
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
|
_process_content(content)
|
|
|
|
eom = False
|
|
if message.role == "assistant":
|
|
eom = message.stop_reason == StopReason.end_of_message
|
|
|
|
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
|
return tokens, images
|
|
|
|
def encode_dialog_prompt(
|
|
self,
|
|
messages: List[RawMessage],
|
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
|
) -> LLMInput:
|
|
tokens = []
|
|
images = []
|
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
|
for message in messages:
|
|
toks, imgs = self.encode_message(message, tool_prompt_format)
|
|
tokens.extend(toks)
|
|
images.extend(imgs)
|
|
|
|
# Add the start of an assistant message for the model to complete.
|
|
tokens.extend(self._encode_header("assistant"))
|
|
|
|
return self._model_input_from_tokens_images(tokens, images)
|
|
|
|
# TODO(this should be generic, not only for assistant messages)
|
|
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
|
content = self.tokenizer.decode(tokens)
|
|
|
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
|
|
|
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
|
content = content.strip(" ")
|
|
header_str = self.possible_headers[Role.assistant]
|
|
if content.startswith(header_str):
|
|
content = content[len(header_str) :]
|
|
|
|
ipython = content.startswith("<|python_start|>")
|
|
if ipython:
|
|
content = content[len("<|python_start|>") :]
|
|
content = content.replace("<|python_end|>", "")
|
|
|
|
if content.endswith("<|eot|>"):
|
|
content = content[: -len("<|eot|>")]
|
|
stop_reason = StopReason.end_of_turn
|
|
elif content.endswith("<|eom|>"):
|
|
content = content[: -len("<|eom|>")]
|
|
stop_reason = StopReason.end_of_message
|
|
|
|
tool_name = None
|
|
tool_arguments = {}
|
|
|
|
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
|
if custom_tool_info is not None:
|
|
tool_name, tool_arguments = custom_tool_info
|
|
# Sometimes when agent has custom tools alongside builin tools
|
|
# Agent responds for builtin tool calls in the format of the custom tools
|
|
# This code tries to handle that case
|
|
if tool_name in BuiltinTool.__members__:
|
|
tool_name = BuiltinTool[tool_name]
|
|
tool_arguments = {
|
|
"query": list(tool_arguments.values())[0],
|
|
}
|
|
else:
|
|
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
|
if builtin_tool_info is not None:
|
|
tool_name, query = builtin_tool_info
|
|
tool_arguments = {
|
|
"query": query,
|
|
}
|
|
if tool_name in BuiltinTool.__members__:
|
|
tool_name = BuiltinTool[tool_name]
|
|
elif ipython:
|
|
tool_name = BuiltinTool.code_interpreter
|
|
tool_arguments = {
|
|
"code": content,
|
|
}
|
|
|
|
tool_calls = []
|
|
if tool_name is not None and tool_arguments is not None:
|
|
call_id = str(uuid.uuid4())
|
|
tool_calls.append(
|
|
ToolCall(
|
|
call_id=call_id,
|
|
tool_name=tool_name,
|
|
arguments=tool_arguments,
|
|
)
|
|
)
|
|
content = ""
|
|
|
|
return RawMessage(
|
|
role="assistant",
|
|
content=content,
|
|
stop_reason=stop_reason,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
|
return LLMInput(
|
|
tokens=tokens,
|
|
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
|
)
|