mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: introduce llama4 support (#1877)
As title says. Details in README, elsewhere.
This commit is contained in:
parent
23a99a4b22
commit
b8f1561956
61 changed files with 205222 additions and 6439 deletions
|
@ -255,7 +255,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn_messages
|
||||
input_messages = last_turn.input_messages
|
||||
|
||||
turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .inference import resolve_model
|
||||
from .llama3.generation import Llama3
|
||||
from .llama4.generation import Llama4
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature or 1.0
|
||||
top_p = sampling_params.strategy.top_p or 1.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class Llama4Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
quantization_mode = "fp8_mixed"
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
quantization_mode = "int4_mixed"
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
self.inner_generator = Llama4.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_content(request.content),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Llama3Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
self.inner_generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
)
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_content(request.content),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
|
@ -34,11 +34,16 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
|
@ -55,7 +60,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .generators import Llama3Generator, Llama4Generator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -64,6 +69,14 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
|
||||
return Llama3Generator(config, model_id, llama_model)
|
||||
|
||||
|
||||
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
|
||||
return Llama4Generator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
|
@ -77,29 +90,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
@ -127,11 +121,57 @@ class MetaReferenceInferenceImpl(
|
|||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
if llama_model.model_family in {
|
||||
ModelFamily.llama3,
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
}:
|
||||
builder_fn = llama3_builder_fn
|
||||
elif llama_model.model_family == ModelFamily.llama4:
|
||||
builder_fn = llama4_builder_fn
|
||||
else:
|
||||
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=llama_model.pth_file_count,
|
||||
builder_fn=builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
if llama_model.model_family == ModelFamily.llama4
|
||||
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
||||
),
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -164,14 +204,16 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
@ -205,6 +247,8 @@ class MetaReferenceInferenceImpl(
|
|||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -212,9 +256,9 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
|
@ -225,11 +269,9 @@ class MetaReferenceInferenceImpl(
|
|||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
tokens = tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
|
@ -288,6 +330,8 @@ class MetaReferenceInferenceImpl(
|
|||
return await self._nonstream_chat_completion(request)
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
|
@ -296,9 +340,9 @@ class MetaReferenceInferenceImpl(
|
|||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
|
@ -326,6 +370,8 @@ class MetaReferenceInferenceImpl(
|
|||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -355,10 +401,10 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
|
|
@ -4,17 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -23,27 +19,16 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from ..common import TokenResult, model_checkpoint_dir
|
||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
|
@ -51,7 +36,7 @@ from .args import ModelArgs
|
|||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
class Llama3:
|
||||
|
@ -146,7 +131,7 @@ class Llama3:
|
|||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from ..quantization.loader import convert_to_fp8_quantized_model
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
|
@ -159,7 +144,7 @@ class Llama3:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from ..quantization.loader import convert_to_int4_quantized_model
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
|
@ -169,7 +154,7 @@ class Llama3:
|
|||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from ..quantization.hadamard_utils import (
|
||||
from ..hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
|
@ -222,9 +207,8 @@ class Llama3:
|
|||
top_p: float = 0.9,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
logits_processor: Optional["LogitsProcessor"] = None,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
|
@ -292,7 +276,7 @@ class Llama3:
|
|||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
|
@ -336,58 +320,6 @@ class Llama3:
|
|||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_config.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
|
@ -412,72 +344,3 @@ def sample_top_p(probs, p):
|
|||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature
|
||||
top_p = sampling_params.strategy.top_p
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
|
|
@ -7,9 +7,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import logging
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -19,22 +19,27 @@ from torch import Tensor, nn
|
|||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
ffn_swiglu,
|
||||
load_fp8,
|
||||
quantize_fp8,
|
||||
)
|
||||
|
||||
from ...llama3.args import ModelArgs
|
||||
from ...llama3.model import Transformer, TransformerBlock
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
from ...config import MetaReferenceQuantizedInferenceConfig
|
||||
from ..args import ModelArgs
|
||||
from ..model import Transformer, TransformerBlock
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(__name__, category="quantization")
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from .fp8_impls import ffn_swiglu
|
||||
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
@ -51,8 +56,7 @@ def convert_to_fp8_quantized_model(
|
|||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
assert config.model is not None, "Model must be specified for quantized inference"
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
|
@ -82,7 +86,7 @@ def convert_to_fp8_quantized_model(
|
|||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
|
@ -136,6 +140,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
self.lora_scale: Optional[float] = None
|
||||
self.adaptor: Optional[nn.Sequential] = None
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
|
@ -143,9 +149,6 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
else:
|
||||
self.adaptor = None
|
||||
self.lora_scale = None
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
|
@ -293,10 +296,10 @@ def convert_to_int4_quantized_model(
|
|||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||
quantization_args = model_args.quantization_args
|
||||
if quantization_args.scheme is None:
|
||||
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
||||
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
|
@ -317,4 +320,4 @@ def convert_to_int4_quantized_model(
|
|||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
||||
return cast(Transformer, model.to(device))
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
class LoRAArgs(BaseModel):
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
class MoEArgs(BaseModel):
|
||||
num_experts: int = -1
|
||||
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
|
||||
auto_scale_F: bool = ( # noqa: N815
|
||||
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
|
||||
)
|
||||
top_k: int = 1
|
||||
interleave_moe_layer_step: int = 1
|
||||
|
||||
|
||||
class Size(BaseModel):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class VisionArgs(BaseModel):
|
||||
image_size: Size
|
||||
patch_size: Size
|
||||
|
||||
# parameters for the encoder transformer
|
||||
dim: int
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
mlp_ratio: float
|
||||
output_dim: int
|
||||
|
||||
pixel_shuffle_ratio: float
|
||||
|
||||
|
||||
class ModelArgs(BaseModel):
|
||||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate(self) -> "ModelArgs":
|
||||
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
|
||||
assert self.n_heads % self.n_kv_heads == 0, (
|
||||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
return self
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
do_reduce: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.do_reduce = do_reduce
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
state_dict[prefix + "w1.weight"] = w1
|
||||
state_dict[prefix + "w3.weight"] = w3
|
||||
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
|
||||
out = F.linear(x, self.w2.weight)
|
||||
if self.do_reduce:
|
||||
return reduce_from_model_parallel_region(out)
|
||||
return out
|
|
@ -0,0 +1,330 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.models.llama.llama4.chat_format import (
|
||||
ChatFormat,
|
||||
RawContent,
|
||||
RawMessage,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
from ..common import TokenResult
|
||||
from .args import ModelArgs
|
||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||
from .model import Transformer
|
||||
|
||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||
|
||||
|
||||
class QuantizationMode(str, Enum):
|
||||
none = "none"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
||||
|
||||
class Llama4:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[str] = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert world_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
||||
)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
**params,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
# TODO: params.json should always have correct vocab_size
|
||||
if model_args.vocab_size == -1:
|
||||
model_args.vocab_size = tokenizer.n_words
|
||||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
print(f"Loading checkpoint from {ckpt_dir}...")
|
||||
with open(ckpt_path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
|
||||
print("Loaded checkpoint")
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
print("Done...")
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama4(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_input: LLMInput,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
params = self.model.args
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input and get_model_parallel_rank() == 0:
|
||||
tokens_to_print = list(llm_input.tokens)
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [llm_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i, t in enumerate(llm_input.tokens):
|
||||
yield TokenResult(
|
||||
token=t,
|
||||
text=self.tokenizer.decode([t]),
|
||||
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
image_embedding = None
|
||||
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
|
||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||
image_mask = image_mask.unsqueeze(-1)
|
||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||
|
||||
image_batch = [llm_input.images]
|
||||
image_embedding = MaskedEmbedding(
|
||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||
mask=image_mask,
|
||||
)
|
||||
|
||||
xformer_input = TransformerInput(
|
||||
tokens=tokens[:, prev_pos:cur_pos],
|
||||
tokens_position=prev_pos,
|
||||
image_embedding=image_embedding,
|
||||
)
|
||||
xformer_output = self.model.forward(xformer_input)
|
||||
logits = xformer_output.logits
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
||||
)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
content: RawContent,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
llm_input = self.formatter.encode_content(content)
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
if result.token in self.tokenizer.stop_tokens:
|
||||
break
|
||||
yield result
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
if result.token in self.tokenizer.stop_tokens:
|
||||
break
|
||||
yield result
|
||||
|
||||
def chat_completion_raw(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
):
|
||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
||||
output_tokens = []
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
):
|
||||
output_tokens.append(result.token)
|
||||
|
||||
return llm_input.tokens, output_tokens
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -0,0 +1,453 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
from .datatypes import TransformerInput, TransformerOutput
|
||||
from .ffn import FeedForward
|
||||
from .moe import MoE
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor):
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scale_factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# TODO: this module needs to be moved into a separate file since it can be used by
|
||||
# the vision encoder as well.
|
||||
def __init__(
|
||||
self,
|
||||
args: ModelArgs,
|
||||
use_qk_norm: bool,
|
||||
use_rope: bool,
|
||||
add_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
self.use_qk_norm = use_qk_norm
|
||||
# For attention temperature tuning
|
||||
self.attn_temperature_tuning = args.attn_temperature_tuning
|
||||
self.floor_scale = args.floor_scale
|
||||
self.attn_scale = args.attn_scale
|
||||
|
||||
self.n_heads = args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=add_bias,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
|
||||
if r != 0:
|
||||
raise ValueError(
|
||||
f"shape={tuple(wqkv.shape)} is not divisible by "
|
||||
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
|
||||
state_dict[prefix + "wq.weight"] = wq
|
||||
state_dict[prefix + "wk.weight"] = wk
|
||||
state_dict[prefix + "wv.weight"] = wv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = self.qk_norm(xq)
|
||||
xk = self.qk_norm(xk)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
|
||||
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
|
||||
attn_scales = attn_scales.view(1, seqlen, 1, 1)
|
||||
xq = xq * attn_scales
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
xk = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
xv = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
|
||||
|
||||
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
||||
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
output = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
|
||||
|
||||
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
|
||||
|
||||
use_rope = not self.is_nope_layer
|
||||
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
|
||||
|
||||
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
|
||||
|
||||
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
|
||||
self.feed_forward = MoE(
|
||||
dim=args.dim,
|
||||
hidden_dim=int(args.ffn_exp * args.dim),
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
multiple_of=args.multiple_of,
|
||||
moe_args=args.moe_args,
|
||||
)
|
||||
else:
|
||||
hidden_dim = int(4 * args.dim)
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if args.ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
||||
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
|
||||
elif prefix + "feed_forward.norm.weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
|
||||
|
||||
for k in (
|
||||
"feed_forward.experts.mlp",
|
||||
"feed_forward.mlp_shared",
|
||||
"attention.wo",
|
||||
"attention.wqkv",
|
||||
):
|
||||
if prefix + k + "._extra_state" in state_dict:
|
||||
state_dict.pop(prefix + k + "._extra_state")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
if self.is_nope_layer or local_attn_mask is None:
|
||||
mask = global_attn_mask
|
||||
else:
|
||||
mask = local_attn_mask
|
||||
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, args))
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
args.dim // args.n_heads,
|
||||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
# circular import otherwise until we refactor out Attention
|
||||
from .vision.embedding import VisionEmbeddings
|
||||
|
||||
self.vision_embeddings = VisionEmbeddings(vision_args)
|
||||
self.vision_projection = ColumnParallelLinear(
|
||||
vision_args.output_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, model_input: TransformerInput) -> TransformerOutput:
|
||||
tokens = model_input.tokens
|
||||
start_pos = model_input.tokens_position
|
||||
assert isinstance(start_pos, int), (
|
||||
"This implementation does not support different start positions per batch item"
|
||||
)
|
||||
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
if image_embedding := model_input.image_embedding:
|
||||
h_image = self.vision_projection(image_embedding.embedding)
|
||||
h = h * ~image_embedding.mask + h_image * image_embedding.mask
|
||||
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
global_attn_mask, local_attn_mask = None, None
|
||||
if seqlen > 1:
|
||||
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if global_attn_mask.device.type == torch.device("mps").type:
|
||||
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
|
||||
|
||||
if chunk_size := self.args.attention_chunk_size:
|
||||
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
|
||||
return TransformerOutput(logits=output)
|
||||
|
||||
|
||||
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
|
||||
# in the iRoPE architecture
|
||||
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
|
||||
block_pos = torch.abs(
|
||||
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
|
||||
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
|
||||
)
|
||||
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
return mask.to(device)
|
|
@ -0,0 +1,224 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .args import MoEArgs
|
||||
from .ffn import FeedForward
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_local_experts: int,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.num_local_experts = num_local_experts
|
||||
self.dim = dim
|
||||
divide_factor = fs_init.get_model_parallel_world_size()
|
||||
|
||||
self.w1: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w2: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w3: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
|
||||
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
|
||||
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
routed_in_egD: torch.Tensor, # noqa: N803
|
||||
) -> torch.Tensor:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
|
||||
x_egD = routed_in_egD.view(e, -1, D)
|
||||
|
||||
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
|
||||
out_egD = out_egD.view(-1, D)
|
||||
|
||||
return out_egD
|
||||
|
||||
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
return torch.bmm(middle_out_egF, w2)
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
This EC implementation is modified from the original EC module.
|
||||
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
|
||||
This module supports 3 sharding methods of the experts:
|
||||
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
|
||||
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
|
||||
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- et: number of local experts per tp (n_experts/tp)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- f: F/tp (used in column/row-parallel linear)
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
|
||||
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGGD: [e, GG, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
ffn_dim_multiplier: float,
|
||||
multiple_of: int,
|
||||
moe_args: MoEArgs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.moe_args = moe_args
|
||||
|
||||
hidden_dim_denom: float = 1
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim_denom = moe_args.capacity_factor + 1
|
||||
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
||||
# custom dim factor multiplier
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
||||
|
||||
hidden_dim += -hidden_dim % multiple_of
|
||||
|
||||
num_local_experts: int = moe_args.num_experts
|
||||
dtype: torch.dtype = torch.get_default_dtype()
|
||||
self.experts = Experts(
|
||||
num_local_experts,
|
||||
dim,
|
||||
hidden_dim,
|
||||
)
|
||||
|
||||
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
|
||||
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
|
||||
|
||||
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
|
||||
_, slen, D = x_bsD.shape
|
||||
x_aD = x_bsD.view(-1, D)
|
||||
|
||||
a = x_aD.shape[0]
|
||||
|
||||
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
|
||||
|
||||
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
|
||||
.scatter_(1, router_indices_aK, router_scores_aK)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
|
||||
router_scores = torch.sigmoid(router_scores)
|
||||
|
||||
routed_in_EG_D: Tensor = torch.gather(
|
||||
x_aD,
|
||||
dim=0,
|
||||
index=router_indices.reshape(-1, 1).expand(-1, D),
|
||||
)
|
||||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_egg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
||||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
|
@ -0,0 +1,436 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image, ImageFile
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
IMAGE_RES = 448
|
||||
|
||||
|
||||
class ResizeNormalizeImageTransform:
|
||||
def __init__(
|
||||
self,
|
||||
size_width=None,
|
||||
size_height=None,
|
||||
) -> None:
|
||||
self._size_width = size_width or IMAGE_RES
|
||||
self._size_height = size_height or IMAGE_RES
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
|
||||
self.tv_transform = tv.Compose(
|
||||
[
|
||||
tv.Resize((self._size_height, self._size_width)),
|
||||
tv.ToTensor(),
|
||||
tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, image: Image.Image) -> torch.Tensor:
|
||||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, width in value:
|
||||
possible_resolutions.append((height * patch_size, width * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(
|
||||
max(new_size_without_distortion[1], 1),
|
||||
max(new_size_without_distortion[0], 1),
|
||||
),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..generation import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def experts_batched_swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor, # (e, g, D)
|
||||
w1: Tensor, # (e, D, F)
|
||||
w3: Tensor, # (e, D, F)
|
||||
w2: Tensor, # (e, F, D)
|
||||
) -> torch.Tensor:
|
||||
from ...quantize_impls import bmm_nt
|
||||
|
||||
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
|
||||
return bmm_nt(middle_out_egF, w2)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
Int4ScaledWeights,
|
||||
load_fp8,
|
||||
load_int4,
|
||||
quantize_fp8,
|
||||
quantize_int4,
|
||||
)
|
||||
|
||||
rank = get_model_parallel_rank()
|
||||
|
||||
use_rich_progress = use_rich_progress and rank == 0
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
|
||||
if os.path.isfile(int4_scales_path):
|
||||
log_status(f"Rank {rank}: Loading int4 scales")
|
||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = int4_scales[key]
|
||||
zero_point = int4_zero_points[key]
|
||||
return load_int4(
|
||||
weight,
|
||||
scale,
|
||||
zero_point,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
log_status(f"Rank {rank}: Loading fp8 scales")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = fp8_scales[key]
|
||||
return load_fp8(
|
||||
weight,
|
||||
scale,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
processed_blocks = 0
|
||||
try:
|
||||
if use_rich_progress:
|
||||
progress.start()
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
# Skip quantization on first and last layers
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
# Skip quantization on dense layers
|
||||
if not isinstance(block.feed_forward, MoE):
|
||||
continue
|
||||
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
|
||||
)
|
||||
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
|
||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||
|
||||
param_count = 0
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
param_count += 1
|
||||
|
||||
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
|
||||
finally:
|
||||
if use_rich_progress:
|
||||
progress.stop()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
|
||||
console = None
|
||||
if use_rich_progress:
|
||||
from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
|
||||
def log_status(message: str) -> None:
|
||||
if use_rich_progress:
|
||||
console.print(message)
|
||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
|
||||
total_blocks = sum(
|
||||
1
|
||||
for _, block in model.named_modules()
|
||||
if (
|
||||
isinstance(block, TransformerBlock)
|
||||
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
and isinstance(block.feed_forward, MoE)
|
||||
)
|
||||
)
|
||||
progress = None
|
||||
if use_rich_progress:
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
BarColumn(complete_style="green", finished_style="bright_green"),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("ETA:"),
|
||||
TimeRemainingColumn(),
|
||||
TextColumn("[bold]{task.fields[status]}"),
|
||||
console=console,
|
||||
expand=True,
|
||||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
if completed is not None:
|
||||
progress.update(task_id, completed=completed)
|
||||
elif rank == 0 and completed and completed % 10 == 0:
|
||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
|
||||
return progress, log_status, update_status
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
|
||||
from ..args import VisionArgs
|
||||
from .encoder import VisionEncoder
|
||||
|
||||
|
||||
class PixelShuffle(nn.Module):
|
||||
def __init__(self, ps_ratio):
|
||||
super().__init__()
|
||||
self.ps_ratio = ps_ratio
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, N, C], N = number of patches
|
||||
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
||||
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
||||
hh = ww = int(math.sqrt(x.shape[1]))
|
||||
x = x.reshape(x.shape[0], hh, ww, -1)
|
||||
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
||||
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
return pixel_shuffle_patches
|
||||
|
||||
|
||||
def pixel_shuffle_op(input_x, ps_ratio):
|
||||
n, w, h, c = input_x.size()
|
||||
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
input_x = input_x.view(
|
||||
n,
|
||||
int(h * ps_ratio),
|
||||
int(w * ps_ratio),
|
||||
int(c / (ps_ratio * ps_ratio)),
|
||||
)
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
return input_x
|
||||
|
||||
|
||||
class SimpleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
bias: bool = True,
|
||||
dropout: float = 0.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.non_linearity(self.c_proj(hidden))
|
||||
|
||||
|
||||
class PixelShuffleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ps_ratio: float,
|
||||
input_dim: int,
|
||||
output_dim: int = 4096,
|
||||
add_fc: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
||||
self.mlp = SimpleMLP(
|
||||
int(input_dim // (ps_ratio**2)),
|
||||
output_dim,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
self.fc = nn.Identity()
|
||||
if add_fc:
|
||||
self.fc = ColumnParallelLinear(
|
||||
output_dim,
|
||||
output_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = self.pixel_shuffle(encoded_patches)
|
||||
return self.fc(self.mlp(encoded_patches))
|
||||
|
||||
|
||||
class VisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, args: VisionArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
image_size = args.image_size
|
||||
patch_size = args.patch_size
|
||||
self.vision_encoder = VisionEncoder(
|
||||
image_size=(image_size.height, image_size.width),
|
||||
patch_size=(patch_size.height, patch_size.width),
|
||||
dim=args.dim,
|
||||
layers=args.n_layers,
|
||||
heads=args.n_heads,
|
||||
mlp_ratio=args.mlp_ratio,
|
||||
)
|
||||
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
||||
self.vision_adapter = PixelShuffleMLP(
|
||||
ps_ratio=args.pixel_shuffle_ratio,
|
||||
input_dim=args.dim,
|
||||
output_dim=args.output_dim,
|
||||
)
|
||||
|
||||
self.output_dim = args.output_dim
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
||||
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
||||
|
||||
def _get_empty_sequence(self, h):
|
||||
return torch.zeros(
|
||||
h.shape[0],
|
||||
h.shape[1],
|
||||
self.output_dim,
|
||||
device=h.device,
|
||||
dtype=h.dtype,
|
||||
)
|
||||
|
||||
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
||||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
images_flattened = [image for sample in image_batch for image in sample]
|
||||
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
||||
embedding = self.vision_encoder(images_flattened)
|
||||
projected_embedding = self.vision_adapter(embedding)
|
||||
|
||||
h_image = self._get_empty_sequence(h_ref)
|
||||
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
||||
|
||||
|
||||
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
||||
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
||||
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
||||
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
||||
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
||||
|
||||
assert not torch.isnan(encoded_patches_proj).any()
|
||||
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
||||
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
||||
)
|
||||
|
||||
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
||||
for index in range(h_image.size(0)):
|
||||
encoded_patches_per_sample = encoded_patches_list[index]
|
||||
sample_image_mask = image_mask[index]
|
||||
|
||||
if encoded_patches_per_sample.numel() == 0:
|
||||
continue
|
||||
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
||||
-1, encoded_patches_per_sample.size(-1)
|
||||
)
|
||||
|
||||
n_tokens_to_fill = sample_image_mask.sum()
|
||||
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
||||
|
||||
h_image[index].masked_scatter_(
|
||||
sample_image_mask.expand(-1, h_image.size(-1)),
|
||||
encoded_patches_per_sample[:n_tokens_to_fill],
|
||||
)
|
||||
|
||||
return h_image
|
|
@ -0,0 +1,411 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import einsum
|
||||
|
||||
from ..args import ModelArgs
|
||||
from ..model import Attention
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||
"""Conv2D Patching layer with model parallelism.
|
||||
Column parallel over unfolded input.
|
||||
Arguments:
|
||||
in_channels: Input channels.
|
||||
out_channels: Output channels.
|
||||
kernel_size: Size of convolution kernel.
|
||||
stride (default 1): Stride for convolution.
|
||||
bias (default False): Use bias in Conv2d.
|
||||
Input: (bsz, in_channels, height, width)
|
||||
Output: (bsz, num_tokens, out_channels)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
||||
self._linear = ColumnParallelLinear(
|
||||
in_channels * kernel_size[0] * kernel_size[1],
|
||||
out_channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._unfold(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self._linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class _FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
dropout: float,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.c_proj(hidden)
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert d_model % n_head == 0
|
||||
self.n_heads = n_head
|
||||
self.head_dim = d_model // self.n_heads
|
||||
|
||||
attn_args = ModelArgs(
|
||||
dim=d_model,
|
||||
head_dim=self.head_dim,
|
||||
n_heads=self.n_heads,
|
||||
n_kv_heads=self.n_heads,
|
||||
)
|
||||
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = _FeedForward(
|
||||
dim=d_model,
|
||||
hidden_dim=int(mlp_ratio * d_model),
|
||||
dropout=0.0,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.gated = gated
|
||||
if gated:
|
||||
self.gate_attn = nn.Parameter(torch.zeros(1))
|
||||
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
||||
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
|
||||
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
_TransformerBlock(
|
||||
d_model=dim,
|
||||
n_head=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
gated=gated,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
|
||||
out = []
|
||||
for idx, r in enumerate(self.resblocks):
|
||||
if return_intermediate is not None and idx in return_intermediate:
|
||||
out.append(x)
|
||||
x = r(x, mask=mask, freq_cis=freq_cis)
|
||||
if return_intermediate is not None:
|
||||
return x, torch.stack(out, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class PackingIndex:
|
||||
Z = 0 # Z (time) coordinate of the token in the original sample
|
||||
Y = 1 # Y (height) coordinate of the token in the original sample
|
||||
X = 2 # X (width) coordinate of the token in the original sample
|
||||
TIME = 3 # Total number of time units (frames) in the original sample
|
||||
HEIGHT = 4 # Height of the original sample
|
||||
WIDTH = 5 # Width of the original sample
|
||||
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
|
||||
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
|
||||
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
|
||||
|
||||
# Total size of the enum, remember to update this!
|
||||
NUM_METADATA = 8
|
||||
|
||||
# Note: For padding tokens IDX = -1
|
||||
# For cls tokens, IDX = -2
|
||||
ID_CLS_TOKEN = -2
|
||||
ID_PAD_TOKEN = -1
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
in_channels: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (
|
||||
self.image_size[0] // self.patch_size[0],
|
||||
self.image_size[1] // self.patch_size[1],
|
||||
)
|
||||
self.conv1 = ColumnParallelConv2dPatch(
|
||||
in_channels=in_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
scale = dim**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
|
||||
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
|
||||
)
|
||||
|
||||
self.ln_pre = LayerNorm(dim)
|
||||
self.ln_post = LayerNorm(dim)
|
||||
self.transformer = _Transformer(
|
||||
dim,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
|
||||
# NOTE: hack for the fixed res
|
||||
image_h, image_w = self.image_size
|
||||
patch_h, patch_w = self.patch_size
|
||||
idx_h, idx_w = image_h // patch_h, image_w // patch_w
|
||||
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
|
||||
img_idx = img_idx.reshape(idx_h * idx_w, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
|
||||
|
||||
packed_img_idx = torch.empty(
|
||||
img_idx.shape[0],
|
||||
img_idx.shape[1],
|
||||
PackingIndex.NUM_METADATA - 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
|
||||
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
|
||||
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
|
||||
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
|
||||
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
|
||||
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
|
||||
self.packed_img_idx = packed_img_idx # for positional embedding load hook
|
||||
|
||||
# compute rope freqs
|
||||
rope_freq = self.get_rope_freqs(dim // heads // 2)
|
||||
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
|
||||
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
# disable RoPE for padding and cls tokens
|
||||
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
|
||||
# compute complex freqs
|
||||
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
# xlf automatically broadcasts
|
||||
self.freq_cis = self.freq_cis.squeeze(0)
|
||||
self.n_heads = heads // fs_init.get_model_parallel_world_size()
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def get_rope_freqs(self, dim, theta=10000):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
return freqs
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def compute_rope_freqs(self, freqs, t):
|
||||
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||
freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
return freqs
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
|
||||
raise ValueError(
|
||||
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
|
||||
)
|
||||
|
||||
batch_size, token_per_image, _ = self.packed_img_idx.shape
|
||||
# Input points for idx are [x, y, w, h]
|
||||
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
|
||||
total_windows, window_size, _ = idx.shape
|
||||
|
||||
# Grid values are [-1, 1] and coords are w, h
|
||||
grid = (
|
||||
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
|
||||
)[None, ...]
|
||||
|
||||
# In this mode, cls token has no position embedding
|
||||
if orig_pos_embed is not None:
|
||||
posemb = (
|
||||
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
|
||||
sample = F.grid_sample(
|
||||
posemb, grid, padding_mode="zeros"
|
||||
) # padding tokens / class token will get zero for posemb
|
||||
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
|
||||
sample = torch.where(
|
||||
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
|
||||
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
|
||||
sample,
|
||||
)
|
||||
|
||||
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
|
||||
|
||||
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
|
||||
|
||||
if return_state_dict:
|
||||
return state_dict
|
||||
|
||||
def apply_class_embedding(self, x):
|
||||
x = torch.cat(
|
||||
[
|
||||
x,
|
||||
self.class_embedding.to(x.dtype)
|
||||
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
],
|
||||
dim=1,
|
||||
) # shape = [*, grid ** 2 + 1, width]
|
||||
return x
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
|
||||
if images.ndim == 5:
|
||||
num_concurrent_media = 1
|
||||
bsz, num_chunks, nch, h, w = images.shape
|
||||
else:
|
||||
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
|
||||
|
||||
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
# patch embedding
|
||||
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
||||
_, ntok, dim = x.shape
|
||||
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
||||
|
||||
# apply cls token
|
||||
x = self.apply_class_embedding(x)
|
||||
ntok += 1
|
||||
|
||||
# apply position embeddings
|
||||
if self.positional_embedding_vlm is not None:
|
||||
x = x + self.positional_embedding_vlm.to(x.dtype)
|
||||
|
||||
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = x.view(bsz * num_concurrent_media, -1, dim)
|
||||
freq_cis = self.freq_cis.to(images.device)
|
||||
|
||||
tf_output = self.transformer(
|
||||
x,
|
||||
freq_cis=freq_cis,
|
||||
)
|
||||
|
||||
int_x = None
|
||||
if isinstance(tf_output, tuple):
|
||||
x, int_x = tf_output
|
||||
else:
|
||||
x = tf_output
|
||||
x = self.ln_post(x)
|
||||
|
||||
# remove cls token output
|
||||
x = x[:, :-1, :]
|
||||
|
||||
# add and output x + int_x features
|
||||
if int_x is not None:
|
||||
int_x = int_x[:, :-1, :, :]
|
||||
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
|
||||
x = torch.cat([x, int_x], dim=-1)
|
||||
|
||||
return x
|
|
@ -4,23 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -39,11 +33,10 @@ class ModelRunner:
|
|||
|
||||
|
||||
def init_model_cb(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
builder_fn: Callable,
|
||||
params: list[Any],
|
||||
):
|
||||
llama = Llama3.build(config, model_id, llama_model)
|
||||
llama = builder_fn(*params)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
@ -60,25 +53,15 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
model_parallel_size: int,
|
||||
builder_fn: Callable,
|
||||
builder_params: list[Any],
|
||||
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
||||
):
|
||||
self.config = config
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
checkpoint_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.builder_fn = builder_fn
|
||||
self.builder_params = builder_params
|
||||
self.formatter = formatter
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -87,11 +70,9 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
model_parallel_size = self.llama_model.pth_file_count
|
||||
|
||||
self.group = ModelParallelProcessGroup(
|
||||
model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
|
||||
self.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,177 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
|
||||
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||
"""
|
||||
if isinstance(w, Fp8RowwiseWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
return z_.view(B, T, D)
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# The file gets a special treatment for now?
|
||||
# ruff: noqa: N803
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
D=st.sampled_from([4096, 8192]),
|
||||
HD_L=st.sampled_from([1280, 2560]),
|
||||
B=st.sampled_from([1, 2]),
|
||||
T=st.sampled_from([2048, 4096]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int, # noqa
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||
|
||||
v_ref = ref_ffn(x, w1, w3, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,152 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||
quantize_fp8,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_dir: str,
|
||||
tokenizer_path: str,
|
||||
quantized_ckpt_dir: str,
|
||||
max_seq_len: Optional[int] = 512,
|
||||
max_batch_size: Optional[int] = 4,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
seed: int = 1,
|
||||
):
|
||||
""" """
|
||||
if not os.path.exists(quantized_ckpt_dir):
|
||||
os.makedirs(quantized_ckpt_dir)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "params.json"),
|
||||
os.path.join(quantized_ckpt_dir, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "tokenizer.model"),
|
||||
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
if not model_parallel_is_initialized():
|
||||
if model_parallel_size is None:
|
||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
)
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
assert model_args.vocab_size == tokenizer.n_words, (
|
||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
)
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
log.info(ckpt_path)
|
||||
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
|
||||
fp8_scales = {}
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
torch.save(fp8_scales, fp8_scales_path)
|
||||
|
||||
ckpt_path = os.path.join(
|
||||
quantized_ckpt_dir,
|
||||
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
|
||||
)
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -1,31 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
cd $(dirname "$(realpath "$0")")
|
||||
|
||||
MASTER_HOST=$1
|
||||
RUN_ID=$2
|
||||
CKPT_DIR=$3
|
||||
QUANT_CKPT_DIR=$4
|
||||
TOKENIZER_PATH=$5
|
||||
NNODES=$6
|
||||
NPROC=$7
|
||||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
--rdzv_conf='timeout=120' \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint="${MASTER_HOST}:29502" \
|
||||
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR
|
|
@ -0,0 +1,332 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class Int4ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Int4Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Int4Weights(
|
||||
Int4ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Int4Weights",
|
||||
["weight", "scale", "zero_point", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def int4_row_quantize(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_bit = 4 # Number of target bits.
|
||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
|
||||
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
||||
|
||||
# Recenter output and move to int8.
|
||||
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
||||
|
||||
# Cutlass expects column major layout for scale and zero point,
|
||||
# so we transpose here and make them contiguous.
|
||||
scales = scales.view(x.shape[0], -1).t().contiguous()
|
||||
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
||||
|
||||
return out, scales, zeros
|
||||
|
||||
|
||||
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
||||
# Given int8 x, pack adjacent int4 values into a single int8.
|
||||
low_x = x[:, ::2]
|
||||
high_x = x[:, 1::2]
|
||||
|
||||
# High bits need to left shift, this also masks off extra bits.
|
||||
high_x = torch.bitwise_left_shift(high_x, 4)
|
||||
# Low bits need to have sign bits removed.
|
||||
low_x = torch.bitwise_and(low_x, 0xF)
|
||||
|
||||
# Recombine into a single value with bitwise or.
|
||||
return torch.bitwise_or(low_x, high_x).contiguous()
|
||||
|
||||
|
||||
def bmm_nt(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
if isinstance(w, Fp8ScaledWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
|
||||
elif isinstance(w, Int4ScaledWeights):
|
||||
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
raise ValueError("Unsupported quantization type")
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
if w.ndim >= 3:
|
||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||
scale = torch.stack(scale, dim=0)
|
||||
zero_point = torch.stack(zero_point, dim=0)
|
||||
else:
|
||||
wq, scale, zero_point = int4_row_quantize(w)
|
||||
wq = pack_int4(wq)
|
||||
del w
|
||||
return Int4Weights(
|
||||
weight=wq.to(output_device),
|
||||
scale=scale.to(output_device),
|
||||
zero_point=zero_point.to(output_device),
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
|
||||
scale=w_scale.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_int4(
|
||||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input INT4.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Int4Weights(
|
||||
weight=w.to(torch.int8).to(device=output_device),
|
||||
scale=scale.to(device=output_device),
|
||||
zero_point=zero_point.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_dynamic(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
|
||||
"""
|
||||
if isinstance(w, Int4Weights):
|
||||
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_dynamic(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
assert x.dim() == 3 or x.dim() == 2
|
||||
if x.dim() == 3:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
else:
|
||||
(T, D) = x.shape # noqa: N806
|
||||
B = 1 # noqa: N806
|
||||
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
if x.dim() == 3:
|
||||
return z_.view(B, T, D)
|
||||
else:
|
||||
return z_
|
|
@ -126,7 +126,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = event.span_id
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
|
|
|
@ -39,13 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::meta-reference-quantized",
|
||||
pip_packages=(
|
||||
META_REFERENCE_DEPS
|
||||
+ [
|
||||
"fbgemm-gpu",
|
||||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
|
|
|
@ -27,7 +27,7 @@ def supported_inference_models() -> List[Model]:
|
|||
m
|
||||
for m in all_registered_models()
|
||||
if (
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
|
||||
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3, ModelFamily.llama4}
|
||||
or is_supported_safety_model(m)
|
||||
)
|
||||
]
|
||||
|
|
|
@ -33,9 +33,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_choice,
|
||||
|
@ -55,10 +53,22 @@ class LiteLLMOpenAIMixin(
|
|||
Inference,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
api_key_from_config: Optional[str],
|
||||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
self.api_base = openai_compat_api_base
|
||||
|
||||
if openai_compat_api_base:
|
||||
self.is_openai_compat = True
|
||||
else:
|
||||
self.is_openai_compat = False
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -98,6 +108,7 @@ class LiteLLMOpenAIMixin(
|
|||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
model = await self.model_store.get_model(model_id)
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
|
@ -111,6 +122,9 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
if self.is_openai_compat:
|
||||
params["model"] = "openai/" + params["model"]
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
|
@ -208,6 +222,7 @@ class LiteLLMOpenAIMixin(
|
|||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
|
|
|
@ -573,21 +573,24 @@ async def convert_message_to_openai_dict_new(
|
|||
content=await _convert_message_content(message.content),
|
||||
)
|
||||
elif isinstance(message, CompletionMessage):
|
||||
tool_calls = [
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
params = {"tool_calls": tool_calls}
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
tool_calls=[
|
||||
OpenAIChatCompletionMessageToolCall(
|
||||
id=tool.call_id,
|
||||
function=OpenAIFunction(
|
||||
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||
arguments=json.dumps(tool.arguments),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
or None,
|
||||
**params,
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
|
@ -801,7 +804,7 @@ def _convert_openai_logprobs(
|
|||
- token, logprob
|
||||
|
||||
"""
|
||||
if not logprobs:
|
||||
if not logprobs or not logprobs.content:
|
||||
return None
|
||||
|
||||
return [
|
||||
|
|
|
@ -224,7 +224,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
|
|||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
async def completion_request_to_prompt_model_input_info(request: CompletionRequest) -> Tuple[str, int]:
|
||||
async def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest,
|
||||
) -> Tuple[str, int]:
|
||||
content = augment_content_with_response_format_prompt(request.response_format, request.content)
|
||||
request.content = content
|
||||
request = await convert_request_to_raw(request)
|
||||
|
@ -302,8 +304,12 @@ def chat_completion_request_to_messages(
|
|||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_1(request)
|
||||
elif model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
elif model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
||||
messages = augment_messages_for_tools_llama_3_2(request)
|
||||
else:
|
||||
messages = request.messages
|
||||
|
@ -471,7 +477,11 @@ def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
|
|||
):
|
||||
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
|
||||
return ToolPromptFormat.json
|
||||
elif llama_model.model_family in (ModelFamily.llama3_2, ModelFamily.llama3_3):
|
||||
elif llama_model.model_family in (
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
):
|
||||
# llama3.2 and llama3.3 models follow the same tool prompt format
|
||||
return ToolPromptFormat.python_list
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue