mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
chore: remove dependency on llama_models completely
This commit is contained in:
parent
7131d5ddeb
commit
7529cbfcc9
30 changed files with 131325 additions and 53 deletions
|
@ -11,16 +11,128 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from io import BytesIO
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
# import all for backwards compatibility
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from llama_models.datatypes import * # noqa: F403
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||||
|
|
||||||
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
|
# the llama3 series of models.
|
||||||
|
|
||||||
|
|
||||||
|
class Role(Enum):
|
||||||
|
system = "system"
|
||||||
|
user = "user"
|
||||||
|
assistant = "assistant"
|
||||||
|
tool = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinTool(Enum):
|
||||||
|
brave_search = "brave_search"
|
||||||
|
wolfram_alpha = "wolfram_alpha"
|
||||||
|
photogen = "photogen"
|
||||||
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
|
||||||
|
Primitive = Union[str, int, float, bool, None]
|
||||||
|
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
call_id: str
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
arguments: Dict[str, RecursiveType]
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ToolPromptFormat(Enum):
|
||||||
|
"""Prompt format for calling custom / zero shot tools.
|
||||||
|
|
||||||
|
:cvar json: JSON format for calling tools. It takes the form:
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function" : {
|
||||||
|
"name": "function_name",
|
||||||
|
"description": "function_description",
|
||||||
|
"parameters": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||||
|
<function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||||
|
evaluated to a list. Each element in the list is a function call. Example:
|
||||||
|
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
json = "json"
|
||||||
|
function_tag = "function_tag"
|
||||||
|
python_list = "python_list"
|
||||||
|
|
||||||
|
|
||||||
|
class StopReason(Enum):
|
||||||
|
end_of_turn = "end_of_turn"
|
||||||
|
end_of_message = "end_of_message"
|
||||||
|
out_of_tokens = "out_of_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
class RawMediaItem(BaseModel):
|
||||||
|
type: Literal["image"] = "image"
|
||||||
|
data: bytes | BytesIO
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@field_serializer("data")
|
||||||
|
def serialize_data(self, data: Optional[bytes], _info):
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return base64.b64encode(data).decode("utf-8")
|
||||||
|
|
||||||
|
@field_validator("data", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_data(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
return base64.b64decode(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class RawTextItem(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||||
|
|
||||||
|
RawContent = str | RawContentItem | List[RawContentItem]
|
||||||
|
|
||||||
|
|
||||||
|
class RawMessage(BaseModel):
|
||||||
|
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||||
|
content: RawContent
|
||||||
|
|
||||||
|
# This is for RAG but likely should be absorbed into content
|
||||||
|
context: Optional[RawContent] = None
|
||||||
|
|
||||||
|
# These are for the output message coming from the assistant
|
||||||
|
stop_reason: Optional[StopReason] = None
|
||||||
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
register_schema(ToolCall)
|
||||||
|
|
||||||
|
|
||||||
|
|
5
llama_stack/models/llama/llama3/__init__.py
Normal file
5
llama_stack/models/llama/llama3/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
82
llama_stack/models/llama/llama3/args.py
Normal file
82
llama_stack/models/llama/llama3/args.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationScheme(Enum):
|
||||||
|
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuantizationArgs:
|
||||||
|
scheme: Optional[QuantizationScheme] = None
|
||||||
|
group_size: Optional[int] = None
|
||||||
|
spinquant: bool = False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k == "scheme":
|
||||||
|
setattr(self, k, QuantizationScheme(v))
|
||||||
|
else:
|
||||||
|
if hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAArgs:
|
||||||
|
rank: int
|
||||||
|
scale: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArgs:
|
||||||
|
dim: int = 4096
|
||||||
|
n_layers: int = 32
|
||||||
|
n_heads: int = 32
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
vocab_size: int = -1
|
||||||
|
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||||
|
ffn_dim_multiplier: Optional[float] = None
|
||||||
|
norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 500000
|
||||||
|
use_scaled_rope: bool = False
|
||||||
|
|
||||||
|
max_batch_size: int = 32
|
||||||
|
max_seq_len: int = 2048
|
||||||
|
|
||||||
|
# vision model params
|
||||||
|
vision_chunk_size: int = -1 # image resolution for image models
|
||||||
|
vision_max_num_chunks: int = 4
|
||||||
|
vision_num_cross_attention_layers: int = -1
|
||||||
|
|
||||||
|
quantization_args: Optional[QuantizationArgs] = None
|
||||||
|
lora_args: Optional[LoRAArgs] = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k == "lora_args":
|
||||||
|
setattr(self, k, LoRAArgs(**v))
|
||||||
|
elif k == "quantization_args":
|
||||||
|
setattr(self, k, QuantizationArgs(**v))
|
||||||
|
else:
|
||||||
|
if hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
if self.n_kv_heads is None:
|
||||||
|
self.n_kv_heads = self.n_heads
|
||||||
|
assert self.n_kv_heads <= self.n_heads
|
||||||
|
assert self.n_heads % self.n_kv_heads == 0
|
||||||
|
assert self.dim % self.n_heads == 0
|
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
@ -0,0 +1,282 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
RawContent,
|
||||||
|
RawMediaItem,
|
||||||
|
RawMessage,
|
||||||
|
RawTextItem,
|
||||||
|
Role,
|
||||||
|
StopReason,
|
||||||
|
ToolCall,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
from .tool_utils import ToolUtils
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VisionInput:
|
||||||
|
mask: List[List[int]]
|
||||||
|
images: List[PIL_Image.Image]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMInput:
|
||||||
|
tokens: List[int]
|
||||||
|
vision: Optional[VisionInput] = None
|
||||||
|
|
||||||
|
|
||||||
|
def role_str(role: Role) -> str:
|
||||||
|
role_strs = {
|
||||||
|
Role.user: "user",
|
||||||
|
Role.system: "system",
|
||||||
|
Role.tool: "ipython", # special
|
||||||
|
Role.assistant: "assistant",
|
||||||
|
}
|
||||||
|
return role_strs[role]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatFormat:
|
||||||
|
possible_headers: Dict[Role, str]
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: Tokenizer):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||||
|
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||||
|
|
||||||
|
def _encode_header(self, role: str) -> List[int]:
|
||||||
|
tokens = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||||
|
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def encode_content(self, content: RawContent) -> LLMInput:
|
||||||
|
tokens, images = self._encode_content(content, bos=True)
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
|
||||||
|
added_bos = False
|
||||||
|
|
||||||
|
def _process(c):
|
||||||
|
nonlocal added_bos, bos
|
||||||
|
|
||||||
|
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||||
|
if isinstance(c, RawTextItem):
|
||||||
|
c = c.text
|
||||||
|
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||||
|
added_bos = True
|
||||||
|
|
||||||
|
elif isinstance(c, RawMediaItem):
|
||||||
|
bos = False if added_bos else bos
|
||||||
|
if bos:
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
added_bos = True
|
||||||
|
tokens.append(self.vision_token)
|
||||||
|
|
||||||
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||||
|
image = PIL_Image.open(bytes_io)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
for c in content:
|
||||||
|
_process(c)
|
||||||
|
else:
|
||||||
|
_process(content)
|
||||||
|
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_message(
|
||||||
|
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||||
|
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||||
|
tokens = self._encode_header(message.role)
|
||||||
|
images = []
|
||||||
|
|
||||||
|
def _process_content(c):
|
||||||
|
toks, imgs = self._encode_content(c)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
message.role == "assistant"
|
||||||
|
and len(message.tool_calls) > 0
|
||||||
|
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||||
|
):
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||||
|
|
||||||
|
_process_content(message.content)
|
||||||
|
|
||||||
|
if message.role == "user" and message.context is not None:
|
||||||
|
# This is RAG context; why is it here in the chat format? I don't think
|
||||||
|
# this is needed and can be moved upwards
|
||||||
|
_process_content("\n\n")
|
||||||
|
_process_content(message.context)
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
for t in message.tool_calls:
|
||||||
|
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||||
|
_process_content(content)
|
||||||
|
|
||||||
|
eom = False
|
||||||
|
if message.role == "assistant":
|
||||||
|
eom = message.stop_reason == StopReason.end_of_message
|
||||||
|
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||||
|
return tokens, images
|
||||||
|
|
||||||
|
def encode_dialog_prompt(
|
||||||
|
self,
|
||||||
|
messages: List[RawMessage],
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
) -> LLMInput:
|
||||||
|
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||||
|
tokens = []
|
||||||
|
images = []
|
||||||
|
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||||
|
for message in messages:
|
||||||
|
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||||
|
tokens.extend(toks)
|
||||||
|
images.extend(imgs)
|
||||||
|
|
||||||
|
# Add the start of an assistant message for the model to complete.
|
||||||
|
tokens.extend(self._encode_header("assistant"))
|
||||||
|
|
||||||
|
return self._model_input_from_tokens_images(tokens, images)
|
||||||
|
|
||||||
|
# TODO(this should be generic, not only for assistant messages)
|
||||||
|
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = self.tokenizer.decode(tokens)
|
||||||
|
|
||||||
|
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||||
|
|
||||||
|
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||||
|
content = content.strip(" ")
|
||||||
|
header_str = self.possible_headers[Role.assistant]
|
||||||
|
if content.startswith(header_str):
|
||||||
|
content = content[len(header_str) :]
|
||||||
|
|
||||||
|
ipython = content.startswith("<|python_tag|>")
|
||||||
|
if ipython:
|
||||||
|
content = content[len("<|python_tag|>") :]
|
||||||
|
|
||||||
|
if content.endswith("<|eot_id|>"):
|
||||||
|
content = content[: -len("<|eot_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif content.endswith("<|eom_id|>"):
|
||||||
|
content = content[: -len("<|eom_id|>")]
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
tool_name = None
|
||||||
|
tool_arguments = {}
|
||||||
|
|
||||||
|
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||||
|
if custom_tool_info is not None:
|
||||||
|
tool_name, tool_arguments = custom_tool_info
|
||||||
|
# Sometimes when agent has custom tools alongside builin tools
|
||||||
|
# Agent responds for builtin tool calls in the format of the custom tools
|
||||||
|
# This code tries to handle that case
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
tool_arguments = {
|
||||||
|
"query": list(tool_arguments.values())[0],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||||
|
if builtin_tool_info is not None:
|
||||||
|
tool_name, query = builtin_tool_info
|
||||||
|
tool_arguments = {
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
|
if tool_name in BuiltinTool.__members__:
|
||||||
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
elif ipython:
|
||||||
|
tool_name = BuiltinTool.code_interpreter
|
||||||
|
tool_arguments = {
|
||||||
|
"code": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
if tool_name is not None and tool_arguments is not None:
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
arguments=tool_arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
return RawMessage(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||||
|
vision_input = None
|
||||||
|
if len(images) > 0:
|
||||||
|
vision_input = VisionInput(
|
||||||
|
mask=create_vision_mask(tokens, self.vision_token),
|
||||||
|
images=images,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMInput(
|
||||||
|
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||||
|
vision=vision_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_vision_mask(
|
||||||
|
tokens: List[int],
|
||||||
|
vision_token: int,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||||
|
if len(vision_token_locations) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if len(vision_token_locations) == 1:
|
||||||
|
# only one image present, unmask until end of sequence
|
||||||
|
return [[vision_token_locations[0], -1]]
|
||||||
|
vision_masks = [
|
||||||
|
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||||
|
]
|
||||||
|
# last image will attend to all subsequent text
|
||||||
|
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||||
|
|
||||||
|
# if there are two or more consecutive vision tokens,
|
||||||
|
# they should all attend to all subsequent
|
||||||
|
# text present
|
||||||
|
last_mask_end = vision_masks[-1][1]
|
||||||
|
for vision_mask in vision_masks[::-1]:
|
||||||
|
if vision_mask[0] == vision_mask[1] - 1:
|
||||||
|
vision_mask[1] = last_mask_end
|
||||||
|
last_mask_end = vision_mask[1]
|
||||||
|
return vision_masks
|
|
@ -14,20 +14,19 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from termcolor import colored
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from termcolor import colored
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
|
||||||
|
|
||||||
from . import template_data
|
from . import template_data
|
||||||
|
from .chat_format import ChatFormat
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
ToolResponseGenerator,
|
ToolResponseGenerator,
|
||||||
)
|
)
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
315
llama_stack/models/llama/llama3/model.py
Normal file
315
llama_stack/models/llama/llama3/model.py
Normal file
|
@ -0,0 +1,315 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import fairscale.nn.model_parallel.initialize as fs_init
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.layers import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ..api import ModelArgs
|
||||||
|
|
||||||
|
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||||
|
# dependencies. These dependencies are not part of the default dependencies
|
||||||
|
# (requirements.txt) of the `llama-models` package.
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
return output * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Values obtained from grid search
|
||||||
|
scale_factor = 8
|
||||||
|
low_freq_factor = 1
|
||||||
|
high_freq_factor = 4
|
||||||
|
old_context_len = 8192 # original llama3 length
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
|
wavelen = 2 * torch.pi / freqs
|
||||||
|
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||||
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
|
return torch.where(
|
||||||
|
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||||
|
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||||
|
new_freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||||
|
if use_scaled:
|
||||||
|
freqs = apply_scaling(freqs)
|
||||||
|
freqs = torch.outer(t, freqs)
|
||||||
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||||
|
ndim = x.ndim
|
||||||
|
assert 0 <= 1 < ndim
|
||||||
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||||
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||||
|
return freqs_cis.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(
|
||||||
|
xq: torch.Tensor,
|
||||||
|
xk: torch.Tensor,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||||
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||||
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||||
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||||
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||||
|
bs, slen, n_kv_heads, head_dim = x.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return x
|
||||||
|
return (
|
||||||
|
x[:, :, :, None, :]
|
||||||
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||||
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
|
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||||
|
self.n_local_heads = args.n_heads // model_parallel_size
|
||||||
|
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||||
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
||||||
|
self.wq = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wk = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wv = ColumnParallelLinear(
|
||||||
|
args.dim,
|
||||||
|
self.n_kv_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
gather_output=False,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
self.wo = RowParallelLinear(
|
||||||
|
args.n_heads * self.head_dim,
|
||||||
|
args.dim,
|
||||||
|
bias=False,
|
||||||
|
input_is_parallel=True,
|
||||||
|
init_method=lambda x: x,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache_k = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.cache_v = torch.zeros(
|
||||||
|
(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
bsz, seqlen, _ = x.shape
|
||||||
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||||
|
|
||||||
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||||
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
|
self.cache_k = self.cache_k.to(xq)
|
||||||
|
self.cache_v = self.cache_v.to(xq)
|
||||||
|
|
||||||
|
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||||
|
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||||
|
|
||||||
|
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||||
|
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||||
|
|
||||||
|
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||||
|
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||||
|
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||||
|
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||||
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
multiple_of: int,
|
||||||
|
ffn_dim_multiplier: Optional[float],
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
hidden_dim = int(2 * hidden_dim / 3)
|
||||||
|
# custom dim factor multiplier
|
||||||
|
if ffn_dim_multiplier is not None:
|
||||||
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||||
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
|
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||||
|
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||||
|
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = args.n_heads
|
||||||
|
self.dim = args.dim
|
||||||
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
self.attention = Attention(args)
|
||||||
|
self.feed_forward = FeedForward(
|
||||||
|
dim=args.dim,
|
||||||
|
hidden_dim=4 * args.dim,
|
||||||
|
multiple_of=args.multiple_of,
|
||||||
|
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||||
|
)
|
||||||
|
self.layer_id = layer_id
|
||||||
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
|
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||||
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, params: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.params = params
|
||||||
|
self.vocab_size = params.vocab_size
|
||||||
|
self.n_layers = params.n_layers
|
||||||
|
|
||||||
|
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||||
|
|
||||||
|
self.layers = torch.nn.ModuleList()
|
||||||
|
for layer_id in range(params.n_layers):
|
||||||
|
self.layers.append(TransformerBlock(layer_id, params))
|
||||||
|
|
||||||
|
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||||
|
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||||
|
|
||||||
|
self.freqs_cis = precompute_freqs_cis(
|
||||||
|
params.dim // params.n_heads,
|
||||||
|
params.max_seq_len * 2,
|
||||||
|
params.rope_theta,
|
||||||
|
params.use_scaled_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||||
|
_bsz, seqlen = tokens.shape
|
||||||
|
h = self.tok_embeddings(tokens)
|
||||||
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if seqlen > 1:
|
||||||
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||||
|
|
||||||
|
mask = torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/issues/100005
|
||||||
|
# torch.triu is buggy when the device is mps: filled values are
|
||||||
|
# nan instead of 0.
|
||||||
|
if mask.device.type == torch.device("mps").type:
|
||||||
|
mask = torch.nan_to_num(mask, nan=0.0)
|
||||||
|
|
||||||
|
# When performing key-value caching, we compute the attention scores
|
||||||
|
# only for the new sequence. Thus, the matrix of scores is of size
|
||||||
|
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||||
|
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||||
|
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
|
h = self.norm(h)
|
||||||
|
output = self.output(h).float()
|
||||||
|
return output
|
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||||
|
import math
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||||
|
"""
|
||||||
|
Resize position embedding for vision encoder.
|
||||||
|
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||||
|
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
"""
|
||||||
|
new_grid_size = to_2tuple(grid_size)
|
||||||
|
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||||
|
|
||||||
|
new_pos_emb_tok, new_pos_emb_img = (
|
||||||
|
orig_pos_embed[:1],
|
||||||
|
orig_pos_embed[1:],
|
||||||
|
)
|
||||||
|
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||||
|
|
||||||
|
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
new_pos_emb_img = F.interpolate(
|
||||||
|
new_pos_emb_img,
|
||||||
|
size=new_grid_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||||
|
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||||
|
return new_pos_embed
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||||
|
"""
|
||||||
|
Takes a local position embedding for vision encoder and uses it
|
||||||
|
to initialize the global position embedding.
|
||||||
|
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||||
|
"""
|
||||||
|
pos_embed = pos_and_cls_embed[1:]
|
||||||
|
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||||
|
grid_size = to_2tuple(grid_size)
|
||||||
|
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||||
|
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||||
|
new_pos_emb_img = F.interpolate(
|
||||||
|
new_pos_emb_img,
|
||||||
|
size=new_grid_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||||
|
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||||
|
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||||
|
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||||
|
return pos_and_cls_embed
|
||||||
|
|
||||||
|
|
||||||
|
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||||
|
"""
|
||||||
|
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||||
|
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||||
|
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||||
|
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||||
|
"""
|
||||||
|
# first remove cls token
|
||||||
|
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||||
|
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||||
|
|
||||||
|
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||||
|
old_grid_size = int(math.sqrt(ntok))
|
||||||
|
|
||||||
|
# move to correct form for interpolation
|
||||||
|
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||||
|
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||||
|
pos_embed = pos_embed.unsqueeze(0)
|
||||||
|
|
||||||
|
# interpolate
|
||||||
|
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||||
|
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||||
|
pos_embed_resized = F.interpolate(
|
||||||
|
pos_embed,
|
||||||
|
size=new_size,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||||
|
|
||||||
|
# move it back in place
|
||||||
|
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||||
|
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
|
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||||
|
|
||||||
|
# interpolate cls token
|
||||||
|
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||||
|
cls_embed_resized = F.interpolate(
|
||||||
|
cls_embed,
|
||||||
|
size=(x_scale, y_scale),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=True,
|
||||||
|
)
|
||||||
|
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||||
|
# add cls token back in
|
||||||
|
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||||
|
|
||||||
|
return pos_and_cls_embed
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder_attention_mask(
|
||||||
|
x: torch.Tensor,
|
||||||
|
ar: torch.Tensor,
|
||||||
|
ntok: int,
|
||||||
|
num_chunks: int,
|
||||||
|
n_heads: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build vision encoder attention mask that omits padding tokens.
|
||||||
|
"""
|
||||||
|
masks = []
|
||||||
|
for arx in ar:
|
||||||
|
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||||
|
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||||
|
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||||
|
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||||
|
mask_i = mask_i.unsqueeze(0)
|
||||||
|
masks.append(mask_i)
|
||||||
|
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
def expand_num_tokens_to_mult8(x):
|
||||||
|
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||||
|
if num_pad_tokens == 0:
|
||||||
|
return x, 0
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
x,
|
||||||
|
torch.zeros(
|
||||||
|
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dim=-2,
|
||||||
|
),
|
||||||
|
num_pad_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||||
|
if num_pad_tokens == 0:
|
||||||
|
return x
|
||||||
|
return x[:, :, :-num_pad_tokens]
|
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
|
@ -0,0 +1,408 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as tv
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
IMAGE_RES = 224
|
||||||
|
|
||||||
|
logger = getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class VariableSizeImageTransform(object):
|
||||||
|
"""
|
||||||
|
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||||
|
based on the image aspect ratio and the number of image chunks we allow.
|
||||||
|
|
||||||
|
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||||
|
that leads to a significant degradation in image quality.
|
||||||
|
|
||||||
|
It can be summarized in 6 steps:
|
||||||
|
1. Find all possible canvas combinations of max_num_chunks;
|
||||||
|
2. Find the best canvas to fit the image;
|
||||||
|
3. Resize without distortion
|
||||||
|
4. Pad
|
||||||
|
5. Normalize
|
||||||
|
6. Chunk
|
||||||
|
|
||||||
|
For example, if an input image is of size 300x800, patch_size of 224,
|
||||||
|
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||||
|
is allowed within 8 image chunks, with some restrictions.
|
||||||
|
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||||
|
giving a total of 8 chunks.
|
||||||
|
|
||||||
|
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||||
|
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||||
|
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||||
|
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||||
|
|
||||||
|
However, if limit_upscaling_to_patch_size is set to True,
|
||||||
|
the upscaling will be limited to the patch size. In the example above,
|
||||||
|
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||||
|
|
||||||
|
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||||
|
patches are coming from the resizing and chunking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||||
|
self.size = size
|
||||||
|
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||||
|
self.to_tensor = tv.ToTensor()
|
||||||
|
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||||
|
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||||
|
self.normalize = tv.Normalize(
|
||||||
|
mean=self._mean,
|
||||||
|
std=self._std,
|
||||||
|
inplace=True,
|
||||||
|
)
|
||||||
|
self.resample = tv.InterpolationMode.BILINEAR
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_factors(n: int) -> Set[int]:
|
||||||
|
"""
|
||||||
|
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||||
|
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): The number to find factors for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
set: A set containing all factors of the number.
|
||||||
|
"""
|
||||||
|
factors_set = set()
|
||||||
|
|
||||||
|
for i in range(1, int(n**0.5) + 1):
|
||||||
|
if n % i == 0:
|
||||||
|
factors_set.add(i)
|
||||||
|
factors_set.add(n // i)
|
||||||
|
return factors_set
|
||||||
|
|
||||||
|
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||||
|
and patch_size. Useful for when dividing an image into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_num_chunks (int): Maximum number of chunks for processing.
|
||||||
|
patch_size (int): Size of the side of the patch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> max_num_chunks = 5
|
||||||
|
>>> patch_size = 224
|
||||||
|
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||||
|
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||||
|
(672, 224), (224, 448), (448, 224)])
|
||||||
|
|
||||||
|
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||||
|
{
|
||||||
|
0.25: [(1, 4)],
|
||||||
|
1.0: [(2, 2), (1, 1)],
|
||||||
|
4.0: [(4, 1)],
|
||||||
|
0.33: [(1, 3)],
|
||||||
|
3.0: [(3, 1)],
|
||||||
|
0.5: [(1, 2)],
|
||||||
|
2.0: [(2, 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
and return the resolutions multiplied by the patch_size:
|
||||||
|
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||||
|
"""
|
||||||
|
asp_dict = defaultdict(list)
|
||||||
|
for chunk_size in range(max_num_chunks, 0, -1):
|
||||||
|
_factors = sorted(self.get_factors(chunk_size))
|
||||||
|
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||||
|
for height, width in _asp_ratios:
|
||||||
|
ratio_float = height / width
|
||||||
|
asp_dict[ratio_float].append((height, width))
|
||||||
|
|
||||||
|
# get the resolutions multiplied by the patch_size
|
||||||
|
possible_resolutions = []
|
||||||
|
for value in asp_dict.values():
|
||||||
|
for height, depth in value:
|
||||||
|
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||||
|
|
||||||
|
return possible_resolutions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_max_res_without_distortion(
|
||||||
|
image_size: Tuple[int, int],
|
||||||
|
target_size: Tuple[int, int],
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||||
|
aspect ratio, based on the target resolution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||||
|
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||||
|
Example:
|
||||||
|
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||||
|
(134, 200)
|
||||||
|
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||||
|
(450, 338)
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_width, original_height = image_size
|
||||||
|
target_width, target_height = target_size
|
||||||
|
|
||||||
|
scale_w = target_width / original_width
|
||||||
|
scale_h = target_height / original_height
|
||||||
|
|
||||||
|
if scale_w < scale_h:
|
||||||
|
new_width = target_width
|
||||||
|
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||||
|
else:
|
||||||
|
new_height = target_height
|
||||||
|
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||||
|
|
||||||
|
return new_width, new_height
|
||||||
|
|
||||||
|
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||||
|
new_width, new_height = target_size
|
||||||
|
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||||
|
new_im.paste(image)
|
||||||
|
return new_im
|
||||||
|
|
||||||
|
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||||
|
# Split image into number of required tiles (width x height)
|
||||||
|
num_channels, height, width = image.size()
|
||||||
|
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||||
|
# Permute dimensions to reorder the axes
|
||||||
|
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||||
|
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||||
|
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def resize_without_distortion(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
target_size: Tuple[int, int],
|
||||||
|
max_upscaling_size: Optional[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Used to resize an image to target_resolution, without distortion.
|
||||||
|
|
||||||
|
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||||
|
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||||
|
modifying target_size works as a boundary for the image's largest side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resample (str): Resampling method used when resizing images.
|
||||||
|
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||||
|
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||||
|
If None, there is no limit.
|
||||||
|
Examples:
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 600
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(600, 300) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 600
|
||||||
|
>>> image_size = (2000, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 100) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = 2000
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 500) # new_size_without_distortion
|
||||||
|
|
||||||
|
>>> target_size = (1000, 1200)
|
||||||
|
>>> max_upscaling_size = None
|
||||||
|
>>> image_size = (400, 200)
|
||||||
|
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||||
|
(1000, 500) # new_size_without_distortion
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_width, image_height = image.size
|
||||||
|
image_size = (image_width, image_height)
|
||||||
|
|
||||||
|
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||||
|
if max_upscaling_size is not None:
|
||||||
|
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||||
|
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||||
|
target_size = (new_target_width, new_target_height)
|
||||||
|
|
||||||
|
# resize to target_size while preserving aspect ratio
|
||||||
|
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||||
|
|
||||||
|
image = F.resize(
|
||||||
|
image,
|
||||||
|
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||||
|
interpolation=self.resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_best_fit(
|
||||||
|
self,
|
||||||
|
image_size: Tuple[int, int],
|
||||||
|
possible_resolutions: torch.Tensor,
|
||||||
|
resize_to_max_canvas: bool = False,
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||||
|
resize an image to.
|
||||||
|
|
||||||
|
For each possible resolution, calculates the scaling factors for
|
||||||
|
width and height, and selects the smallest one, which is the limiting side.
|
||||||
|
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||||
|
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||||
|
|
||||||
|
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||||
|
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||||
|
|
||||||
|
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||||
|
reduce downscaling as much as possible.
|
||||||
|
|
||||||
|
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||||
|
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||||
|
has more padding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||||
|
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||||
|
row represents a possible resolution (height, width).
|
||||||
|
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: The best resolution [height, width] for the given image.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> image_size = (200, 300)
|
||||||
|
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||||
|
... [672, 224],
|
||||||
|
... [224, 448],
|
||||||
|
... [448, 224],
|
||||||
|
... [224, 224]])
|
||||||
|
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||||
|
[224, 448]
|
||||||
|
|
||||||
|
We have:
|
||||||
|
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||||
|
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||||
|
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||||
|
Only one of the scales > 1:
|
||||||
|
upscaling_possible = tensor([1.1200, 1.1200])
|
||||||
|
smallest_rescale = tensor(1.1200)
|
||||||
|
So we pick the resolution with the smallest smallest area:
|
||||||
|
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||||
|
optimal_canvas = tensor([224, 448])
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_width, original_height = image_size
|
||||||
|
|
||||||
|
# get all possible resolutions heights/widths
|
||||||
|
target_widths, target_heights = (
|
||||||
|
possible_resolutions[:, 0],
|
||||||
|
possible_resolutions[:, 1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# get scaling factors to resize the image without distortion
|
||||||
|
scale_w = target_widths / original_width
|
||||||
|
scale_h = target_heights / original_height
|
||||||
|
|
||||||
|
# get the min scale between width and height (limiting side -> no distortion)
|
||||||
|
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||||
|
|
||||||
|
# filter only scales that allow upscaling
|
||||||
|
upscaling_options = scales[scales >= 1]
|
||||||
|
if len(upscaling_options) > 0:
|
||||||
|
if resize_to_max_canvas:
|
||||||
|
selected_scale = torch.max(upscaling_options)
|
||||||
|
else:
|
||||||
|
selected_scale = torch.min(upscaling_options)
|
||||||
|
else:
|
||||||
|
# no upscaling possible,
|
||||||
|
# get the minimum downscaling (max scale for scales<1)
|
||||||
|
downscaling_options = scales[scales < 1]
|
||||||
|
selected_scale = torch.max(downscaling_options)
|
||||||
|
|
||||||
|
# get all resolutions that support this scaling factor,
|
||||||
|
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||||
|
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||||
|
|
||||||
|
# if there are multiple resolutions,
|
||||||
|
# get the one with minimum area to reduce padding
|
||||||
|
if len(chosen_canvas) > 1:
|
||||||
|
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||||
|
optimal_idx = torch.argmin(areas)
|
||||||
|
optimal_canvas = chosen_canvas[optimal_idx]
|
||||||
|
else:
|
||||||
|
optimal_canvas = chosen_canvas[0]
|
||||||
|
|
||||||
|
return tuple(optimal_canvas.tolist())
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
max_num_chunks: int,
|
||||||
|
normalize_img: bool = True,
|
||||||
|
resize_to_max_canvas: bool = False,
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image (PIL.Image): Image to be resized.
|
||||||
|
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||||
|
normalize_img (bool): Whether to normalize the image.
|
||||||
|
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||||
|
If True, picks the canvas the allows the largest resizing without distortion.
|
||||||
|
If False, downsample as little as possible, including no resizing at all,
|
||||||
|
but never upsample, unless the image is smaller than the patch size.
|
||||||
|
"""
|
||||||
|
assert max_num_chunks > 0
|
||||||
|
assert isinstance(image, Image.Image), type(image)
|
||||||
|
w, h = image.size
|
||||||
|
|
||||||
|
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||||
|
possible_resolutions = torch.tensor(possible_resolutions)
|
||||||
|
|
||||||
|
best_resolution = self.get_best_fit(
|
||||||
|
image_size=(w, h),
|
||||||
|
possible_resolutions=possible_resolutions,
|
||||||
|
resize_to_max_canvas=resize_to_max_canvas,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||||
|
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||||
|
image = self._pad(image, best_resolution)
|
||||||
|
|
||||||
|
image = self.to_tensor(image)
|
||||||
|
|
||||||
|
if normalize_img:
|
||||||
|
image = self.normalize(image)
|
||||||
|
|
||||||
|
ratio_w, ratio_h = (
|
||||||
|
best_resolution[0] // self.size,
|
||||||
|
best_resolution[1] // self.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||||
|
|
||||||
|
ar = (ratio_h, ratio_w)
|
||||||
|
return image, ar
|
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
File diff suppressed because it is too large
Load diff
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_negative_inf_value(dtype):
|
||||||
|
return torch.finfo(dtype).min
|
||||||
|
|
||||||
|
|
||||||
|
def to_2tuple(x):
|
||||||
|
if isinstance(x, collections.abc.Iterable):
|
||||||
|
return x
|
||||||
|
return (x, x)
|
|
@ -15,11 +15,8 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
|
||||||
BuiltinTool,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from logging import getLogger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# The tiktoken tokenizer can handle <=400k chars without
|
||||||
|
# pyo3_runtime.PanicException.
|
||||||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||||
|
|
||||||
|
# https://github.com/openai/tiktoken/issues/195
|
||||||
|
# Here we iterate over subsequences and split if we exceed the limit
|
||||||
|
# of max consecutive non-whitespace or whitespace characters.
|
||||||
|
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||||
|
|
||||||
|
|
||||||
|
_INSTANCE = None
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
special_tokens: Dict[str, int]
|
||||||
|
|
||||||
|
num_reserved_special_tokens = 256
|
||||||
|
|
||||||
|
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
global _INSTANCE
|
||||||
|
|
||||||
|
if _INSTANCE is None:
|
||||||
|
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||||
|
return _INSTANCE
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
"""
|
||||||
|
Initializes the Tokenizer with a Tiktoken model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The path to the Tiktoken model file.
|
||||||
|
"""
|
||||||
|
assert os.path.isfile(model_path), model_path
|
||||||
|
|
||||||
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||||
|
num_base_tokens = len(mergeable_ranks)
|
||||||
|
special_tokens = [
|
||||||
|
"<|begin_of_text|>",
|
||||||
|
"<|end_of_text|>",
|
||||||
|
"<|reserved_special_token_0|>",
|
||||||
|
"<|reserved_special_token_1|>",
|
||||||
|
"<|finetune_right_pad_id|>",
|
||||||
|
"<|step_id|>",
|
||||||
|
"<|start_header_id|>",
|
||||||
|
"<|end_header_id|>",
|
||||||
|
"<|eom_id|>", # end of message
|
||||||
|
"<|eot_id|>", # end of turn
|
||||||
|
"<|python_tag|>",
|
||||||
|
"<|image|>",
|
||||||
|
]
|
||||||
|
reserved_tokens = [
|
||||||
|
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||||
|
]
|
||||||
|
special_tokens = special_tokens + reserved_tokens
|
||||||
|
|
||||||
|
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||||
|
self.model = tiktoken.Encoding(
|
||||||
|
name=Path(model_path).name,
|
||||||
|
pat_str=self.pat_str,
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=self.special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||||
|
# BOS / EOS token IDs
|
||||||
|
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||||
|
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||||
|
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||||
|
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||||
|
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||||
|
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||||
|
self.stop_tokens = [
|
||||||
|
self.eos_id,
|
||||||
|
self.special_tokens["<|eom_id|>"],
|
||||||
|
self.special_tokens["<|eot_id|>"],
|
||||||
|
]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
s: str,
|
||||||
|
*,
|
||||||
|
bos: bool,
|
||||||
|
eos: bool,
|
||||||
|
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Encodes a string into a list of token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s (str): The input string to be encoded.
|
||||||
|
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||||
|
eos (bool): Whether to append the end-of-sequence token.
|
||||||
|
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||||
|
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[int]: A list of token IDs.
|
||||||
|
|
||||||
|
By default, setting disallowed_special=() encodes a string by ignoring
|
||||||
|
special tokens. Specifically:
|
||||||
|
- Setting `disallowed_special` to () will cause all text corresponding
|
||||||
|
to special tokens to be encoded as natural text (insteading of raising
|
||||||
|
an error).
|
||||||
|
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||||
|
to special tokens to be encoded as special tokens.
|
||||||
|
"""
|
||||||
|
if allowed_special is None:
|
||||||
|
allowed_special = set()
|
||||||
|
assert type(s) is str
|
||||||
|
|
||||||
|
substrs = (
|
||||||
|
substr
|
||||||
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||||
|
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||||
|
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||||
|
)
|
||||||
|
)
|
||||||
|
t: List[int] = []
|
||||||
|
for substr in substrs:
|
||||||
|
t.extend(
|
||||||
|
self.model.encode(
|
||||||
|
substr,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if bos:
|
||||||
|
t.insert(0, self.bos_id)
|
||||||
|
if eos:
|
||||||
|
t.append(self.eos_id)
|
||||||
|
return t
|
||||||
|
|
||||||
|
def decode(self, t: Sequence[int]) -> str:
|
||||||
|
"""
|
||||||
|
Decodes a list of token IDs into a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (List[int]): The list of token IDs to be decoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The decoded string.
|
||||||
|
"""
|
||||||
|
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||||
|
return self.model.decode(cast(List[int], t))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||||
|
"""
|
||||||
|
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||||
|
consecutive whitespaces or consecutive non-whitespaces.
|
||||||
|
"""
|
||||||
|
current_slice_len = 0
|
||||||
|
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||||
|
slice_start = 0
|
||||||
|
|
||||||
|
for i in range(len(s)):
|
||||||
|
is_now_space = s[i].isspace()
|
||||||
|
|
||||||
|
if current_slice_is_space ^ is_now_space:
|
||||||
|
current_slice_len = 1
|
||||||
|
current_slice_is_space = is_now_space
|
||||||
|
else:
|
||||||
|
current_slice_len += 1
|
||||||
|
if current_slice_len > max_consecutive_slice_len:
|
||||||
|
yield s[slice_start:i]
|
||||||
|
slice_start = i
|
||||||
|
current_slice_len = 1
|
||||||
|
yield s[slice_start:]
|
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
||||||
|
|
||||||
|
def is_json(s):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
# Return True for valid objects and not for ints, strings, etc
|
||||||
|
return isinstance(parsed, dict)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_python_list(input_string):
|
||||||
|
"""Check if the input string is a valid Python list of function calls"""
|
||||||
|
try:
|
||||||
|
# Try to parse the string
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Check if it's a single expression
|
||||||
|
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the expression is a list
|
||||||
|
expr = tree.body[0].value
|
||||||
|
if not isinstance(expr, ast.List):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the list is empty
|
||||||
|
if len(expr.elts) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all elements in the list are function calls
|
||||||
|
for element in expr.elts:
|
||||||
|
if not isinstance(element, ast.Call):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the function call has a valid name
|
||||||
|
if not isinstance(element.func, ast.Name):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all arguments are keyword arguments
|
||||||
|
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except SyntaxError:
|
||||||
|
# If parsing fails, it's not a valid Python expression
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_python_list_for_function_calls(input_string):
|
||||||
|
"""
|
||||||
|
Parse a Python list of function calls and
|
||||||
|
return a list of tuples containing the function name and arguments
|
||||||
|
"""
|
||||||
|
# Parse the string into an AST
|
||||||
|
tree = ast.parse(input_string)
|
||||||
|
|
||||||
|
# Ensure the input is a list
|
||||||
|
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||||
|
raise ValueError("Input must be a list of function calls")
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# Iterate through each function call in the list
|
||||||
|
for node in tree.body[0].value.elts:
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
function_name = node.func.id
|
||||||
|
function_args = {}
|
||||||
|
|
||||||
|
# Extract keyword arguments
|
||||||
|
for keyword in node.keywords:
|
||||||
|
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||||
|
|
||||||
|
result.append((function_name, function_args))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ToolUtils:
|
||||||
|
@staticmethod
|
||||||
|
def is_builtin_tool_call(message_body: str) -> bool:
|
||||||
|
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
return match is not None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||||
|
|
||||||
|
# Check if a match is found and return it
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("tool_name")
|
||||||
|
query = match.group("query")
|
||||||
|
return tool_name, query
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||||
|
# NOTE: Custom function too calls are still experimental
|
||||||
|
# Sometimes, response is of the form
|
||||||
|
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||||
|
# and some times
|
||||||
|
# <function=function_name>(parameters)</function>
|
||||||
|
|
||||||
|
# Find the first match in the text
|
||||||
|
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||||
|
if match:
|
||||||
|
tool_name = match.group("function_name")
|
||||||
|
query = match.group("args")
|
||||||
|
try:
|
||||||
|
return tool_name, json.loads(query.replace("'", '"'))
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception while parsing json query for custom tool call", query, e)
|
||||||
|
return None
|
||||||
|
elif is_json(message_body):
|
||||||
|
response = json.loads(message_body)
|
||||||
|
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||||
|
function_name = response["name"]
|
||||||
|
args = response["parameters"]
|
||||||
|
return function_name, args
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif is_valid_python_list(message_body):
|
||||||
|
res = parse_python_list_for_function_calls(message_body)
|
||||||
|
# FIXME: Enable multiple tool calls
|
||||||
|
return res[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||||
|
if t.tool_name == BuiltinTool.brave_search:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'brave_search.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'wolfram_alpha.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.photogen:
|
||||||
|
q = t.arguments["query"]
|
||||||
|
return f'photogen.call(query="{q}")'
|
||||||
|
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||||
|
return t.arguments["code"]
|
||||||
|
else:
|
||||||
|
fname = t.tool_name
|
||||||
|
|
||||||
|
if tool_prompt_format == ToolPromptFormat.json:
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": fname,
|
||||||
|
"parameters": t.arguments,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||||
|
args = json.dumps(t.arguments)
|
||||||
|
return f"<function={fname}>{args}</function>"
|
||||||
|
|
||||||
|
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||||
|
|
||||||
|
def format_value(value: RecursiveType) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return f'"{value}"'
|
||||||
|
elif isinstance(value, (int, float, bool)) or value is None:
|
||||||
|
return str(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported type: {type(value)}")
|
||||||
|
|
||||||
|
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||||
|
return f"[{fname}({args_str})]"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
|
@ -16,7 +16,9 @@ import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import (
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from .llama3.interface import LLama31Interface
|
from .llama3.interface import LLama31Interface
|
||||||
from .llama3.template_data import (
|
from .llama3.template_data import (
|
||||||
|
|
|
@ -23,13 +23,6 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat, LLMInput
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer
|
|
||||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
|
||||||
CrossAttentionTransformer,
|
|
||||||
)
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
@ -46,6 +39,13 @@ from llama_stack.models.llama.datatypes import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer
|
||||||
|
from llama_stack.models.llama.llama3.multimodal.model import (
|
||||||
|
CrossAttentionTransformer,
|
||||||
|
)
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
|
|
@ -9,10 +9,9 @@ from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import Model
|
from llama_stack.models.llama.datatypes import Model
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
|
|
|
@ -15,13 +15,13 @@ import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from llama_stack.apis.inference import QuantizationType
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
|
@ -22,11 +22,11 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama3.args import ModelArgs
|
||||||
|
from llama_stack.models.llama.llama3.model import Transformer, TransformerBlock
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||||
quantize_fp8,
|
quantize_fp8,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,6 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
||||||
|
@ -36,6 +35,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
|
|
@ -7,7 +7,6 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import StopReason, ToolCall
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -42,7 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
|
@ -13,9 +13,6 @@ import re
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from llama_models.datatypes import StopReason
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -44,9 +41,11 @@ from llama_stack.models.llama.datatypes import (
|
||||||
RawMessage,
|
RawMessage,
|
||||||
RawTextItem,
|
RawTextItem,
|
||||||
Role,
|
Role,
|
||||||
|
StopReason,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
is_multimodal,
|
is_multimodal,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.prompt_templates import (
|
from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
FunctionTagCustomToolGenerator,
|
FunctionTagCustomToolGenerator,
|
||||||
|
@ -54,6 +53,7 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
PythonListCustomToolGenerator,
|
PythonListCustomToolGenerator,
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,6 @@ from urllib.parse import unquote
|
||||||
import chardet
|
import chardet
|
||||||
import httpx
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
|
|
||||||
|
@ -27,6 +26,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
from llama_stack.apis.tools import RAGDocument
|
from llama_stack.apis.tools import RAGDocument
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||||
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
|
|
@ -13,31 +13,38 @@
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
# from llama_stack.models.llama.datatypes import * # noqa: F403
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_models.llama3.reference_impl.generation import Llama
|
from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig
|
||||||
|
from llama_stack.providers.inline.inference.meta_reference.generation import Llama
|
||||||
|
|
||||||
THIS_DIR = Path(__file__).parent.resolve()
|
THIS_DIR = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
|
||||||
def run_main(
|
def run_main(
|
||||||
ckpt_dir: str,
|
model_id: str,
|
||||||
|
checkpoint_dir: str,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
model_parallel_size: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function"
|
assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function"
|
||||||
tokenizer_path = str(THIS_DIR.parent / "llama3/api/tokenizer.model")
|
|
||||||
generator = Llama.build(
|
config = MetaReferenceInferenceConfig(
|
||||||
ckpt_dir=ckpt_dir,
|
model=model_id,
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
max_seq_len=512,
|
max_seq_len=512,
|
||||||
max_batch_size=1,
|
max_batch_size=1,
|
||||||
model_parallel_size=model_parallel_size,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
)
|
||||||
|
llama_model = resolve_model(model_id)
|
||||||
|
if not llama_model:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
generator = Llama.build(
|
||||||
|
config=config,
|
||||||
|
model_id=model_id,
|
||||||
|
llama_model=llama_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_cases = module.usecases()
|
use_cases = module.usecases()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue