# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. import io import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Tuple from PIL import Image as PIL_Image from llama_stack.models.llama.datatypes import ( BuiltinTool, RawContent, RawMediaItem, RawMessage, RawTextItem, Role, StopReason, ToolCall, ToolPromptFormat, ) from .tokenizer import Tokenizer from .tool_utils import ToolUtils @dataclass class VisionInput: mask: List[List[int]] images: List[PIL_Image.Image] @dataclass class LLMInput: tokens: List[int] vision: Optional[VisionInput] = None def role_str(role: Role) -> str: role_strs = { Role.user: "user", Role.system: "system", Role.tool: "ipython", # special Role.assistant: "assistant", } return role_strs[role] class ChatFormat: possible_headers: Dict[Role, str] def __init__(self, tokenizer: Tokenizer): self.tokenizer = tokenizer self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role} self.vision_token = self.tokenizer.special_tokens["<|image|>"] def _encode_header(self, role: str) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False)) tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) return tokens def encode_content(self, content: RawContent) -> LLMInput: tokens, images = self._encode_content(content, bos=True) return self._model_input_from_tokens_images(tokens, images) def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]: tokens = [] images = [] added_bos = False def _process(c): nonlocal added_bos, bos if isinstance(c, str) or isinstance(c, RawTextItem): if isinstance(c, RawTextItem): c = c.text tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False)) added_bos = True elif isinstance(c, RawMediaItem): bos = False if added_bos else bos if bos: tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) added_bos = True tokens.append(self.vision_token) bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data image = PIL_Image.open(bytes_io) image = image.convert("RGB") images.append(image) if isinstance(content, list): for c in content: _process(c) else: _process(content) return tokens, images def encode_message( self, message: RawMessage, tool_prompt_format: ToolPromptFormat ) -> Tuple[List[int], List[PIL_Image.Image]]: tokens = self._encode_header(message.role) images = [] def _process_content(c): toks, imgs = self._encode_content(c) tokens.extend(toks) images.extend(imgs) if ( message.role == "assistant" and len(message.tool_calls) > 0 and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter ): tokens.append(self.tokenizer.special_tokens["<|python_tag|>"]) _process_content(message.content) if message.role == "user" and message.context is not None: # This is RAG context; why is it here in the chat format? I don't think # this is needed and can be moved upwards _process_content("\n\n") _process_content(message.context) if message.role == "assistant": for t in message.tool_calls: content = ToolUtils.encode_tool_call(t, tool_prompt_format) _process_content(content) eom = False if message.role == "assistant": eom = message.stop_reason == StopReason.end_of_message tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"]) return tokens, images def encode_dialog_prompt( self, messages: List[RawMessage], tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> LLMInput: tool_prompt_format = tool_prompt_format or ToolPromptFormat.json tokens = [] images = [] tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) for message in messages: toks, imgs = self.encode_message(message, tool_prompt_format) tokens.extend(toks) images.extend(imgs) # Add the start of an assistant message for the model to complete. tokens.extend(self._encode_header("assistant")) return self._model_input_from_tokens_images(tokens, images) # TODO(this should be generic, not only for assistant messages) def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage: content = self.tokenizer.decode(tokens) return self.decode_assistant_message_from_content(content, stop_reason) def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage: content = content.strip(" ") header_str = self.possible_headers[Role.assistant] if content.startswith(header_str): content = content[len(header_str) :] ipython = content.startswith("<|python_tag|>") if ipython: content = content[len("<|python_tag|>") :] if content.endswith("<|eot_id|>"): content = content[: -len("<|eot_id|>")] stop_reason = StopReason.end_of_turn elif content.endswith("<|eom_id|>"): content = content[: -len("<|eom_id|>")] stop_reason = StopReason.end_of_message tool_name = None tool_arguments = {} custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) if custom_tool_info is not None: tool_name, tool_arguments = custom_tool_info # Sometimes when agent has custom tools alongside builin tools # Agent responds for builtin tool calls in the format of the custom tools # This code tries to handle that case if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] tool_arguments = { "query": list(tool_arguments.values())[0], } else: builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) if builtin_tool_info is not None: tool_name, query = builtin_tool_info tool_arguments = { "query": query, } if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] elif ipython: tool_name = BuiltinTool.code_interpreter tool_arguments = { "code": content, } tool_calls = [] if tool_name is not None and tool_arguments is not None: call_id = str(uuid.uuid4()) tool_calls.append( ToolCall( call_id=call_id, tool_name=tool_name, arguments=tool_arguments, ) ) content = "" return RawMessage( role="assistant", content=content, stop_reason=stop_reason, tool_calls=tool_calls, ) def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput: vision_input = None if len(images) > 0: vision_input = VisionInput( mask=create_vision_mask(tokens, self.vision_token), images=images, ) return LLMInput( tokens=[128256 if token == self.vision_token else token for token in tokens], vision=vision_input, ) def create_vision_mask( tokens: List[int], vision_token: int, ) -> List[List[int]]: vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token] if len(vision_token_locations) == 0: return [] if len(vision_token_locations) == 1: # only one image present, unmask until end of sequence return [[vision_token_locations[0], -1]] vision_masks = [ [loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False) ] # last image will attend to all subsequent text vision_masks.append([vision_token_locations[-1], len(tokens)]) # if there are two or more consecutive vision tokens, # they should all attend to all subsequent # text present last_mask_end = vision_masks[-1][1] for vision_mask in vision_masks[::-1]: if vision_mask[0] == vision_mask[1] - 1: vision_mask[1] = last_mask_end last_mask_end = vision_mask[1] return vision_masks