forked from phoenix-oss/llama-stack-mirror
chore: remove dependency on llama_models completely (#1344)
This commit is contained in:
parent
7131d5ddeb
commit
8bbd52bb9f
43 changed files with 131358 additions and 202 deletions
|
@ -11,16 +11,128 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
# import all for backwards compatibility
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
system = "system"
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
tool = "tool"
|
||||
|
||||
|
||||
class BuiltinTool(Enum):
|
||||
brave_search = "brave_search"
|
||||
wolfram_alpha = "wolfram_alpha"
|
||||
photogen = "photogen"
|
||||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
Primitive = Union[str, int, float, bool, None]
|
||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
arguments: Dict[str, RecursiveType]
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class ToolPromptFormat(Enum):
|
||||
"""Prompt format for calling custom / zero shot tools.
|
||||
|
||||
:cvar json: JSON format for calling tools. It takes the form:
|
||||
{
|
||||
"type": "function",
|
||||
"function" : {
|
||||
"name": "function_name",
|
||||
"description": "function_description",
|
||||
"parameters": {...}
|
||||
}
|
||||
}
|
||||
:cvar function_tag: Function tag format, pseudo-XML. This looks like:
|
||||
<function=function_name>(parameters)</function>
|
||||
|
||||
:cvar python_list: Python list. The output is a valid Python expression that can be
|
||||
evaluated to a list. Each element in the list is a function call. Example:
|
||||
["function_name(param1, param2)", "function_name(param1, param2)"]
|
||||
"""
|
||||
|
||||
json = "json"
|
||||
function_tag = "function_tag"
|
||||
python_list = "python_list"
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
end_of_turn = "end_of_turn"
|
||||
end_of_message = "end_of_message"
|
||||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def validate_data(cls, v):
|
||||
if isinstance(v, str):
|
||||
return base64.b64decode(v)
|
||||
return v
|
||||
|
||||
|
||||
class RawTextItem(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||
|
||||
RawContent = str | RawContentItem | List[RawContentItem]
|
||||
|
||||
|
||||
class RawMessage(BaseModel):
|
||||
role: Literal["user"] | Literal["system"] | Literal["tool"] | Literal["assistant"]
|
||||
content: RawContent
|
||||
|
||||
# This is for RAG but likely should be absorbed into content
|
||||
context: Optional[RawContent] = None
|
||||
|
||||
# These are for the output message coming from the assistant
|
||||
stop_reason: Optional[StopReason] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
|
||||
|
||||
|
|
5
llama_stack/models/llama/llama3/__init__.py
Normal file
5
llama_stack/models/llama/llama3/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
82
llama_stack/models/llama/llama3/args.py
Normal file
82
llama_stack/models/llama/llama3/args.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "scheme":
|
||||
setattr(self, k, QuantizationScheme(v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAArgs:
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# vision model params
|
||||
vision_chunk_size: int = -1 # image resolution for image models
|
||||
vision_max_num_chunks: int = 4
|
||||
vision_num_cross_attention_layers: int = -1
|
||||
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if k == "lora_args":
|
||||
setattr(self, k, LoRAArgs(**v))
|
||||
elif k == "quantization_args":
|
||||
setattr(self, k, QuantizationArgs(**v))
|
||||
else:
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
if self.n_kv_heads is None:
|
||||
self.n_kv_heads = self.n_heads
|
||||
assert self.n_kv_heads <= self.n_heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.dim % self.n_heads == 0
|
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
282
llama_stack/models/llama/llama3/chat_format.py
Normal file
|
@ -0,0 +1,282 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionInput:
|
||||
mask: List[List[int]]
|
||||
images: List[PIL_Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
tokens: List[int]
|
||||
vision: Optional[VisionInput] = None
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
role_strs = {
|
||||
Role.user: "user",
|
||||
Role.system: "system",
|
||||
Role.tool: "ipython", # special
|
||||
Role.assistant: "assistant",
|
||||
}
|
||||
return role_strs[role]
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
|
||||
def __init__(self, tokenizer: Tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_content(self, content: RawContent) -> LLMInput:
|
||||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = []
|
||||
images = []
|
||||
|
||||
added_bos = False
|
||||
|
||||
def _process(c):
|
||||
nonlocal added_bos, bos
|
||||
|
||||
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||
if isinstance(c, RawTextItem):
|
||||
c = c.text
|
||||
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||
added_bos = True
|
||||
|
||||
elif isinstance(c, RawMediaItem):
|
||||
bos = False if added_bos else bos
|
||||
if bos:
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
added_bos = True
|
||||
tokens.append(self.vision_token)
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = image.convert("RGB")
|
||||
images.append(image)
|
||||
|
||||
if isinstance(content, list):
|
||||
for c in content:
|
||||
_process(c)
|
||||
else:
|
||||
_process(content)
|
||||
|
||||
return tokens, images
|
||||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
def _process_content(c):
|
||||
toks, imgs = self._encode_content(c)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
if (
|
||||
message.role == "assistant"
|
||||
and len(message.tool_calls) > 0
|
||||
and message.tool_calls[0].tool_name == BuiltinTool.code_interpreter
|
||||
):
|
||||
tokens.append(self.tokenizer.special_tokens["<|python_tag|>"])
|
||||
|
||||
_process_content(message.content)
|
||||
|
||||
if message.role == "user" and message.context is not None:
|
||||
# This is RAG context; why is it here in the chat format? I don't think
|
||||
# this is needed and can be moved upwards
|
||||
_process_content("\n\n")
|
||||
_process_content(message.context)
|
||||
|
||||
if message.role == "assistant":
|
||||
for t in message.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom_id|>" if eom else "<|eot_id|>"])
|
||||
return tokens, images
|
||||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> LLMInput:
|
||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||
tokens = []
|
||||
images = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
for message in messages:
|
||||
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
tokens.extend(self._encode_header("assistant"))
|
||||
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||
content = content.strip(" ")
|
||||
header_str = self.possible_headers[Role.assistant]
|
||||
if content.startswith(header_str):
|
||||
content = content[len(header_str) :]
|
||||
|
||||
ipython = content.startswith("<|python_tag|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_tag|>") :]
|
||||
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||
vision_input = None
|
||||
if len(images) > 0:
|
||||
vision_input = VisionInput(
|
||||
mask=create_vision_mask(tokens, self.vision_token),
|
||||
images=images,
|
||||
)
|
||||
|
||||
return LLMInput(
|
||||
tokens=[128256 if token == self.vision_token else token for token in tokens],
|
||||
vision=vision_input,
|
||||
)
|
||||
|
||||
|
||||
def create_vision_mask(
|
||||
tokens: List[int],
|
||||
vision_token: int,
|
||||
) -> List[List[int]]:
|
||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||
if len(vision_token_locations) == 0:
|
||||
return []
|
||||
|
||||
if len(vision_token_locations) == 1:
|
||||
# only one image present, unmask until end of sequence
|
||||
return [[vision_token_locations[0], -1]]
|
||||
vision_masks = [
|
||||
[loc1, loc2] for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:], strict=False)
|
||||
]
|
||||
# last image will attend to all subsequent text
|
||||
vision_masks.append([vision_token_locations[-1], len(tokens)])
|
||||
|
||||
# if there are two or more consecutive vision tokens,
|
||||
# they should all attend to all subsequent
|
||||
# text present
|
||||
last_mask_end = vision_masks[-1][1]
|
||||
for vision_mask in vision_masks[::-1]:
|
||||
if vision_mask[0] == vision_mask[1] - 1:
|
||||
vision_mask[1] = last_mask_end
|
||||
last_mask_end = vision_mask[1]
|
||||
return vision_masks
|
|
@ -14,20 +14,19 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
FunctionTagCustomToolGenerator,
|
||||
|
@ -35,6 +34,7 @@ from .prompt_templates import (
|
|||
SystemDefaultGenerator,
|
||||
ToolResponseGenerator,
|
||||
)
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
|
315
llama_stack/models/llama/llama3/model.py
Normal file
315
llama_stack/models/llama/llama3/model.py
Normal file
|
@ -0,0 +1,315 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from ..api import ModelArgs
|
||||
|
||||
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||
# dependencies. These dependencies are not part of the default dependencies
|
||||
# (requirements.txt) of the `llama-models` package.
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
wavelen = 2 * torch.pi / freqs
|
||||
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
return torch.where(
|
||||
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
|
||||
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
|
||||
new_freqs,
|
||||
)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
keys = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
values = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
|
||||
|
||||
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
||||
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if mask is not None:
|
||||
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
||||
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.attention = Attention(args)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=4 * args.dim,
|
||||
multiple_of=args.multiple_of,
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, params: ModelArgs):
|
||||
super().__init__()
|
||||
self.params = params
|
||||
self.vocab_size = params.vocab_size
|
||||
self.n_layers = params.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(params.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, params))
|
||||
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
params.dim // params.n_heads,
|
||||
params.max_seq_len * 2,
|
||||
params.rope_theta,
|
||||
params.use_scaled_rope,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if mask.device.type == torch.device("mps").type:
|
||||
mask = torch.nan_to_num(mask, nan=0.0)
|
||||
|
||||
# When performing key-value caching, we compute the attention scores
|
||||
# only for the new sequence. Thus, the matrix of scores is of size
|
||||
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
||||
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
||||
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
return output
|
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
12
llama_stack/models/llama/llama3/multimodal/__init__.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
179
llama_stack/models/llama/llama3/multimodal/encoder_utils.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and its affiliates.
|
||||
import math
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import get_negative_inf_value, to_2tuple
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
def resize_local_position_embedding(orig_pos_embed, grid_size):
|
||||
"""
|
||||
Resize position embedding for vision encoder.
|
||||
Original position embedding is [n_tiles * n_tiles + 1, dim]
|
||||
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
|
||||
"""
|
||||
new_grid_size = to_2tuple(grid_size)
|
||||
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
|
||||
|
||||
new_pos_emb_tok, new_pos_emb_img = (
|
||||
orig_pos_embed[:1],
|
||||
orig_pos_embed[1:],
|
||||
)
|
||||
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
|
||||
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
|
||||
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
|
||||
return new_pos_embed
|
||||
|
||||
|
||||
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a local position embedding for vision encoder and uses it
|
||||
to initialize the global position embedding.
|
||||
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
pos_embed = pos_and_cls_embed[1:]
|
||||
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
|
||||
grid_size = to_2tuple(grid_size)
|
||||
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
|
||||
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
|
||||
new_pos_emb_img = F.interpolate(
|
||||
new_pos_emb_img,
|
||||
size=new_grid_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
|
||||
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
|
||||
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
|
||||
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
|
||||
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
|
||||
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
|
||||
"""
|
||||
Takes a global position embedding for vision encoder and resizes it to new size.
|
||||
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
|
||||
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
|
||||
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
|
||||
"""
|
||||
# first remove cls token
|
||||
pos_embed = pos_and_cls_embed[:, :, 1:]
|
||||
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
|
||||
|
||||
xs_old, ys_old, ntok, dim = pos_embed.shape
|
||||
old_grid_size = int(math.sqrt(ntok))
|
||||
|
||||
# move to correct form for interpolation
|
||||
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
|
||||
pos_embed = pos_embed.unsqueeze(0)
|
||||
|
||||
# interpolate
|
||||
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
|
||||
pos_embed = pos_embed.permute(0, 3, 1, 2)
|
||||
pos_embed_resized = F.interpolate(
|
||||
pos_embed,
|
||||
size=new_size,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
|
||||
|
||||
# move it back in place
|
||||
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
|
||||
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
|
||||
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
|
||||
|
||||
# interpolate cls token
|
||||
cls_embed = cls_embed.permute(2, 3, 0, 1)
|
||||
cls_embed_resized = F.interpolate(
|
||||
cls_embed,
|
||||
size=(x_scale, y_scale),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
|
||||
# add cls token back in
|
||||
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
|
||||
|
||||
return pos_and_cls_embed
|
||||
|
||||
|
||||
def build_encoder_attention_mask(
|
||||
x: torch.Tensor,
|
||||
ar: torch.Tensor,
|
||||
ntok: int,
|
||||
num_chunks: int,
|
||||
n_heads: int,
|
||||
):
|
||||
"""
|
||||
Build vision encoder attention mask that omits padding tokens.
|
||||
"""
|
||||
masks = []
|
||||
for arx in ar:
|
||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
|
||||
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
|
||||
mask_i = mask_i.unsqueeze(0)
|
||||
masks.append(mask_i)
|
||||
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
|
||||
return masks
|
||||
|
||||
|
||||
def expand_num_tokens_to_mult8(x):
|
||||
num_pad_tokens = 8 - (x.shape[-2] % 8)
|
||||
if num_pad_tokens == 0:
|
||||
return x, 0
|
||||
else:
|
||||
return (
|
||||
torch.cat(
|
||||
[
|
||||
x,
|
||||
torch.zeros(
|
||||
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
),
|
||||
],
|
||||
dim=-2,
|
||||
),
|
||||
num_pad_tokens,
|
||||
)
|
||||
|
||||
|
||||
def contract_num_tokens_from_mult8(x, num_pad_tokens):
|
||||
if num_pad_tokens == 0:
|
||||
return x
|
||||
return x[:, :, :-num_pad_tokens]
|
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
408
llama_stack/models/llama/llama3/multimodal/image_transform.py
Normal file
|
@ -0,0 +1,408 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
IMAGE_RES = 224
|
||||
|
||||
logger = getLogger()
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
self._std = (0.26862954, 0.26130258, 0.27577711)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, depth in value:
|
||||
possible_resolutions.append((height * patch_size, depth * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(new_size_without_distortion[1], new_size_without_distortion[0]),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
1435
llama_stack/models/llama/llama3/multimodal/model.py
Normal file
File diff suppressed because it is too large
Load diff
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
26
llama_stack/models/llama/llama3/multimodal/utils.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_negative_inf_value(dtype):
|
||||
return torch.finfo(dtype).min
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
|
@ -15,11 +15,8 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_models.datatypes import (
|
||||
BuiltinTool,
|
||||
)
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
128000
llama_stack/models/llama/llama3/tokenizer.model
Normal file
File diff suppressed because it is too large
Load diff
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
214
llama_stack/models/llama/llama3/tokenizer.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
special_tokens = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|reserved_special_token_0|>",
|
||||
"<|reserved_special_token_1|>",
|
||||
"<|finetune_right_pad_id|>",
|
||||
"<|step_id|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|eom_id|>", # end of message
|
||||
"<|eot_id|>", # end of turn
|
||||
"<|python_tag|>",
|
||||
"<|image|>",
|
||||
]
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{2 + i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot_id|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom_id|>"]
|
||||
self.python_tag_id = self.special_tokens["<|python_tag|>"]
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom_id|>"],
|
||||
self.special_tokens["<|eot_id|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
199
llama_stack/models/llama/llama3/tool_utils.py
Normal file
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||
|
||||
|
||||
def is_json(s):
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
# Return True for valid objects and not for ints, strings, etc
|
||||
return isinstance(parsed, dict)
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_python_list(input_string):
|
||||
"""Check if the input string is a valid Python list of function calls"""
|
||||
try:
|
||||
# Try to parse the string
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Check if it's a single expression
|
||||
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||
return False
|
||||
|
||||
# Check if the expression is a list
|
||||
expr = tree.body[0].value
|
||||
if not isinstance(expr, ast.List):
|
||||
return False
|
||||
|
||||
# Check if the list is empty
|
||||
if len(expr.elts) == 0:
|
||||
return False
|
||||
|
||||
# Check if all elements in the list are function calls
|
||||
for element in expr.elts:
|
||||
if not isinstance(element, ast.Call):
|
||||
return False
|
||||
|
||||
# Check if the function call has a valid name
|
||||
if not isinstance(element.func, ast.Name):
|
||||
return False
|
||||
|
||||
# Check if all arguments are keyword arguments
|
||||
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except SyntaxError:
|
||||
# If parsing fails, it's not a valid Python expression
|
||||
return False
|
||||
|
||||
|
||||
def parse_python_list_for_function_calls(input_string):
|
||||
"""
|
||||
Parse a Python list of function calls and
|
||||
return a list of tuples containing the function name and arguments
|
||||
"""
|
||||
# Parse the string into an AST
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Ensure the input is a list
|
||||
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||
raise ValueError("Input must be a list of function calls")
|
||||
|
||||
result = []
|
||||
|
||||
# Iterate through each function call in the list
|
||||
for node in tree.body[0].value.elts:
|
||||
if isinstance(node, ast.Call):
|
||||
function_name = node.func.id
|
||||
function_args = {}
|
||||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
|
||||
result.append((function_name, function_args))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ToolUtils:
|
||||
@staticmethod
|
||||
def is_builtin_tool_call(message_body: str) -> bool:
|
||||
match = re.search(ToolUtils.BUILTIN_TOOL_PATTERN, message_body)
|
||||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# Find the first match in the text
|
||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||
|
||||
# Check if a match is found and return it
|
||||
if match:
|
||||
tool_name = match.group("tool_name")
|
||||
query = match.group("query")
|
||||
return tool_name, query
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
# NOTE: Custom function too calls are still experimental
|
||||
# Sometimes, response is of the form
|
||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||
# and some times
|
||||
# <function=function_name>(parameters)</function>
|
||||
|
||||
# Find the first match in the text
|
||||
match = re.search(CUSTOM_TOOL_CALL_PATTERN, message_body)
|
||||
if match:
|
||||
tool_name = match.group("function_name")
|
||||
query = match.group("args")
|
||||
try:
|
||||
return tool_name, json.loads(query.replace("'", '"'))
|
||||
except Exception as e:
|
||||
print("Exception while parsing json query for custom tool call", query, e)
|
||||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
else:
|
||||
return None
|
||||
elif is_valid_python_list(message_body):
|
||||
res = parse_python_list_for_function_calls(message_body)
|
||||
# FIXME: Enable multiple tool calls
|
||||
return res[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def encode_tool_call(t: ToolCall, tool_prompt_format: ToolPromptFormat) -> str:
|
||||
if t.tool_name == BuiltinTool.brave_search:
|
||||
q = t.arguments["query"]
|
||||
return f'brave_search.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.wolfram_alpha:
|
||||
q = t.arguments["query"]
|
||||
return f'wolfram_alpha.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.photogen:
|
||||
q = t.arguments["query"]
|
||||
return f'photogen.call(query="{q}")'
|
||||
elif t.tool_name == BuiltinTool.code_interpreter:
|
||||
return t.arguments["code"]
|
||||
else:
|
||||
fname = t.tool_name
|
||||
|
||||
if tool_prompt_format == ToolPromptFormat.json:
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "function",
|
||||
"name": fname,
|
||||
"parameters": t.arguments,
|
||||
}
|
||||
)
|
||||
elif tool_prompt_format == ToolPromptFormat.function_tag:
|
||||
args = json.dumps(t.arguments)
|
||||
return f"<function={fname}>{args}</function>"
|
||||
|
||||
elif tool_prompt_format == ToolPromptFormat.python_list:
|
||||
|
||||
def format_value(value: RecursiveType) -> str:
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
elif isinstance(value, (int, float, bool)) or value is None:
|
||||
return str(value)
|
||||
elif isinstance(value, list):
|
||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||
elif isinstance(value, dict):
|
||||
return f"{{{', '.join(f'{k}={format_value(v)}' for k, v in value.items())}}}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(value)}")
|
||||
|
||||
args_str = ", ".join(f"{k}={format_value(v)}" for k, v in t.arguments.items())
|
||||
return f"[{fname}({args_str})]"
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool prompt format: {tool_prompt_format}")
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
import json
|
||||
import textwrap
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
|
|
@ -16,7 +16,9 @@ import textwrap
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_models.datatypes import (
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
|
@ -25,7 +27,6 @@ from llama_models.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue