forked from phoenix-oss/llama-stack-mirror
chore: enable pyupgrade fixes (#1806)
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
ffe3d0b2cd
commit
9e6561a1ec
319 changed files with 2843 additions and 3033 deletions
|
@ -7,14 +7,14 @@
|
|||
import concurrent.futures
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||
|
||||
|
||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> list[int]:
|
||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||
if new_mp_size % old_mp_size == 0:
|
||||
# Read old MP shard and split it into smaller ones
|
||||
|
@ -31,12 +31,12 @@ def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[in
|
|||
|
||||
|
||||
def maybe_reshard_state_dict(
|
||||
ckpt_paths: List[Path],
|
||||
ckpt_paths: list[Path],
|
||||
n_kv_heads: int,
|
||||
moe_num_experts: Optional[int] = None,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
moe_num_experts: int | None = None,
|
||||
map_location: str | torch.device = "cpu",
|
||||
mmap: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if str(map_location) == "cpu":
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
else:
|
||||
|
@ -97,18 +97,18 @@ _MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
|||
|
||||
|
||||
def reshard_mp(
|
||||
state_dicts: List[Dict[str, torch.Tensor]],
|
||||
state_dicts: list[dict[str, torch.Tensor]],
|
||||
size: int,
|
||||
rank: int,
|
||||
repeat_qk_qv: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||
If the list has more than one state dict, we concatenate the values of the same
|
||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||
"""
|
||||
|
||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
def concat_or_chunk(tensors: list[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
if len(tensors) > 1:
|
||||
return torch.cat(tensors, dim=dim)
|
||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||
|
@ -144,7 +144,7 @@ def reshard_mp(
|
|||
column_regex = re.compile("|".join(column_keys))
|
||||
row_regex = re.compile("|".join(row_keys))
|
||||
|
||||
output: Dict[str, torch.Tensor] = {}
|
||||
output: dict[str, torch.Tensor] = {}
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Note: only processes keys in the first state dict.
|
||||
# Assumes keys are the same across all state dicts.
|
||||
|
@ -154,7 +154,7 @@ def reshard_mp(
|
|||
return output
|
||||
|
||||
|
||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||
def convert_moe_weights(state_dict: dict[str, Any], num_experts: int) -> dict[str, Any]:
|
||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
routed_regex = re.compile("|".join(routed_keys))
|
||||
keys = list(state_dict.keys())
|
||||
|
|
|
@ -7,10 +7,9 @@
|
|||
import base64
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
|
@ -31,21 +30,21 @@ class BuiltinTool(Enum):
|
|||
code_interpreter = "code_interpreter"
|
||||
|
||||
|
||||
Primitive = Union[str, int, float, bool, None]
|
||||
RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||
Primitive = str | int | float | bool | None
|
||||
RecursiveType = Primitive | list[Primitive] | dict[str, Primitive]
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
call_id: str
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
tool_name: BuiltinTool | str
|
||||
# Plan is to deprecate the Dict in favor of a JSON string
|
||||
# that is parsed on the client side instead of trying to manage
|
||||
# the recursive type here.
|
||||
# Making this a union so that client side can start prepping for this change.
|
||||
# Eventually, we will remove both the Dict and arguments_json field,
|
||||
# and arguments will just be a str
|
||||
arguments: Union[str, Dict[str, RecursiveType]]
|
||||
arguments_json: Optional[str] = None
|
||||
arguments: str | dict[str, RecursiveType]
|
||||
arguments_json: str | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -91,15 +90,15 @@ class StopReason(Enum):
|
|||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
description: str | None = None
|
||||
required: bool | None = True
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
tool_name: BuiltinTool | str
|
||||
description: str | None = None
|
||||
parameters: dict[str, ToolParamDefinition] | None = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
|
@ -119,7 +118,7 @@ class RawMediaItem(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("data")
|
||||
def serialize_data(self, data: Optional[bytes], _info):
|
||||
def serialize_data(self, data: bytes | None, _info):
|
||||
if data is None:
|
||||
return None
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
@ -137,9 +136,9 @@ class RawTextItem(BaseModel):
|
|||
text: str
|
||||
|
||||
|
||||
RawContentItem = Annotated[Union[RawTextItem, RawMediaItem], Field(discriminator="type")]
|
||||
RawContentItem = Annotated[RawTextItem | RawMediaItem, Field(discriminator="type")]
|
||||
|
||||
RawContent = str | RawContentItem | List[RawContentItem]
|
||||
RawContent = str | RawContentItem | list[RawContentItem]
|
||||
|
||||
|
||||
class RawMessage(BaseModel):
|
||||
|
@ -147,17 +146,17 @@ class RawMessage(BaseModel):
|
|||
content: RawContent
|
||||
|
||||
# This is for RAG but likely should be absorbed into content
|
||||
context: Optional[RawContent] = None
|
||||
context: RawContent | None = None
|
||||
|
||||
# These are for the output message coming from the assistant
|
||||
stop_reason: Optional[StopReason] = None
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
stop_reason: StopReason | None = None
|
||||
tool_calls: list[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
class GenerationResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
logprobs: list[float] | None = None
|
||||
|
||||
source: Literal["input"] | Literal["output"]
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
|||
|
||||
@dataclass
|
||||
class QuantizationArgs:
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
scheme: QuantizationScheme | None = None
|
||||
group_size: int | None = None
|
||||
spinquant: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -39,10 +38,10 @@ class ModelArgs:
|
|||
dim: int = 4096
|
||||
n_layers: int = 32
|
||||
n_heads: int = 32
|
||||
n_kv_heads: Optional[int] = None
|
||||
n_kv_heads: int | None = None
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_dim_multiplier: float | None = None
|
||||
norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
|
@ -55,8 +54,8 @@ class ModelArgs:
|
|||
vision_max_num_chunks: int = 4
|
||||
vision_num_cross_attention_layers: int = -1
|
||||
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
quantization_args: QuantizationArgs | None = None
|
||||
lora_args: LoRAArgs | None = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
|
|
|
@ -8,7 +8,6 @@ import io
|
|||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
|
@ -29,14 +28,14 @@ from .tool_utils import ToolUtils
|
|||
|
||||
@dataclass
|
||||
class VisionInput:
|
||||
mask: List[List[int]]
|
||||
images: List[PIL_Image.Image]
|
||||
mask: list[list[int]]
|
||||
images: list[PIL_Image.Image]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
tokens: List[int]
|
||||
vision: Optional[VisionInput] = None
|
||||
tokens: list[int]
|
||||
vision: VisionInput | None = None
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
|
@ -50,7 +49,7 @@ def role_str(role: Role) -> str:
|
|||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
possible_headers: dict[Role, str]
|
||||
|
||||
def __init__(self, tokenizer: Tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -58,7 +57,7 @@ class ChatFormat:
|
|||
self.possible_headers = {role: f"<|start_header_id|>{role_str(role)}<|end_header_id|>\n\n" for role in Role}
|
||||
self.vision_token = self.tokenizer.special_tokens["<|image|>"]
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
def _encode_header(self, role: str) -> list[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
|
@ -70,7 +69,7 @@ class ChatFormat:
|
|||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||
tokens = []
|
||||
images = []
|
||||
|
||||
|
@ -107,7 +106,7 @@ class ChatFormat:
|
|||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[PIL_Image.Image]]:
|
||||
) -> tuple[list[int], list[PIL_Image.Image]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
|
@ -145,8 +144,8 @@ class ChatFormat:
|
|||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
messages: list[RawMessage],
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> LLMInput:
|
||||
tool_prompt_format = tool_prompt_format or ToolPromptFormat.json
|
||||
tokens = []
|
||||
|
@ -163,7 +162,7 @@ class ChatFormat:
|
|||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
@ -234,7 +233,7 @@ class ChatFormat:
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[PIL_Image.Image]) -> LLMInput:
|
||||
def _model_input_from_tokens_images(self, tokens: list[int], images: list[PIL_Image.Image]) -> LLMInput:
|
||||
vision_input = None
|
||||
if len(images) > 0:
|
||||
vision_input = VisionInput(
|
||||
|
@ -249,9 +248,9 @@ class ChatFormat:
|
|||
|
||||
|
||||
def create_vision_mask(
|
||||
tokens: List[int],
|
||||
tokens: list[int],
|
||||
vision_token: int,
|
||||
) -> List[List[int]]:
|
||||
) -> list[list[int]]:
|
||||
vision_token_locations = [i for i, token in enumerate(tokens) if token == vision_token]
|
||||
if len(vision_token_locations) == 0:
|
||||
return []
|
||||
|
|
|
@ -15,8 +15,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -41,8 +41,8 @@ class Llama3:
|
|||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
world_size: int | None = None,
|
||||
quantization_mode: QuantizationMode | None = None,
|
||||
seed: int = 1,
|
||||
device: str = "cuda",
|
||||
):
|
||||
|
@ -82,7 +82,7 @@ class Llama3:
|
|||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
with open(Path(ckpt_dir) / "params.json") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
|
@ -154,15 +154,15 @@ class Llama3:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
llm_inputs: list[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
params = self.model.params
|
||||
|
@ -302,13 +302,13 @@ class Llama3:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
contents: list[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
|
@ -324,14 +324,14 @@ class Llama3:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
messages_batch: list[list[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
|
@ -131,7 +130,7 @@ class LLama31Interface:
|
|||
self.formatter = ChatFormat(self.tokenizer)
|
||||
self.tool_prompt_format = tool_prompt_format
|
||||
|
||||
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
|
||||
def get_tokens(self, messages: list[RawMessage]) -> list[int]:
|
||||
model_input = self.formatter.encode_dialog_prompt(
|
||||
messages,
|
||||
self.tool_prompt_format,
|
||||
|
@ -149,10 +148,10 @@ class LLama31Interface:
|
|||
|
||||
def system_messages(
|
||||
self,
|
||||
builtin_tools: List[BuiltinTool],
|
||||
custom_tools: List[ToolDefinition],
|
||||
instruction: Optional[str] = None,
|
||||
) -> List[RawMessage]:
|
||||
builtin_tools: list[BuiltinTool],
|
||||
custom_tools: list[ToolDefinition],
|
||||
instruction: str | None = None,
|
||||
) -> list[RawMessage]:
|
||||
messages = []
|
||||
|
||||
default_gen = SystemDefaultGenerator()
|
||||
|
@ -194,8 +193,8 @@ class LLama31Interface:
|
|||
self,
|
||||
content: str,
|
||||
stop_reason: StopReason,
|
||||
tool_call: Optional[ToolCall] = None,
|
||||
) -> List[RawMessage]:
|
||||
tool_call: ToolCall | None = None,
|
||||
) -> list[RawMessage]:
|
||||
tool_calls = []
|
||||
if tool_call:
|
||||
tool_calls.append(tool_call)
|
||||
|
@ -208,7 +207,7 @@ class LLama31Interface:
|
|||
)
|
||||
]
|
||||
|
||||
def user_message(self, content: str) -> List[RawMessage]:
|
||||
def user_message(self, content: str) -> list[RawMessage]:
|
||||
return [RawMessage(role="user", content=content)]
|
||||
|
||||
def display_message_as_tokens(self, message: RawMessage) -> None:
|
||||
|
@ -228,7 +227,7 @@ class LLama31Interface:
|
|||
print("\n", end="")
|
||||
|
||||
|
||||
def list_jinja_templates() -> List[Template]:
|
||||
def list_jinja_templates() -> list[Template]:
|
||||
return TEMPLATES
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -80,7 +79,7 @@ def apply_rotary_emb(
|
|||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
|
@ -162,7 +161,7 @@ class Attention(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
mask: torch.Tensor | None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
@ -204,7 +203,7 @@ class FeedForward(nn.Module):
|
|||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
ffn_dim_multiplier: float | None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
@ -243,7 +242,7 @@ class TransformerBlock(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
mask: torch.Tensor | None,
|
||||
):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import math
|
||||
from collections import defaultdict
|
||||
from logging import getLogger
|
||||
from typing import Any, Optional, Set, Tuple
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
|
@ -26,7 +26,7 @@ IMAGE_RES = 224
|
|||
logger = getLogger()
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
class VariableSizeImageTransform:
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
@ -75,7 +75,7 @@ class VariableSizeImageTransform(object):
|
|||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
def get_factors(n: int) -> set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
@ -145,9 +145,9 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
image_size: tuple[int, int],
|
||||
target_size: tuple[int, int],
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
@ -198,8 +198,8 @@ class VariableSizeImageTransform(object):
|
|||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
target_size: tuple[int, int],
|
||||
max_upscaling_size: int | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
@ -261,10 +261,10 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
image_size: tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
@ -364,7 +364,7 @@ class VariableSizeImageTransform(object):
|
|||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[Any, Any]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
|
|
|
@ -6,8 +6,9 @@
|
|||
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -104,9 +105,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
|||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int],
|
||||
bias: bool | None = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
|
@ -390,13 +391,13 @@ class VisionEncoder(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
missing_keys: list[str] = None,
|
||||
unexpected_keys: list[str] = None,
|
||||
error_msgs: list[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
|
@ -641,7 +642,7 @@ class FeedForward(nn.Module):
|
|||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float],
|
||||
ffn_dim_multiplier: float | None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
@ -983,7 +984,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module):
|
|||
self,
|
||||
x: torch.Tensor,
|
||||
xattn_mask: torch.Tensor,
|
||||
full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor],
|
||||
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
|
||||
xattn_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_attn_out = self.attention(
|
||||
|
@ -1144,7 +1145,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
def _init_fusion_schedule(
|
||||
self,
|
||||
num_layers: int,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
llama_layers = list(range(self.n_llama_layers))
|
||||
|
||||
# uniformly spread the layers
|
||||
|
@ -1231,7 +1232,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
text_dtype,
|
||||
vision_tokens,
|
||||
cross_attention_masks,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
assert vision_tokens is not None, "Vision tokens must be provided"
|
||||
vision_seqlen = vision_tokens.shape[3]
|
||||
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
||||
|
@ -1280,11 +1281,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
|
||||
def compute_vision_tokens_masks(
|
||||
self,
|
||||
batch_images: List[List[PIL_Image.Image]],
|
||||
batch_masks: List[List[List[int]]],
|
||||
batch_images: list[list[PIL_Image.Image]],
|
||||
batch_masks: list[list[list[int]]],
|
||||
total_len: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
skip_vision_encoder = False
|
||||
|
||||
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
||||
|
@ -1371,11 +1372,11 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
|
||||
|
||||
def _stack_images(
|
||||
images: List[List[PIL_Image.Image]],
|
||||
images: list[list[PIL_Image.Image]],
|
||||
max_num_chunks: int,
|
||||
image_res: int,
|
||||
max_num_images: int,
|
||||
) -> Tuple[torch.Tensor, List[int]]:
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
"""
|
||||
Takes a list of list of images and stacks them into a tensor.
|
||||
This function is needed since images can be of completely
|
||||
|
@ -1400,8 +1401,8 @@ def _stack_images(
|
|||
|
||||
|
||||
def _pad_masks(
|
||||
all_masks: List[List[List[int]]],
|
||||
all_num_chunks: List[List[int]],
|
||||
all_masks: list[list[list[int]]],
|
||||
all_num_chunks: list[list[int]],
|
||||
total_len: int,
|
||||
max_num_chunks: int,
|
||||
) -> torch.Tensor:
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
|
@ -20,7 +20,7 @@ from jinja2 import Template
|
|||
@dataclass
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
data: Dict[str, Any]
|
||||
data: dict[str, Any]
|
||||
|
||||
def render(self):
|
||||
template = Template(self.template)
|
||||
|
@ -35,5 +35,5 @@ class PromptTemplateGeneratorBase:
|
|||
def gen(self, *args, **kwargs) -> PromptTemplate:
|
||||
raise NotImplementedError()
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
def data_examples(self) -> list[Any]:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
import textwrap
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
|
@ -39,12 +39,12 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
|
|||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[Any]:
|
||||
def data_examples(self) -> list[Any]:
|
||||
return [None]
|
||||
|
||||
|
||||
class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
||||
def _tool_breakdown(self, tools: List[ToolDefinition]):
|
||||
def _tool_breakdown(self, tools: list[ToolDefinition]):
|
||||
builtin_tools, custom_tools = [], []
|
||||
for dfn in tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
|
@ -54,7 +54,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
|
||||
return builtin_tools, custom_tools
|
||||
|
||||
def gen(self, tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def gen(self, tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
builtin_tools, custom_tools = self._tool_breakdown(tools)
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
|
@ -75,7 +75,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
# builtin tools
|
||||
[
|
||||
|
@ -91,7 +91,7 @@ class BuiltinToolGenerator(PromptTemplateGeneratorBase):
|
|||
|
||||
|
||||
class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Answer the user's question by making use of the following functions if needed.
|
||||
|
@ -137,7 +137,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
|
@ -161,7 +161,7 @@ class JsonCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
|
||||
|
||||
class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
||||
def gen(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def gen(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
You have access to the following functions:
|
||||
|
@ -199,7 +199,7 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
|
|||
{"custom_tools": [t.model_dump() for t in custom_tools]},
|
||||
)
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
|
@ -238,14 +238,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
@ -291,7 +291,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import Optional
|
||||
|
||||
from .base import PromptTemplate, PromptTemplateGeneratorBase
|
||||
|
||||
|
@ -21,8 +20,8 @@ class ToolResponseGenerator(PromptTemplateGeneratorBase):
|
|||
def gen(
|
||||
self,
|
||||
status: str,
|
||||
stdout: Optional[str] = None,
|
||||
stderr: Optional[str] = None,
|
||||
stdout: str | None = None,
|
||||
stderr: str | None = None,
|
||||
):
|
||||
assert status in [
|
||||
"success",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -37,9 +37,9 @@ def swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
quantization_mode: str | None = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||
|
@ -52,8 +52,8 @@ def convert_to_quantized_model(
|
|||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer:
|
||||
# Move weights to GPU with quantization
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
|
@ -122,8 +122,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
lora_rank: int | None = None,
|
||||
lora_scale: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
|
@ -134,8 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
self.lora_scale: Optional[float] = None
|
||||
self.adaptor: Optional[nn.Sequential] = None
|
||||
self.lora_scale: float | None = None
|
||||
self.adaptor: nn.Sequential | None = None
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
|
@ -147,13 +147,13 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
|
@ -191,13 +191,13 @@ class Int8WeightEmbedding(torch.nn.Embedding):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
|
@ -221,13 +221,13 @@ class Int8WeightLinear(torch.nn.Linear):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
|
@ -238,8 +238,8 @@ class Int8WeightLinear(torch.nn.Linear):
|
|||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
lora_rank: int | None,
|
||||
lora_scale: float | None,
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
|
@ -265,7 +265,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
elif isinstance(module, ColumnParallelLinear | RowParallelLinear | nn.Linear):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
|
@ -286,7 +286,7 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
def convert_to_int4_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
model_args = model.params
|
||||
|
|
|
@ -5,18 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -44,7 +37,7 @@ class Tokenizer:
|
|||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
special_tokens: dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
|
@ -116,9 +109,9 @@ class Tokenizer:
|
|||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
allowed_special: Literal["all"] | Set[str] | None = None,
|
||||
disallowed_special: Literal["all"] | Collection[str] = (),
|
||||
) -> list[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
|
@ -151,7 +144,7 @@ class Tokenizer:
|
|||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
t: list[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
|
@ -177,7 +170,7 @@ class Tokenizer:
|
|||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
return self.model.decode(cast(list[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -172,7 +171,7 @@ class ToolUtils:
|
|||
return match is not None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
def maybe_extract_builtin_tool_call(message_body: str) -> tuple[str, str] | None:
|
||||
# Find the first match in the text
|
||||
match = re.search(BUILTIN_TOOL_PATTERN, message_body)
|
||||
|
||||
|
@ -185,7 +184,7 @@ class ToolUtils:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> Optional[Tuple[str, str]]:
|
||||
def maybe_extract_custom_tool_call(message_body: str) -> tuple[str, str] | None:
|
||||
# NOTE: Custom function too calls are still experimental
|
||||
# Sometimes, response is of the form
|
||||
# {"type": "function", "name": "function_name", "parameters": {...}
|
||||
|
@ -252,7 +251,7 @@ class ToolUtils:
|
|||
def format_value(value: RecursiveType) -> str:
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
elif isinstance(value, (int, float, bool)) or value is None:
|
||||
elif isinstance(value, int | float | bool) or value is None:
|
||||
return str(value)
|
||||
elif isinstance(value, list):
|
||||
return f"[{', '.join(format_value(v) for v in value)}]"
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -73,7 +72,7 @@ def wolfram_alpha_response():
|
|||
)
|
||||
|
||||
|
||||
def usecases() -> List[UseCase | str]:
|
||||
def usecases() -> list[UseCase | str]:
|
||||
return [
|
||||
textwrap.dedent(
|
||||
"""
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
|
@ -74,7 +73,7 @@ def wolfram_alpha_response():
|
|||
)
|
||||
|
||||
|
||||
def usecases() -> List[UseCase | str]:
|
||||
def usecases() -> list[UseCase | str]:
|
||||
return [
|
||||
textwrap.dedent(
|
||||
"""
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
@ -15,8 +14,8 @@ class QuantizationScheme(Enum):
|
|||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
scheme: QuantizationScheme | None = None
|
||||
group_size: int | None = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
|
@ -58,32 +57,32 @@ class ModelArgs(BaseModel):
|
|||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
n_kv_heads: int | None = None
|
||||
head_dim: int | None = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
ffn_dim_multiplier: float | None = None
|
||||
ffn_exp: float | None = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
attention_chunk_size: int | None = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
rope_scaling_factor: Optional[float] = None
|
||||
rope_high_freq_factor: Optional[float] = None
|
||||
rope_scaling_factor: float | None = None
|
||||
rope_high_freq_factor: float | None = None
|
||||
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
nope_layer_interval: int | None = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
vision_args: VisionArgs | None = None
|
||||
moe_args: MoEArgs | None = None
|
||||
quantization_args: QuantizationArgs | None = None
|
||||
lora_args: LoRAArgs | None = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
|
|
@ -8,7 +8,6 @@ import io
|
|||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
@ -46,10 +45,10 @@ def role_str(role: Role) -> str:
|
|||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: Tuple[int, int]
|
||||
aspect_ratio: tuple[int, int]
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
|
@ -59,12 +58,12 @@ def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255
|
|||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
possible_headers: dict[Role, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
vision_args: Optional[VisionArgs] = None,
|
||||
vision_args: VisionArgs | None = None,
|
||||
max_num_chunks: int = 16,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -81,7 +80,7 @@ class ChatFormat:
|
|||
vision_args.image_size.width, vision_args.image_size.height
|
||||
)
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
def _encode_header(self, role: str) -> list[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
||||
|
||||
|
@ -98,7 +97,7 @@ class ChatFormat:
|
|||
def _encode_image(
|
||||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
|
@ -140,7 +139,7 @@ class ChatFormat:
|
|||
|
||||
return tokens
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> tuple[list[int], list[TransformedImage]]:
|
||||
tokens = []
|
||||
tranformed_images = []
|
||||
|
||||
|
@ -189,7 +188,7 @@ class ChatFormat:
|
|||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[TransformedImage]]:
|
||||
) -> tuple[list[int], list[TransformedImage]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
|
@ -223,7 +222,7 @@ class ChatFormat:
|
|||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
messages: list[RawMessage],
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> LLMInput:
|
||||
tokens = []
|
||||
|
@ -240,7 +239,7 @@ class ChatFormat:
|
|||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
def decode_assistant_message(self, tokens: list[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
@ -312,7 +311,7 @@ class ChatFormat:
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
||||
def _model_input_from_tokens_images(self, tokens: list[int], images: list[TransformedImage]) -> LLMInput:
|
||||
return LLMInput(
|
||||
tokens=tokens,
|
||||
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -30,7 +29,7 @@ class LLMInput:
|
|||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
images: list[torch.Tensor] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -45,8 +44,8 @@ class TransformerInput:
|
|||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
tokens_position: torch.Tensor | int
|
||||
image_embedding: MaskedEmbedding | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
|
@ -36,13 +36,13 @@ class FeedForward(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
|
|
|
@ -10,8 +10,8 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -38,8 +38,8 @@ class Llama4:
|
|||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
world_size: int | None = None,
|
||||
quantization_mode: QuantizationMode | None = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -63,7 +63,7 @@ class Llama4:
|
|||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
with open(Path(ckpt_dir) / "params.json") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
|
@ -117,15 +117,15 @@ class Llama4:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
llm_inputs: list[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
logits_processor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
|
@ -245,13 +245,13 @@ class Llama4:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
contents: list[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
@ -267,13 +267,13 @@ class Llama4:
|
|||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
messages_batch: list[list[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
max_gen_len: int | None = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
) -> Generator[list[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -89,7 +89,7 @@ def apply_rotary_emb(
|
|||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
|
@ -174,13 +174,13 @@ class Attention(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
|
@ -200,7 +200,7 @@ class Attention(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
@ -288,13 +288,13 @@ class TransformerBlock(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
@ -318,8 +318,8 @@ class TransformerBlock(nn.Module):
|
|||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
global_attn_mask: torch.Tensor | None,
|
||||
local_attn_mask: torch.Tensor | None,
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
|
@ -374,13 +374,13 @@ class Transformer(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -63,13 +63,13 @@ class Experts(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
|
@ -158,13 +158,13 @@ class MoE(torch.nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
missing_keys: list[str],
|
||||
unexpected_keys: list[str],
|
||||
error_msgs: list[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
|
@ -210,5 +210,5 @@ class MoE(torch.nn.Module):
|
|||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
|
||||
return numerator // denominator
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
|
@ -52,7 +51,7 @@ class ResizeNormalizeImageTransform:
|
|||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
class VariableSizeImageTransform:
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
@ -100,7 +99,7 @@ class VariableSizeImageTransform(object):
|
|||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
def get_factors(n: int) -> set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
@ -170,9 +169,9 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
image_size: tuple[int, int],
|
||||
target_size: tuple[int, int],
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
@ -223,8 +222,8 @@ class VariableSizeImageTransform(object):
|
|||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
target_size: tuple[int, int],
|
||||
max_upscaling_size: int | None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
@ -289,10 +288,10 @@ class VariableSizeImageTransform(object):
|
|||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
image_size: tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
@ -392,7 +391,7 @@ class VariableSizeImageTransform(object):
|
|||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
) -> tuple[torch.Tensor, tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||
|
@ -67,14 +66,14 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
""".strip("\n")
|
||||
)
|
||||
|
||||
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||
def gen(self, custom_tools: list[ToolDefinition], system_prompt: str | None = None) -> PromptTemplate:
|
||||
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||
return PromptTemplate(
|
||||
system_prompt,
|
||||
{"function_description": self._gen_function_description(custom_tools)},
|
||||
)
|
||||
|
||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
@ -120,7 +119,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{"tools": [t.model_dump() for t in custom_tools]},
|
||||
).render()
|
||||
|
||||
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||
def data_examples(self) -> list[list[ToolDefinition]]:
|
||||
return [
|
||||
[
|
||||
ToolDefinition(
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||
PythonListCustomToolGenerator,
|
||||
|
@ -23,7 +22,7 @@ from ..prompt_format import (
|
|||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||
def usecases(base_model: bool = False) -> list[UseCase | str]:
|
||||
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
||||
img_small_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -45,8 +45,8 @@ def experts_batched_swiglu_wrapper(
|
|||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
quantization_mode: str | None = None,
|
||||
fp8_activation_scale_ub: float | None = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
|
@ -213,7 +213,7 @@ def logging_callbacks(
|
|||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
def update_status(message: str | None, completed: int | None = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
|
|
|
@ -5,18 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections.abc import Collection, Iterator, Sequence, Set
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
@ -114,7 +107,7 @@ class Tokenizer:
|
|||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
special_tokens: dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 2048
|
||||
|
||||
|
@ -182,9 +175,9 @@ class Tokenizer:
|
|||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
allowed_special: Literal["all"] | Set[str] | None = None,
|
||||
disallowed_special: Literal["all"] | Collection[str] = (),
|
||||
) -> list[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
|
@ -217,7 +210,7 @@ class Tokenizer:
|
|||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
t: list[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
|
@ -243,7 +236,7 @@ class Tokenizer:
|
|||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
return self.model.decode(cast(list[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -136,13 +137,13 @@ class VisionEmbeddings(torch.nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
missing_keys: list[str] = None,
|
||||
unexpected_keys: list[str] = None,
|
||||
error_msgs: list[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
|
@ -163,7 +164,7 @@ class VisionEmbeddings(torch.nn.Module):
|
|||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_batch: list[list[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
|
@ -42,9 +43,9 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
|
|||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
kernel_size: int | tuple[int, int],
|
||||
stride: int | tuple[int, int],
|
||||
bias: bool | None = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
|
@ -134,15 +135,15 @@ class _TransformerBlock(nn.Module):
|
|||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
freq_cis: torch.Tensor | None = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
freq_cis: torch.Tensor | None = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
@ -210,8 +211,8 @@ class PackingIndex:
|
|||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
image_size: tuple[int, int],
|
||||
patch_size: tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
|
@ -299,13 +300,13 @@ class VisionEncoder(nn.Module):
|
|||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
state_dict: dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
local_metadata: dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
missing_keys: list[str] = None,
|
||||
unexpected_keys: list[str] = None,
|
||||
error_msgs: list[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
import json
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -44,7 +43,7 @@ class TextCompletionContent(BaseModel):
|
|||
class UseCase(BaseModel):
|
||||
title: str = ""
|
||||
description: str = ""
|
||||
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||
dialogs: list[list[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||
notes: str = ""
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||
max_gen_len: int = 512
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -27,7 +26,7 @@ class Fp8ScaledWeights:
|
|||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
|
@ -51,7 +50,7 @@ class Int4ScaledWeights:
|
|||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Int4Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
def __class__(self) -> type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
|
@ -74,7 +73,7 @@ class Int4Weights(
|
|||
def int4_row_quantize(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_bit = 4 # Number of target bits.
|
||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||
|
||||
|
@ -115,8 +114,8 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
|||
|
||||
def bmm_nt(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w: Fp8RowwiseWeights | Int4Weights,
|
||||
num_tokens: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if isinstance(w, Fp8ScaledWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||
|
@ -129,10 +128,10 @@ def bmm_nt(
|
|||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w1: Fp8RowwiseWeights | Int4Weights,
|
||||
w3: Fp8RowwiseWeights | Int4Weights,
|
||||
w2: Fp8RowwiseWeights | Int4Weights,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||
|
@ -158,7 +157,7 @@ def ffn_swiglu(
|
|||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
|
@ -184,7 +183,7 @@ def quantize_fp8(
|
|||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
|
@ -213,7 +212,7 @@ def load_fp8(
|
|||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
|
@ -239,7 +238,7 @@ def load_int4(
|
|||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
output_device: torch.device | None = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
|
@ -256,9 +255,9 @@ def load_int4(
|
|||
|
||||
def fc_dynamic(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w: Fp8RowwiseWeights | Int4Weights,
|
||||
activation_scale_ub: Tensor | None = None,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
|
@ -275,11 +274,11 @@ def fc_dynamic(
|
|||
|
||||
def ffn_swiglu_dynamic(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
w1: Fp8RowwiseWeights | Int4Weights,
|
||||
w3: Fp8RowwiseWeights | Int4Weights,
|
||||
w2: Fp8RowwiseWeights | Int4Weights,
|
||||
activation_scale_ub: Tensor | None = None,
|
||||
num_tokens: Tensor | None = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
assert x.dim() == 3 or x.dim() == 2
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .sku_types import (
|
||||
CheckpointQuantizationFormat,
|
||||
|
@ -19,14 +18,14 @@ LLAMA2_VOCAB_SIZE = 32000
|
|||
LLAMA3_VOCAB_SIZE = 128256
|
||||
|
||||
|
||||
def resolve_model(descriptor: str) -> Optional[Model]:
|
||||
def resolve_model(descriptor: str) -> Model | None:
|
||||
for m in all_registered_models():
|
||||
if descriptor in (m.descriptor(), m.huggingface_repo):
|
||||
return m
|
||||
return None
|
||||
|
||||
|
||||
def all_registered_models() -> List[Model]:
|
||||
def all_registered_models() -> list[Model]:
|
||||
return (
|
||||
llama2_family()
|
||||
+ llama3_family()
|
||||
|
@ -38,48 +37,48 @@ def all_registered_models() -> List[Model]:
|
|||
)
|
||||
|
||||
|
||||
def llama2_family() -> List[Model]:
|
||||
def llama2_family() -> list[Model]:
|
||||
return [
|
||||
*llama2_base_models(),
|
||||
*llama2_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama3_family() -> List[Model]:
|
||||
def llama3_family() -> list[Model]:
|
||||
return [
|
||||
*llama3_base_models(),
|
||||
*llama3_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama3_1_family() -> List[Model]:
|
||||
def llama3_1_family() -> list[Model]:
|
||||
return [
|
||||
*llama3_1_base_models(),
|
||||
*llama3_1_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama3_2_family() -> List[Model]:
|
||||
def llama3_2_family() -> list[Model]:
|
||||
return [
|
||||
*llama3_2_base_models(),
|
||||
*llama3_2_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama3_3_family() -> List[Model]:
|
||||
def llama3_3_family() -> list[Model]:
|
||||
return [
|
||||
*llama3_3_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama4_family() -> List[Model]:
|
||||
def llama4_family() -> list[Model]:
|
||||
return [
|
||||
*llama4_base_models(),
|
||||
*llama4_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama4_base_models() -> List[Model]:
|
||||
def llama4_base_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e,
|
||||
|
@ -98,7 +97,7 @@ def llama4_base_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama4_instruct_models() -> List[Model]:
|
||||
def llama4_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
|
@ -126,7 +125,7 @@ def llama4_instruct_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama2_base_models() -> List[Model]:
|
||||
def llama2_base_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama2_7b,
|
||||
|
@ -185,7 +184,7 @@ def llama2_base_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_base_models() -> List[Model]:
|
||||
def llama3_base_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_8b,
|
||||
|
@ -226,7 +225,7 @@ def llama3_base_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_1_base_models() -> List[Model]:
|
||||
def llama3_1_base_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_1_8b,
|
||||
|
@ -324,7 +323,7 @@ def llama3_1_base_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_2_base_models() -> List[Model]:
|
||||
def llama3_2_base_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_2_1b,
|
||||
|
@ -407,7 +406,7 @@ def llama3_2_base_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama2_instruct_models() -> List[Model]:
|
||||
def llama2_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama2_7b_chat,
|
||||
|
@ -466,7 +465,7 @@ def llama2_instruct_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_instruct_models() -> List[Model]:
|
||||
def llama3_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||
|
@ -507,7 +506,7 @@ def llama3_instruct_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_1_instruct_models() -> List[Model]:
|
||||
def llama3_1_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||
|
@ -635,7 +634,7 @@ def arch_args_3b() -> dict:
|
|||
}
|
||||
|
||||
|
||||
def llama3_2_quantized_models() -> List[Model]:
|
||||
def llama3_2_quantized_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||
|
@ -704,7 +703,7 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_2_instruct_models() -> List[Model]:
|
||||
def llama3_2_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||
|
@ -766,7 +765,7 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama3_3_instruct_models() -> List[Model]:
|
||||
def llama3_3_instruct_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||
|
@ -790,7 +789,7 @@ def llama3_3_instruct_models() -> List[Model]:
|
|||
|
||||
|
||||
@lru_cache
|
||||
def safety_models() -> List[Model]:
|
||||
def safety_models() -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama_guard_4_12b,
|
||||
|
@ -919,7 +918,7 @@ def safety_models() -> List[Model]:
|
|||
@dataclass
|
||||
class LlamaDownloadInfo:
|
||||
folder: str
|
||||
files: List[str]
|
||||
files: list[str]
|
||||
pth_size: int
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
@ -159,13 +159,13 @@ def model_family(model_id) -> ModelFamily:
|
|||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
arch_args: Dict[str, Any]
|
||||
huggingface_repo: str | None = None
|
||||
arch_args: dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue