forked from phoenix-oss/llama-stack-mirror
		
	# What does this PR do? ## Test Plan LLAMA_STACK_CONFIG=http://localhost:5002 pytest -s -v tests/integration/inference --safety-shield meta-llama/Llama-Guard-3-8B --vision-model meta-llama/Llama-4-Scout-17B-16E-Instruct --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
		
			
				
	
	
		
			319 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			319 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) Meta Platforms, Inc. and affiliates.
 | |
| # All rights reserved.
 | |
| #
 | |
| # This source code is licensed under the terms described in the LICENSE file in
 | |
| # the root directory of this source tree.
 | |
| 
 | |
| import io
 | |
| import json
 | |
| import uuid
 | |
| from dataclasses import dataclass
 | |
| from typing import Dict, List, Optional, Tuple
 | |
| 
 | |
| import torch
 | |
| from PIL import Image as PIL_Image
 | |
| 
 | |
| # TODO: either fork these or move them to the common package
 | |
| from ..datatypes import (
 | |
|     BuiltinTool,
 | |
|     RawContent,
 | |
|     RawMediaItem,
 | |
|     RawMessage,
 | |
|     RawTextItem,
 | |
|     Role,
 | |
|     StopReason,
 | |
|     ToolCall,
 | |
|     ToolPromptFormat,
 | |
| )
 | |
| from ..llama3.tool_utils import ToolUtils
 | |
| from .args import VisionArgs
 | |
| from .datatypes import LLMInput
 | |
| from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
 | |
| from .tokenizer import Tokenizer
 | |
| 
 | |
| 
 | |
| def role_str(role: Role) -> str:
 | |
|     role_strs = {
 | |
|         Role.user: "user",
 | |
|         Role.system: "system",
 | |
|         Role.tool: "ipython",  # special
 | |
|         Role.assistant: "assistant",
 | |
|     }
 | |
|     return role_strs[role]
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class TransformedImage:
 | |
|     image_tiles: torch.Tensor
 | |
|     # is the aspect ratio needed anywhere?
 | |
|     aspect_ratio: Tuple[int, int]
 | |
| 
 | |
| 
 | |
| def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
 | |
|     if image.mode == "RGBA":
 | |
|         image.load()  # for png.split()
 | |
|         new_img = PIL_Image.new("RGB", image.size, bg)
 | |
|         new_img.paste(image, mask=image.split()[3])  # 3 is the alpha channel
 | |
|         return new_img
 | |
|     return image.convert("RGB")
 | |
| 
 | |
| 
 | |
| class ChatFormat:
 | |
|     possible_headers: Dict[Role, str]
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         tokenizer: Tokenizer,
 | |
|         vision_args: Optional[VisionArgs] = None,
 | |
|         max_num_chunks: int = 16,
 | |
|     ):
 | |
|         self.tokenizer = tokenizer
 | |
|         self.vision_args = vision_args
 | |
|         self.max_num_chunks = max_num_chunks
 | |
| 
 | |
|         self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
 | |
| 
 | |
|         self.image_transform = None
 | |
|         self.dynamic_image_transform = None
 | |
|         if vision_args:
 | |
|             self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
 | |
|             self.image_transform = ResizeNormalizeImageTransform(
 | |
|                 vision_args.image_size.width, vision_args.image_size.height
 | |
|             )
 | |
| 
 | |
|     def _encode_header(self, role: str) -> List[int]:
 | |
|         tokens = []
 | |
|         tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
 | |
| 
 | |
|         # TODO: need to check if this is correct
 | |
|         tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
 | |
|         tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
 | |
|         tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
 | |
|         return tokens
 | |
| 
 | |
|     def encode_content(self, content: RawContent) -> LLMInput:
 | |
|         tokens, images = self._encode_content(content, bos=True)
 | |
|         return self._model_input_from_tokens_images(tokens, images)
 | |
| 
 | |
|     def _encode_image(
 | |
|         self,
 | |
|         transformed_image: TransformedImage,
 | |
|     ) -> List[int]:
 | |
|         assert self.vision_args is not None, "The model is not vision-enabled"
 | |
| 
 | |
|         image_tensor = transformed_image.image_tiles
 | |
|         image_channels = image_tensor.shape[-3]
 | |
|         image_height = image_tensor.shape[-2]
 | |
|         image_width = image_tensor.shape[-1]
 | |
|         image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
 | |
| 
 | |
|         patch_height = self.vision_args.patch_size.height
 | |
|         patch_width = self.vision_args.patch_size.width
 | |
| 
 | |
|         if image_height % patch_height != 0:
 | |
|             raise ValueError(f"{image_height=} not divisible by {patch_height=}")
 | |
|         if image_width % patch_width != 0:
 | |
|             raise ValueError(f"{image_width=} not divisible by {patch_width=}")
 | |
| 
 | |
|         ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
 | |
|         n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
 | |
| 
 | |
|         image_ar = transformed_image.aspect_ratio
 | |
|         tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
 | |
|         if image_chunks == 1:
 | |
|             tokens += [self.tokenizer.special_tokens["<|image|>"]]
 | |
|             tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
 | |
|             tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
 | |
|         else:
 | |
|             ratio_h, ratio_w = image_ar
 | |
|             for _ in range(ratio_h):
 | |
|                 for xx in range(ratio_w):
 | |
|                     tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
 | |
|                     if xx < ratio_w - 1:
 | |
|                         tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
 | |
| 
 | |
|                 tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
 | |
| 
 | |
|             tokens += [self.tokenizer.special_tokens["<|image|>"]]
 | |
|             tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
 | |
|             tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
 | |
| 
 | |
|         return tokens
 | |
| 
 | |
|     def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
 | |
|         tokens = []
 | |
|         tranformed_images = []
 | |
| 
 | |
|         added_bos = False
 | |
| 
 | |
|         def _process(c):
 | |
|             nonlocal added_bos, bos
 | |
| 
 | |
|             if isinstance(c, str) or isinstance(c, RawTextItem):
 | |
|                 if isinstance(c, RawTextItem):
 | |
|                     c = c.text
 | |
|                 tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
 | |
|                 added_bos = True
 | |
| 
 | |
|             elif isinstance(c, RawMediaItem):
 | |
|                 if not self.vision_args:
 | |
|                     raise ValueError("The model is not vision-enabled, but a media item was found")
 | |
| 
 | |
|                 bos = False if added_bos else bos
 | |
|                 if bos:
 | |
|                     tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
 | |
|                     added_bos = True
 | |
| 
 | |
|                 bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
 | |
|                 image = PIL_Image.open(bytes_io)
 | |
|                 image = convert_image_to_rgb(image)
 | |
|                 image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
 | |
| 
 | |
|                 if image_tiles.shape[0] > 1:
 | |
|                     image_global = self.image_transform(image)
 | |
|                     image_global = image_global.unsqueeze(0)
 | |
|                     image_combine = torch.cat((image_tiles, image_global), dim=0)
 | |
|                     image_tiles = image_combine
 | |
| 
 | |
|                 transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
 | |
|                 tokens.extend(self._encode_image(transformed_image))
 | |
|                 tranformed_images.append(transformed_image)
 | |
| 
 | |
|         if isinstance(content, list):
 | |
|             for c in content:
 | |
|                 _process(c)
 | |
|         else:
 | |
|             _process(content)
 | |
| 
 | |
|         return tokens, tranformed_images
 | |
| 
 | |
|     def encode_message(
 | |
|         self, message: RawMessage, tool_prompt_format: ToolPromptFormat
 | |
|     ) -> Tuple[List[int], List[TransformedImage]]:
 | |
|         tokens = self._encode_header(message.role)
 | |
|         images = []
 | |
| 
 | |
|         def _process_content(c):
 | |
|             toks, imgs = self._encode_content(c)
 | |
|             tokens.extend(toks)
 | |
|             images.extend(imgs)
 | |
| 
 | |
|         _process_content(message.content)
 | |
| 
 | |
|         if message.role == "user" and message.context is not None:
 | |
|             # This is RAG context; why is it here in the chat format? I don't think
 | |
|             # this is needed and can be moved upwards
 | |
|             _process_content("\n\n")
 | |
|             _process_content(message.context)
 | |
| 
 | |
|         if message.role == "assistant":
 | |
|             for t in message.tool_calls:
 | |
|                 content = ToolUtils.encode_tool_call(t, tool_prompt_format)
 | |
|                 _process_content(content)
 | |
| 
 | |
|         # Tool calls and Tool Response messages should be eom
 | |
|         eom = False
 | |
|         if message.role == "assistant":
 | |
|             eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
 | |
|         elif message.role == "tool":
 | |
|             eom = True
 | |
| 
 | |
|         tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
 | |
|         return tokens, images
 | |
| 
 | |
|     def encode_dialog_prompt(
 | |
|         self,
 | |
|         messages: List[RawMessage],
 | |
|         tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
 | |
|     ) -> LLMInput:
 | |
|         tokens = []
 | |
|         images = []
 | |
|         tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
 | |
|         for message in messages:
 | |
|             toks, imgs = self.encode_message(message, tool_prompt_format)
 | |
|             tokens.extend(toks)
 | |
|             images.extend(imgs)
 | |
| 
 | |
|         # Add the start of an assistant message for the model to complete.
 | |
|         tokens.extend(self._encode_header("assistant"))
 | |
| 
 | |
|         return self._model_input_from_tokens_images(tokens, images)
 | |
| 
 | |
|     # TODO(this should be generic, not only for assistant messages)
 | |
|     def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
 | |
|         content = self.tokenizer.decode(tokens)
 | |
| 
 | |
|         return self.decode_assistant_message_from_content(content, stop_reason)
 | |
| 
 | |
|     def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
 | |
|         content = content.strip(" ")
 | |
|         header_str = self.possible_headers[Role.assistant]
 | |
|         if content.startswith(header_str):
 | |
|             content = content[len(header_str) :]
 | |
| 
 | |
|         ipython = content.startswith("<|python_start|>")
 | |
|         if ipython:
 | |
|             content = content[len("<|python_start|>") :]
 | |
|             content = content.replace("<|python_end|>", "")
 | |
| 
 | |
|         if content.endswith("<|eot|>"):
 | |
|             content = content[: -len("<|eot|>")]
 | |
|             stop_reason = StopReason.end_of_turn
 | |
|         elif content.endswith("<|eom|>"):
 | |
|             content = content[: -len("<|eom|>")]
 | |
|             stop_reason = StopReason.end_of_message
 | |
| 
 | |
|         tool_name = None
 | |
|         tool_arguments = {}
 | |
| 
 | |
|         custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
 | |
|         if custom_tool_info is not None:
 | |
|             tool_name, tool_arguments = custom_tool_info
 | |
|             # Sometimes when agent has custom tools alongside builin tools
 | |
|             # Agent responds for builtin tool calls in the format of the custom tools
 | |
|             # This code tries to handle that case
 | |
|             if tool_name in BuiltinTool.__members__:
 | |
|                 tool_name = BuiltinTool[tool_name]
 | |
|                 tool_arguments = {
 | |
|                     "query": list(tool_arguments.values())[0],
 | |
|                 }
 | |
|         else:
 | |
|             builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
 | |
|             if builtin_tool_info is not None:
 | |
|                 tool_name, query = builtin_tool_info
 | |
|                 tool_arguments = {
 | |
|                     "query": query,
 | |
|                 }
 | |
|                 if tool_name in BuiltinTool.__members__:
 | |
|                     tool_name = BuiltinTool[tool_name]
 | |
|             elif ipython:
 | |
|                 tool_name = BuiltinTool.code_interpreter
 | |
|                 tool_arguments = {
 | |
|                     "code": content,
 | |
|                 }
 | |
| 
 | |
|         tool_calls = []
 | |
|         if tool_name is not None and tool_arguments is not None:
 | |
|             call_id = str(uuid.uuid4())
 | |
|             tool_calls.append(
 | |
|                 ToolCall(
 | |
|                     call_id=call_id,
 | |
|                     tool_name=tool_name,
 | |
|                     arguments=tool_arguments,
 | |
|                     arguments_json=json.dumps(tool_arguments),
 | |
|                 )
 | |
|             )
 | |
|             content = ""
 | |
| 
 | |
|         return RawMessage(
 | |
|             role="assistant",
 | |
|             content=content,
 | |
|             stop_reason=stop_reason,
 | |
|             tool_calls=tool_calls,
 | |
|         )
 | |
| 
 | |
|     def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
 | |
|         return LLMInput(
 | |
|             tokens=tokens,
 | |
|             images=[x.image_tiles for x in images] if len(images) > 0 else None,
 | |
|         )
 |