mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
7cdd2a0410
264 changed files with 229042 additions and 8445 deletions
|
@ -6,11 +6,8 @@
|
|||
|
||||
from typing import List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
CompletionResponse,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
batch: List[ChatCompletionResponse]
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
"""Batch inference API for generating completions and chat completions.
|
||||
|
||||
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
||||
|
||||
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
||||
including (post-training, evals, etc).
|
||||
"""
|
||||
|
||||
@webmethod(route="/batch-inference/completion", method="POST")
|
||||
async def batch_completion(
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse: ...
|
||||
) -> Job: ...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def batch_chat_completion(
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: Optional[List[ToolDefinition]] = list,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchChatCompletionResponse: ...
|
||||
) -> Job: ...
|
||||
|
|
|
@ -18,22 +18,71 @@ from typing import (
|
|||
)
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
"""
|
||||
|
@ -48,18 +97,18 @@ class QuantizationType(Enum):
|
|||
"""Type of model quantization to run inference with.
|
||||
|
||||
:cvar bf16: BFloat16 typically this means _no_ quantization
|
||||
:cvar fp8: 8-bit floating point quantization
|
||||
:cvar int4: 4-bit integer quantization
|
||||
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
|
||||
:cvar int4_mixed: 4-bit integer quantization with mixed precision
|
||||
"""
|
||||
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
type: Literal["fp8"] = "fp8"
|
||||
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
|
|||
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
||||
"""
|
||||
|
||||
type: Literal["int4"] = "int4"
|
||||
type: Literal["int4_mixed"] = "int4_mixed"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
|
@ -393,6 +442,352 @@ class EmbeddingsResponse(BaseModel):
|
|||
embeddings: List[List[float]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: OpenAIImageURL
|
||||
|
||||
|
||||
OpenAIChatCompletionContentPartParam = Annotated[
|
||||
Union[
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||
|
||||
|
||||
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIUserMessageParam(BaseModel):
|
||||
"""A message from the user in an OpenAI-compatible chat completion request.
|
||||
|
||||
:param role: Must be "user" to identify this as a user message
|
||||
:param content: The content of the message, which can include text and other media
|
||||
:param name: (Optional) The name of the user message participant.
|
||||
"""
|
||||
|
||||
role: Literal["user"] = "user"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAISystemMessageParam(BaseModel):
|
||||
"""A system message providing instructions or context to the model.
|
||||
|
||||
:param role: Must be "system" to identify this as a system message
|
||||
:param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions).
|
||||
:param name: (Optional) The name of the system message participant.
|
||||
"""
|
||||
|
||||
role: Literal["system"] = "system"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionToolCall(BaseModel):
|
||||
index: Optional[int] = None
|
||||
id: Optional[str] = None
|
||||
type: Literal["function"] = "function"
|
||||
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIAssistantMessageParam(BaseModel):
|
||||
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
|
||||
|
||||
:param role: Must be "assistant" to identify this as the model's response
|
||||
:param content: The content of the model's response
|
||||
:param name: (Optional) The name of the assistant message participant.
|
||||
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
|
||||
"""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIToolMessageParam(BaseModel):
|
||||
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request.
|
||||
|
||||
:param role: Must be "tool" to identify this as a tool response
|
||||
:param tool_call_id: Unique identifier for the tool call this response is for
|
||||
:param content: The response content from the tool
|
||||
"""
|
||||
|
||||
role: Literal["tool"] = "tool"
|
||||
tool_call_id: str
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIDeveloperMessageParam(BaseModel):
|
||||
"""A message from the developer in an OpenAI-compatible chat completion request.
|
||||
|
||||
:param role: Must be "developer" to identify this as a developer message
|
||||
:param content: The content of the developer message
|
||||
:param name: (Optional) The name of the developer message participant.
|
||||
"""
|
||||
|
||||
role: Literal["developer"] = "developer"
|
||||
content: OpenAIChatCompletionMessageContent
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
OpenAIMessageParam = Annotated[
|
||||
Union[
|
||||
OpenAIUserMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
],
|
||||
Field(discriminator="role"),
|
||||
]
|
||||
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatText(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIJSONSchema(TypedDict, total=False):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
strict: Optional[bool] = None
|
||||
|
||||
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||
# has one. And, we don't want to alias here because then have to handle
|
||||
# that alias when converting to OpenAI params. So, to support schema,
|
||||
# we use a TypedDict.
|
||||
schema: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||
type: Literal["json_schema"] = "json_schema"
|
||||
json_schema: OpenAIJSONSchema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||
type: Literal["json_object"] = "json_object"
|
||||
|
||||
|
||||
OpenAIResponseFormatParam = Annotated[
|
||||
Union[
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITopLogProb(BaseModel):
|
||||
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
||||
:token: The token
|
||||
:bytes: (Optional) The bytes for the token
|
||||
:logprob: The log probability of the token
|
||||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
logprob: float
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAITokenLogProb(BaseModel):
|
||||
"""The log probability for a token from an OpenAI-compatible chat completion response.
|
||||
|
||||
:token: The token
|
||||
:bytes: (Optional) The bytes for the token
|
||||
:logprob: The log probability of the token
|
||||
:top_logprobs: The top log probabilities for the token
|
||||
"""
|
||||
|
||||
token: str
|
||||
bytes: Optional[List[int]] = None
|
||||
logprob: float
|
||||
top_logprobs: List[OpenAITopLogProb]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoiceLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
|
||||
|
||||
:param content: (Optional) The log probabilities for the tokens in the message
|
||||
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
content: Optional[List[OpenAITokenLogProb]] = None
|
||||
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoiceDelta(BaseModel):
|
||||
"""A delta from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param content: (Optional) The content of the delta
|
||||
:param refusal: (Optional) The refusal of the delta
|
||||
:param role: (Optional) The role of the delta
|
||||
:param tool_calls: (Optional) The tool calls of the delta
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
refusal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChunkChoice(BaseModel):
|
||||
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||
|
||||
:param delta: The delta from the chunk
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
delta: OpenAIChoiceDelta
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChoice(BaseModel):
|
||||
"""A choice from an OpenAI-compatible chat completion response.
|
||||
|
||||
:param message: The message from the model
|
||||
:param finish_reason: The reason the model stopped generating
|
||||
:param index: The index of the choice
|
||||
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||
"""
|
||||
|
||||
message: OpenAIMessageParam
|
||||
finish_reason: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletion(BaseModel):
|
||||
"""Response from an OpenAI-compatible chat completion request.
|
||||
|
||||
:param id: The ID of the chat completion
|
||||
:param choices: List of choices
|
||||
:param object: The object type, which will be "chat.completion"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChoice]
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIChatCompletionChunk(BaseModel):
|
||||
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||
|
||||
:param id: The ID of the chat completion
|
||||
:param choices: List of choices
|
||||
:param object: The object type, which will be "chat.completion.chunk"
|
||||
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||
:param model: The model that was used to generate the chat completion
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAIChunkChoice]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletionLogprobs(BaseModel):
|
||||
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||
|
||||
:text_offset: (Optional) The offset of the token in the text
|
||||
:token_logprobs: (Optional) The log probabilities for the tokens
|
||||
:tokens: (Optional) The tokens
|
||||
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||
"""
|
||||
|
||||
text_offset: Optional[List[int]] = None
|
||||
token_logprobs: Optional[List[float]] = None
|
||||
tokens: Optional[List[str]] = None
|
||||
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletionChoice(BaseModel):
|
||||
"""A choice from an OpenAI-compatible completion response.
|
||||
|
||||
:finish_reason: The reason the model stopped generating
|
||||
:text: The text of the choice
|
||||
:index: The index of the choice
|
||||
:logprobs: (Optional) The log probabilities for the tokens in the choice
|
||||
"""
|
||||
|
||||
finish_reason: str
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAICompletion(BaseModel):
|
||||
"""Response from an OpenAI-compatible completion request.
|
||||
|
||||
:id: The ID of the completion
|
||||
:choices: List of choices
|
||||
:created: The Unix timestamp in seconds when the completion was created
|
||||
:model: The model that was used to generate the completion
|
||||
:object: The object type, which will be "text_completion"
|
||||
"""
|
||||
|
||||
id: str
|
||||
choices: List[OpenAICompletionChoice]
|
||||
created: int
|
||||
model: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
|
||||
|
||||
class ModelStore(Protocol):
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
|
@ -421,6 +816,16 @@ class EmbeddingTaskType(Enum):
|
|||
document = "document"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
batch: List[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
batch: List[ChatCompletionResponse]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
|
@ -456,6 +861,17 @@ class Inference(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse:
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -496,6 +912,19 @@ class Inference(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -515,3 +944,105 @@ class Inference(Protocol):
|
|||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
# vLLM-specific parameters
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param prompt: The prompt to generate a completion for
|
||||
:param best_of: (Optional) The number of completions to generate
|
||||
:param echo: (Optional) Whether to echo the prompt
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST")
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||
|
||||
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages: List of messages in the conversation
|
||||
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||
:param function_call: (Optional) The function call to use
|
||||
:param functions: (Optional) List of functions to use
|
||||
:param logit_bias: (Optional) The logit bias to use
|
||||
:param logprobs: (Optional) The log probabilities to use
|
||||
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||
:param n: (Optional) The number of completions to generate
|
||||
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
||||
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||
:param response_format: (Optional) The response format to use
|
||||
:param seed: (Optional) The seed to use
|
||||
:param stop: (Optional) The stop tokens to use
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param stream_options: (Optional) The stream options to use
|
||||
:param temperature: (Optional) The temperature to use
|
||||
:param tool_choice: (Optional) The tool choice to use
|
||||
:param tools: (Optional) The tools to use
|
||||
:param top_logprobs: (Optional) The top log probabilities to use
|
||||
:param top_p: (Optional) The top p to use
|
||||
:param user: (Optional) The user to use
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class HealthInfo(BaseModel):
|
||||
status: str
|
||||
# TODO: add a provider level status
|
||||
status: HealthStatus
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -56,12 +56,35 @@ class ListModelsResponse(BaseModel):
|
|||
data: List[Model]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIModel(BaseModel):
|
||||
"""A model from OpenAI.
|
||||
|
||||
:id: The ID of the model
|
||||
:object: The object type, which will be "model"
|
||||
:created: The Unix timestamp in seconds when the model was created
|
||||
:owned_by: The owner of the model
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
data: List[OpenAIModel]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
async def list_models(self) -> ListModelsResponse: ...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
async def get_model(
|
||||
self,
|
||||
|
|
|
@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
|
|||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
max_steps_per_epoch: int = 1
|
||||
gradient_accumulation_steps: int = 1
|
||||
max_validation_steps: Optional[int] = 1
|
||||
data_config: Optional[DataConfig] = None
|
||||
optimizer_config: Optional[OptimizerConfig] = None
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
@ -177,9 +177,9 @@ class PostTraining(Protocol):
|
|||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str = Field(
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Model descriptor for training if not in provider config`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
|||
provider_id: str
|
||||
provider_type: str
|
||||
config: Dict[str, Any]
|
||||
health: HealthResponse
|
||||
|
||||
|
||||
class ListProvidersResponse(BaseModel):
|
||||
|
|
|
@ -29,8 +29,8 @@ from rich.progress import (
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
|
@ -162,6 +162,10 @@ class ParallelDownloader:
|
|||
raise last_exception
|
||||
|
||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||
if task.total_size > 0:
|
||||
self.progress.update(task.task_id, total=task.total_size)
|
||||
return
|
||||
|
||||
async def _get_info():
|
||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||
response.raise_for_status()
|
||||
|
@ -282,7 +286,7 @@ class ParallelDownloader:
|
|||
if not tasks:
|
||||
raise ValueError("No download tasks provided")
|
||||
|
||||
if not self.has_disk_space(tasks):
|
||||
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
|
||||
raise DownloadError("Insufficient disk space for downloads")
|
||||
|
||||
failed_tasks = []
|
||||
|
|
|
@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
|
|||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
sampling_params = model.recommended_sampling_params.model_dump()
|
||||
for k in ("max_tokens", "repetition_penalty"):
|
||||
del sampling_params[k]
|
||||
rows.append(
|
||||
(
|
||||
"Recommended sampling params",
|
||||
json.dumps(sampling_params, indent=4),
|
||||
)
|
||||
)
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
|
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
|
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
|
|||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return self.model_id
|
||||
|
|
|
@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
elif args.providers:
|
||||
providers = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider in providers_for_api:
|
||||
providers.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
distribution_spec = DistributionSpec(
|
||||
providers=providers,
|
||||
description=",".join(args.providers),
|
||||
)
|
||||
if not args.image_type:
|
||||
cprint(
|
||||
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||
color="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||
elif not args.config and not args.template:
|
||||
name = prompt(
|
||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||
|
|
|
@ -57,7 +57,7 @@ class StackBuild(Subcommand):
|
|||
type=str,
|
||||
help=textwrap.dedent(
|
||||
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
|
||||
the build. If not specified, currently active Conda environment will be used if found.
|
||||
the build. If not specified, currently active environment will be used if found.
|
||||
"""
|
||||
),
|
||||
default=None,
|
||||
|
@ -75,6 +75,12 @@ the build. If not specified, currently active Conda environment will be used if
|
|||
default=False,
|
||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||
# always keep implementation completely silo-ed away from CLI so CLI
|
||||
|
|
|
@ -45,7 +45,7 @@ class StackRun(Subcommand):
|
|||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
help="Name of the image to run. Defaults to the current conda environment",
|
||||
help="Name of the image to run. Defaults to the current environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--disable-ipv6",
|
||||
|
|
|
@ -312,6 +312,11 @@ a default SQLite store will be used.""",
|
|||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -4,12 +4,25 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
from typing import Dict, List
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
|
|||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
Args:
|
||||
config: Optional StackRunConfig containing the external providers directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the external providers directory doesn't exist
|
||||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
if config and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
|||
return ListRoutesResponse(data=ret)
|
||||
|
||||
async def health(self) -> HealthInfo:
|
||||
return HealthInfo(status="OK")
|
||||
return HealthInfo(status=HealthStatus.OK)
|
||||
|
||||
async def version(self) -> VersionInfo:
|
||||
return VersionInfo(version=version("llama-stack"))
|
||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
|||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.distribution.utils.exec import in_notebook
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
|
|
|
@ -4,14 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .stack import redact_sensitive_fields
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
|||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.run_config
|
||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
for api, providers in safe_config.providers.items():
|
||||
ret.extend(
|
||||
[
|
||||
for p in providers:
|
||||
ret.append(
|
||||
ProviderInfo(
|
||||
api=api,
|
||||
provider_id=p.provider_id,
|
||||
provider_type=p.provider_type,
|
||||
config=p.config,
|
||||
health=providers_health.get(api, {}).get(
|
||||
p.provider_id,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
),
|
||||
)
|
||||
for p in providers
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return ListProvidersResponse(data=ret)
|
||||
|
||||
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
|||
return p
|
||||
|
||||
raise ValueError(f"Provider {provider_id} not found")
|
||||
|
||||
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||
"""Get health status for all providers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||
Each API maps to a dictionary of provider IDs to their health responses.
|
||||
"""
|
||||
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||
timeout = 1.0
|
||||
|
||||
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||
if not hasattr(impl, "__provider_spec__"):
|
||||
return None
|
||||
api_name = impl.__provider_spec__.api.name
|
||||
if not hasattr(impl, "health"):
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
return api_name, health
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
return (
|
||||
api_name,
|
||||
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||
)
|
||||
|
||||
# Create tasks for all providers
|
||||
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||
|
||||
# Wait for all health checks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Organize results by API and provider ID
|
||||
for result in results:
|
||||
if result is None: # Skip special implementations
|
||||
continue
|
||||
api_name, health_response = result
|
||||
providers_health[api_name] = health_response
|
||||
|
||||
return providers_health
|
||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
|||
Api,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetsProtocolPrivate,
|
||||
InlineProviderSpec,
|
||||
ModelsProtocolPrivate,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
|
@ -230,50 +229,9 @@ def sort_providers_by_deps(
|
|||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||
)
|
||||
|
||||
# Append built-in "inspect" provider
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
"inspect",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.inspect,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_providers.append(
|
||||
(
|
||||
"providers",
|
||||
ProviderWithSpec(
|
||||
provider_id="__builtin__",
|
||||
provider_type="__builtin__",
|
||||
config={"run_config": run_config.model_dump()},
|
||||
spec=InlineProviderSpec(
|
||||
api=Api.providers,
|
||||
provider_type="__builtin__",
|
||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
||||
module="llama_stack.distribution.providers",
|
||||
api_dependencies=apis,
|
||||
deps__=[x.value for x in apis],
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||
for api_str, provider in sorted_providers:
|
||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||
logger.debug("")
|
||||
return sorted_providers
|
||||
|
||||
|
||||
|
@ -351,6 +309,7 @@ async def instantiate_provider(
|
|||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
|
@ -399,6 +358,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if value.__webmethod__.experimental:
|
||||
continue
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
elif not callable(getattr(obj, name)):
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
|
@ -17,6 +18,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
|||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
|
@ -35,6 +38,13 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.scoring import (
|
||||
|
@ -57,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
@ -333,6 +343,30 @@ class InferenceRouter(Inference):
|
|||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -397,6 +431,20 @@ class InferenceRouter(Inference):
|
|||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -419,6 +467,149 @@ class InferenceRouter(Inference):
|
|||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_completion(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ValueError(f"Model '{model}' not found")
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||
return await provider.openai_chat_completion(**params)
|
||||
|
||||
async def health(self) -> Dict[str, HealthResponse]:
|
||||
health_statuses = {}
|
||||
timeout = 0.5
|
||||
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||
try:
|
||||
# check if the provider has a health method
|
||||
if not hasattr(impl, "health"):
|
||||
continue
|
||||
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||
health_statuses[provider_id] = health
|
||||
except asyncio.TimeoutError:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR,
|
||||
message=f"Health check timed out after {timeout} seconds",
|
||||
)
|
||||
except NotImplementedError:
|
||||
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||
except Exception as e:
|
||||
health_statuses[provider_id] = HealthResponse(
|
||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
def __init__(
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -23,7 +24,7 @@ from llama_stack.apis.datasets import (
|
|||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ListScoringFunctionsResponse,
|
||||
|
@ -254,6 +255,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
)
|
||||
for model in models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is None:
|
||||
|
@ -608,8 +622,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
tools = await self.list_tools(toolgroup_id)
|
||||
for tool in getattr(tools, "data", []):
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
|||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -229,15 +229,30 @@ class TracingMiddleware:
|
|||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
|
@ -388,7 +403,12 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
safe_config = redact_sensitive_fields(config.model_dump())
|
||||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
|||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
@ -96,7 +98,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
|||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
await method(**obj.model_dump())
|
||||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||
|
||||
method = getattr(impls[api], list_method)
|
||||
response = await method()
|
||||
|
@ -116,26 +121,6 @@ class EnvVarError(Exception):
|
|||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
||||
|
||||
|
||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||
if isinstance(config, dict):
|
||||
result = {}
|
||||
|
@ -212,13 +197,37 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
|||
) from e
|
||||
|
||||
|
||||
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||
|
||||
Args:
|
||||
impls: Dictionary of API implementations
|
||||
run_config: Stack run configuration
|
||||
"""
|
||||
inspect_impl = DistributionInspectImpl(
|
||||
DistributionInspectConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.inspect] = inspect_impl
|
||||
|
||||
providers_impl = ProviderImpl(
|
||||
ProviderImplConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
|||
set -euo pipefail
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
error_handler() {
|
||||
|
@ -73,7 +74,7 @@ done
|
|||
PYTHON_BINARY="python"
|
||||
case "$env_type" in
|
||||
"venv")
|
||||
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
||||
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
||||
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
|
||||
else
|
||||
# Activate virtual environment
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.9-slim
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
|
|
|
@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
|
|||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
cd llama_stack/distribution/ui
|
||||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
|||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
|
@ -39,6 +40,7 @@ def main():
|
|||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
tool_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
|
|
|
@ -19,6 +19,7 @@ class LlamaStackApi:
|
|||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -4,9 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
from llama_stack.apis.common.content_types import ToolCallDelta
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||
|
||||
|
@ -14,9 +17,16 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
|||
def rag_chat_page():
|
||||
st.title("🦙 RAG")
|
||||
|
||||
def reset_agent_and_chat():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
def should_disable_input():
|
||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents")
|
||||
st.subheader("Upload Documents", divider=True)
|
||||
uploaded_files = st.file_uploader(
|
||||
"Upload file(s) or directory",
|
||||
accept_multiple_files=True,
|
||||
|
@ -27,11 +37,11 @@ def rag_chat_page():
|
|||
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
||||
# Add memory bank name input field
|
||||
vector_db_name = st.text_input(
|
||||
"Vector Database Name",
|
||||
"Document Collection Name",
|
||||
value="rag_vector_db",
|
||||
help="Enter a unique identifier for this vector database",
|
||||
help="Enter a unique identifier for this document collection",
|
||||
)
|
||||
if st.button("Create Vector Database"):
|
||||
if st.button("Create Document Collection"):
|
||||
documents = [
|
||||
RAGDocument(
|
||||
document_id=uploaded_file.name,
|
||||
|
@ -62,26 +72,45 @@ def rag_chat_page():
|
|||
)
|
||||
st.success("Vector database created successfully!")
|
||||
|
||||
st.subheader("Configure Agent")
|
||||
st.subheader("RAG Parameters", divider=True)
|
||||
|
||||
rag_mode = st.radio(
|
||||
"RAG mode",
|
||||
["Direct", "Agent-based"],
|
||||
captions=[
|
||||
"RAG is performed by directly retrieving the information and augmenting the user query",
|
||||
"RAG is performed by an agent activating a dedicated knowledge search tool.",
|
||||
],
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# select memory banks
|
||||
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||
selected_vector_dbs = st.multiselect(
|
||||
"Select Vector Databases",
|
||||
vector_dbs,
|
||||
label="Select Document Collections to use in RAG queries",
|
||||
options=vector_dbs,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
st.subheader("Inference Parameters", divider=True)
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
label="Choose a model",
|
||||
options=available_models,
|
||||
index=0,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful assistant. ",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
|
@ -90,6 +119,8 @@ def rag_chat_page():
|
|||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
|
@ -98,19 +129,23 @@ def rag_chat_page():
|
|||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
on_change=reset_agent_and_chat,
|
||||
disabled=should_disable_input(),
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
reset_agent_and_chat()
|
||||
st.rerun()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
if "displayed_messages" not in st.session_state:
|
||||
st.session_state.displayed_messages = []
|
||||
|
||||
# Display chat history
|
||||
for message in st.session_state.messages:
|
||||
for message in st.session_state.displayed_messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
|
@ -123,33 +158,37 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
session_id = agent.create_session("rag-session")
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
if rag_mode == "Agent-based":
|
||||
agent = create_agent()
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
def agent_process_prompt(prompt):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Send the prompt to the agent
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
|
@ -177,6 +216,79 @@ def rag_chat_page():
|
|||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
||||
|
||||
def direct_process_prompt(prompt):
|
||||
# Add the system prompt in the beginning of the conversation
|
||||
if len(st.session_state.messages) == 0:
|
||||
st.session_state.messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
# Query the vector DB
|
||||
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
|
||||
content=prompt, vector_db_ids=list(selected_vector_dbs)
|
||||
)
|
||||
prompt_context = rag_response.content
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
|
||||
# Display the retrieved content
|
||||
retrieval_response += str(prompt_context)
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
|
||||
# Construct the extended prompt
|
||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||
|
||||
# Run inference directly
|
||||
st.session_state.messages.append({"role": "user", "content": extended_prompt})
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=st.session_state.messages,
|
||||
model_id=selected_model,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
for chunk in response:
|
||||
response_delta = chunk.event.delta
|
||||
if isinstance(response_delta, ToolCallDelta):
|
||||
retrieval_response += response_delta.tool_call.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
else:
|
||||
full_response += chunk.event.delta.text
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
|
||||
st.session_state.messages.append(response_dict)
|
||||
st.session_state.displayed_messages.append(response_dict)
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
# Add user message to chat history
|
||||
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# store the prompt to process it after page refresh
|
||||
st.session_state.prompt = prompt
|
||||
|
||||
# force page refresh to disable the settings widgets
|
||||
st.rerun()
|
||||
|
||||
if "prompt" in st.session_state and st.session_state.prompt is not None:
|
||||
if rag_mode == "Agent-based":
|
||||
agent_process_prompt(st.session_state.prompt)
|
||||
else: # rag_mode == "Direct"
|
||||
direct_process_prompt(st.session_state.prompt)
|
||||
st.session_state.prompt = None
|
||||
|
||||
|
||||
rag_chat_page()
|
||||
|
|
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
client = llama_stack_api.client
|
||||
models = client.models.list()
|
||||
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||
|
||||
tool_groups = client.toolgroups.list()
|
||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||
|
||||
st.subheader("Builtin Tools")
|
||||
toolgroup_selection = st.pills(
|
||||
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
st.subheader("MCP Servers")
|
||||
mcp_selection = st.pills(
|
||||
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
active_tool_list = []
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
active_tool_list.extend(
|
||||
[
|
||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||
]
|
||||
)
|
||||
|
||||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.json(active_tool_list)
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
},
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
if prompt := st.chat_input(placeholder=""):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
if response.event.payload.event_type == "step_progress":
|
||||
if hasattr(response.event.payload.delta, "text"):
|
||||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
yield " 🛠 "
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
tool_chat_page()
|
|
@ -1,4 +1,5 @@
|
|||
streamlit
|
||||
pandas
|
||||
llama-stack-client>=0.0.55
|
||||
llama-stack-client>=0.2.1
|
||||
streamlit-option-menu
|
||||
llama-stack>=0.2.1
|
||||
|
|
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact sensitive information from config before printing."""
|
||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||
|
||||
def _redact_value(v: Any) -> Any:
|
||||
if isinstance(v, dict):
|
||||
return _redact_dict(v)
|
||||
elif isinstance(v, list):
|
||||
return [_redact_value(i) for i in v]
|
||||
return v
|
||||
|
||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||
result[k] = "********"
|
||||
else:
|
||||
result[k] = _redact_value(v)
|
||||
return result
|
||||
|
||||
return _redact_dict(data)
|
|
@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
|
|||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
|
|
|
@ -1,155 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import ContextVar
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_with_exception():
|
||||
# Create context variable
|
||||
context_var = ContextVar("exception_var", default="initial")
|
||||
token = context_var.set("start_value")
|
||||
|
||||
# Create an async generator that raises an exception
|
||||
async def exception_generator():
|
||||
yield context_var.get()
|
||||
context_var.set("modified")
|
||||
raise ValueError("Test exception")
|
||||
yield None # This will never be reached
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(exception_generator(), [context_var])
|
||||
|
||||
# First iteration should work
|
||||
value = await wrapped_gen.__anext__()
|
||||
assert value == "start_value"
|
||||
|
||||
# Second iteration should raise the exception
|
||||
with pytest.raises(ValueError, match="Test exception"):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_empty_generator():
|
||||
# Create context variable
|
||||
context_var = ContextVar("empty_var", default="initial")
|
||||
token = context_var.set("value")
|
||||
|
||||
# Create an empty async generator
|
||||
async def empty_generator():
|
||||
if False: # This condition ensures the generator yields nothing
|
||||
yield None
|
||||
|
||||
# Wrap the generator
|
||||
wrapped_gen = preserve_contexts_async_generator(empty_generator(), [context_var])
|
||||
|
||||
# The generator should raise StopAsyncIteration immediately
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await wrapped_gen.__anext__()
|
||||
|
||||
# Context variable should remain unchanged
|
||||
assert context_var.get() == "value"
|
||||
|
||||
# Clean up
|
||||
context_var.reset(token)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserve_contexts_across_event_loops():
|
||||
"""
|
||||
Test that context variables are preserved across event loop boundaries with nested generators.
|
||||
This simulates the real-world scenario where:
|
||||
1. A new event loop is created for each streaming request
|
||||
2. The async generator runs inside that loop
|
||||
3. There are multiple levels of nested generators
|
||||
4. Context needs to be preserved across these boundaries
|
||||
"""
|
||||
# Create context variables
|
||||
request_id = ContextVar("request_id", default=None)
|
||||
user_id = ContextVar("user_id", default=None)
|
||||
|
||||
# Set initial values
|
||||
|
||||
# Results container to verify values across thread boundaries
|
||||
results = []
|
||||
|
||||
# Inner-most generator (level 2)
|
||||
async def inner_generator():
|
||||
# Should have the context from the outer scope
|
||||
yield (1, request_id.get(), user_id.get())
|
||||
|
||||
# Modify one context variable
|
||||
user_id.set("user-modified")
|
||||
|
||||
# Should reflect the modification
|
||||
yield (2, request_id.get(), user_id.get())
|
||||
|
||||
# Middle generator (level 1)
|
||||
async def middle_generator():
|
||||
inner_gen = inner_generator()
|
||||
|
||||
# Forward the first yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
# Forward the second yield from inner
|
||||
item = await inner_gen.__anext__()
|
||||
yield item
|
||||
|
||||
request_id.set("req-modified")
|
||||
|
||||
# Add our own yield with both modified variables
|
||||
yield (3, request_id.get(), user_id.get())
|
||||
|
||||
# Function to run in a separate thread with a new event loop
|
||||
def run_in_new_loop():
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Outer generator (runs in the new loop)
|
||||
async def outer_generator():
|
||||
request_id.set("req-12345")
|
||||
user_id.set("user-6789")
|
||||
# Wrap the middle generator
|
||||
wrapped_gen = preserve_contexts_async_generator(middle_generator(), [request_id, user_id])
|
||||
|
||||
# Process all items from the middle generator
|
||||
async for item in wrapped_gen:
|
||||
# Store results for verification
|
||||
results.append(item)
|
||||
|
||||
# Run the outer generator in the new loop
|
||||
loop.run_until_complete(outer_generator())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Run the generator chain in a separate thread with a new event loop
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
future.result() # Wait for completion
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
|
||||
# First yield should have original values
|
||||
assert results[0] == (1, "req-12345", "user-6789")
|
||||
|
||||
# Second yield should have modified user_id
|
||||
assert results[1] == (2, "req-12345", "user-modified")
|
||||
|
||||
# Third yield should have both modified values
|
||||
assert results[2] == (3, "req-modified", "user-modified")
|
164
llama_stack/models/llama/checkpoint.py
Normal file
164
llama_stack/models/llama/checkpoint.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import concurrent.futures
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||
|
||||
|
||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||
if new_mp_size % old_mp_size == 0:
|
||||
# Read old MP shard and split it into smaller ones
|
||||
return [new_mp_rank * old_mp_size // new_mp_size]
|
||||
elif old_mp_size % new_mp_size == 0:
|
||||
# Merge old MP shards into a single one
|
||||
mp_factor = old_mp_size // new_mp_size
|
||||
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Either old MP size or new MP size should be a multiple of the other: "
|
||||
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
|
||||
)
|
||||
|
||||
|
||||
def maybe_reshard_state_dict(
|
||||
ckpt_paths: List[Path],
|
||||
n_kv_heads: int,
|
||||
moe_num_experts: Optional[int] = None,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
mmap: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if str(map_location) == "cpu":
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
return state_dicts[0] # type: ignore
|
||||
|
||||
if moe_num_experts is not None:
|
||||
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
|
||||
|
||||
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
|
||||
return reshard_mp(
|
||||
state_dicts,
|
||||
size=max(new_mp_size // old_mp_size, 1),
|
||||
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
|
||||
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
|
||||
)
|
||||
|
||||
|
||||
_WEIGHT_ROW_KEY = {
|
||||
"feed_forward.w2",
|
||||
"feed_forward.mlp.fc2",
|
||||
"attention.wo",
|
||||
"feed_forward.mlp.fc2_weight",
|
||||
"feed_forward.w_out_shared_DF.weight",
|
||||
"attn.wo.weight",
|
||||
"mlp.c_proj.weight",
|
||||
}
|
||||
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
|
||||
|
||||
_WEIGHT_COLUMN_KEY = {
|
||||
"output",
|
||||
"feed_forward.(w1|w3)",
|
||||
"feed_forward.mlp.(fc1|fc3)",
|
||||
"feed_forward.mlp.fc1_weight",
|
||||
"attention.(wk|wq|wv|wqkv).weight",
|
||||
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
|
||||
"attn.(wk|wq|wv).weight",
|
||||
"attn.(wk|wq|wv).bias",
|
||||
"mlp.c_fc.weight",
|
||||
"mlp.c_fc.bias",
|
||||
"conv1._linear.weight",
|
||||
"tok_embeddings.weight",
|
||||
"vision_projection.weight",
|
||||
}
|
||||
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||
|
||||
|
||||
def reshard_mp(
|
||||
state_dicts: List[Dict[str, torch.Tensor]],
|
||||
size: int,
|
||||
rank: int,
|
||||
repeat_qk_qv: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||
If the list has more than one state dict, we concatenate the values of the same
|
||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||
"""
|
||||
|
||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
if len(tensors) > 1:
|
||||
return torch.cat(tensors, dim=dim)
|
||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||
|
||||
def process_key(key: str) -> torch.Tensor:
|
||||
if row_regex.search(key):
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
|
||||
elif column_regex.search(key):
|
||||
if "w13" in key or "fc1_weight" in key:
|
||||
dims = state_dicts[0][key].size()
|
||||
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
|
||||
return concat_or_chunk(values, dim=1).flatten(0, 1)
|
||||
elif "qkv" in key:
|
||||
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
|
||||
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
|
||||
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
|
||||
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
|
||||
elif "wk.weight" in key or "wv.weight" in key:
|
||||
# Support MP > #kv_head
|
||||
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
|
||||
elif key == "output.bias" or key == "fc.weight":
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
elif "w_" in key:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
|
||||
else:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
else:
|
||||
return state_dicts[0][key].clone()
|
||||
|
||||
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
|
||||
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
|
||||
column_regex = re.compile("|".join(column_keys))
|
||||
row_regex = re.compile("|".join(row_keys))
|
||||
|
||||
output: Dict[str, torch.Tensor] = {}
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Note: only processes keys in the first state dict.
|
||||
# Assumes keys are the same across all state dicts.
|
||||
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
|
||||
for future in concurrent.futures.as_completed(mappings):
|
||||
output[mappings[future]] = future.result()
|
||||
return output
|
||||
|
||||
|
||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
routed_regex = re.compile("|".join(routed_keys))
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if routed_regex.search(key):
|
||||
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
|
||||
return state_dict
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
|
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
@ -98,6 +89,29 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
@ -140,267 +154,25 @@ class RawMessage(BaseModel):
|
|||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
class GenerationResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
source: Literal["input"] | Literal["output"]
|
||||
|
||||
# index within the batch
|
||||
batch_idx: int
|
||||
# whether generation for this item is already finished. note that tokens can
|
||||
# get returned even afterwards since other items in the batch can still be generating tokens
|
||||
finished: bool
|
||||
# because a batch is parallel processed, useful decoding for one item can correspond to processing
|
||||
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
|
||||
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
|
||||
ignore_token: bool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.id.name
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||
class QuantizationMode(str, Enum):
|
||||
none = "none"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
|
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
@ -234,7 +226,6 @@ class ChatFormat:
|
|||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
|
|
371
llama_stack/models/llama/llama3/generation.py
Normal file
371
llama_stack/models/llama/llama3/generation.py
Normal file
|
@ -0,0 +1,371 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
device: str = "cuda",
|
||||
):
|
||||
device = torch.device(device)
|
||||
if (
|
||||
device.type == "cuda"
|
||||
and not torch.cuda.is_available()
|
||||
or device.type == "xpu"
|
||||
and not torch.xpu.is_available()
|
||||
):
|
||||
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
if device.type == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
)
|
||||
|
||||
assert model_args.vocab_size == tokenizer.n_words
|
||||
|
||||
def build_model():
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
return model
|
||||
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||
torch.set_default_device(device)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
if device.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||
elif device.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
print("Done...")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama3(model, tokenizer, model_args)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
params = self.model.params
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in llm_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
batch_masks=mask,
|
||||
total_len=total_len,
|
||||
device=tokens.device,
|
||||
)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
text_only_inference,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx].item(),
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
|
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
@ -29,6 +19,10 @@ from torch import nn
|
|||
|
||||
from .args import ModelArgs
|
||||
|
||||
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||
# dependencies. These dependencies are not part of the default dependencies
|
||||
# (requirements.txt) of the `llama-models` package.
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
|
@ -111,9 +105,9 @@ class Attention(nn.Module):
|
|||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
|
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
|
|||
n_heads,
|
||||
):
|
||||
super().__init__()
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
qkvo_replication = 1
|
||||
if model_parallel_size > 16:
|
||||
qkvo_replication = model_parallel_size // 8
|
||||
if world_size > 16:
|
||||
qkvo_replication = world_size // 8
|
||||
|
||||
self.n_kv_heads = n_heads
|
||||
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
|
||||
self.n_local_heads = n_heads * qkvo_replication // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
|
@ -536,16 +526,16 @@ class Attention(nn.Module):
|
|||
cache_v (torch.Tensor): Cached values for attention.
|
||||
"""
|
||||
super().__init__()
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
replication_factor = 1
|
||||
if model_parallel_size > 8:
|
||||
replication_factor = model_parallel_size // MP_SCALE
|
||||
if world_size > 8:
|
||||
replication_factor = world_size // MP_SCALE
|
||||
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
self.n_kv_heads *= replication_factor
|
||||
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.max_seq_len = args.max_seq_len
|
||||
|
@ -587,13 +577,11 @@ class Attention(nn.Module):
|
|||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
device = next(self.parameters()).device
|
||||
self.register_buffer(
|
||||
"key_cache",
|
||||
torch.zeros(
|
||||
cache_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
@ -602,7 +590,6 @@ class Attention(nn.Module):
|
|||
torch.zeros(
|
||||
cache_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
@ -614,6 +601,9 @@ class Attention(nn.Module):
|
|||
freqs_cis: torch.Tensor,
|
||||
position_ids: torch.LongTensor,
|
||||
):
|
||||
self.key_cache = self.key_cache.to(x.device)
|
||||
self.value_cache = self.value_cache.to(x.device)
|
||||
|
||||
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
||||
|
||||
bs, slen, _ = xq.shape
|
||||
|
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
|
|||
norm_eps: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.world_size = fs_init.get_model_parallel_world_size()
|
||||
replication_factor = 1
|
||||
if self.model_parallel_size > 8:
|
||||
replication_factor = self.model_parallel_size // MP_SCALE
|
||||
if self.world_size > 8:
|
||||
replication_factor = self.world_size // MP_SCALE
|
||||
n_kv_heads *= replication_factor
|
||||
|
||||
assert n_heads % n_kv_heads == 0
|
||||
|
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
|
|||
# trunk LLM (i.e., group query attention) -- @dubeya
|
||||
# local heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.n_heads % self.model_parallel_size == 0
|
||||
assert self.n_kv_heads % self.model_parallel_size == 0
|
||||
self.n_local_heads = self.n_heads // self.model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
||||
assert self.n_heads % self.world_size == 0
|
||||
assert self.n_kv_heads % self.world_size == 0
|
||||
self.n_local_heads = self.n_heads // self.world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
|
||||
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
|
|||
self.image_res = args.vision_chunk_size
|
||||
self.max_num_chunks = args.vision_max_num_chunks
|
||||
if return_intermediate is not None:
|
||||
return_intermediate = [int(level) for level in return_intermediate.split(",")]
|
||||
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
||||
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
||||
self.patch_size = 14
|
||||
self.vision_encoder = VisionEncoder(
|
||||
|
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.world_size = fs_init.get_model_parallel_world_size()
|
||||
assert args.vocab_size > 0
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
||||
assert self.vocab_size % self.model_parallel_size == 0
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||
assert self.vocab_size % self.world_size == 0
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
self.pos_embeddings = None
|
||||
# final norm layer (not necessary for post-norm)
|
||||
|
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
text_only_inference: bool = False,
|
||||
):
|
||||
assert self.cache_is_setup, "Please set up cache before calling forward"
|
||||
self.mask_cache = self.mask_cache.to(h.device)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
mask = self.mask_cache.index_select(2, position_ids)
|
||||
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
||||
|
||||
|
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
output = gather_from_tensor_model_parallel_region(output)
|
||||
return output.float()
|
||||
|
||||
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
|
||||
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
|
||||
# Set up the text kv caches
|
||||
device = next(self.parameters()).device
|
||||
ones = torch.ones(
|
||||
(self.max_seq_len, self.max_seq_len),
|
||||
dtype=torch.bool,
|
||||
|
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
|
||||
return (
|
||||
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
||||
full_text_row_masked_out_mask,
|
||||
full_text_row_masked_out_mask.to(device=text_device),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
max_num_chunks=args.vision_max_num_chunks,
|
||||
)
|
||||
|
||||
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
||||
self.text_model.setup_cache(max_batch_size, dtype)
|
||||
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||
self.text_model.setup_cache(max_batch_size, device, dtype)
|
||||
|
||||
def compute_vision_tokens_masks(
|
||||
self,
|
||||
batch_images: List[List[PIL_Image.Image]],
|
||||
batch_masks: List[List[List[int]]],
|
||||
total_len: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
skip_vision_encoder = False
|
||||
|
||||
|
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
image_res=self.params.vision_chunk_size,
|
||||
max_num_images=max_num_images,
|
||||
)
|
||||
stacked_images = stacked_images.to(device=device)
|
||||
|
||||
if skip_vision_encoder:
|
||||
vision_tokens = torch.zeros(
|
||||
|
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
),
|
||||
)
|
||||
else:
|
||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
|
||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
|
||||
|
||||
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
||||
xattn_caches = torch.stack(
|
|
@ -15,7 +15,7 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
|
@ -229,6 +229,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
If you decide to invoke a function, you SHOULD NOT include any other text in the response. besides the function call in the above format.
|
||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||
|
||||
|
||||
{{ function_description }}
|
||||
""".strip("\n")
|
||||
)
|
||||
|
@ -243,10 +248,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||
template_str = textwrap.dedent(
|
||||
"""
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
|
@ -279,6 +280,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
|
|
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -4,12 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import logging
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
|
@ -18,52 +15,53 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
|
|||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
from ...llama3.args import ModelArgs
|
||||
from ...llama3.model import Transformer, TransformerBlock
|
||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
from ...datatypes import QuantizationMode
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
ffn_swiglu,
|
||||
load_fp8,
|
||||
quantize_fp8,
|
||||
)
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..multimodal.model import CrossAttentionTransformer
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from .fp8_impls import ffn_swiglu
|
||||
|
||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||
elif quantization_mode == QuantizationMode.int4_mixed:
|
||||
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
||||
|
||||
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
return model
|
||||
|
||||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
||||
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
# Move weights to GPU with quantization
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
log.info("Loading fp8 scales...")
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
print("Loading fp8 scales...")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
@ -77,23 +75,23 @@ def convert_to_fp8_quantized_model(
|
|||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
log.info("Quantizing fp8 weights from bf16...")
|
||||
for block in model.layers:
|
||||
print("Quantizing fp8 weights from bf16...")
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(block.feed_forward, key)
|
||||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
output_device=device,
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
parameter.data = parameter.to(device=device)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -136,6 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
self.lora_scale: Optional[float] = None
|
||||
self.adaptor: Optional[nn.Sequential] = None
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
|
@ -143,9 +143,6 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
|||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
else:
|
||||
self.adaptor = None
|
||||
self.lora_scale = None
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
|
@ -287,16 +284,16 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer,
|
||||
model_args: ModelArgs,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
) -> Transformer:
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
model_args = model.params
|
||||
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||
quantization_args = model_args.quantization_args
|
||||
if quantization_args.scheme is None:
|
||||
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
||||
|
||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
|
@ -316,5 +313,4 @@ def convert_to_int4_quantized_model(
|
|||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
||||
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
|
@ -12,8 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from ..datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
|
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
|
|
@ -4,19 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
@ -34,80 +28,141 @@ def is_json(s):
|
|||
return True
|
||||
|
||||
|
||||
def is_valid_python_list(input_string):
|
||||
"""Check if the input string is a valid Python list of function calls"""
|
||||
try:
|
||||
# Try to parse the string
|
||||
tree = ast.parse(input_string)
|
||||
|
||||
# Check if it's a single expression
|
||||
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
||||
return False
|
||||
|
||||
# Check if the expression is a list
|
||||
expr = tree.body[0].value
|
||||
if not isinstance(expr, ast.List):
|
||||
return False
|
||||
|
||||
# Check if the list is empty
|
||||
if len(expr.elts) == 0:
|
||||
return False
|
||||
|
||||
# Check if all elements in the list are function calls
|
||||
for element in expr.elts:
|
||||
if not isinstance(element, ast.Call):
|
||||
return False
|
||||
|
||||
# Check if the function call has a valid name
|
||||
if not isinstance(element.func, ast.Name):
|
||||
return False
|
||||
|
||||
# Check if all arguments are keyword arguments
|
||||
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except SyntaxError:
|
||||
# If parsing fails, it's not a valid Python expression
|
||||
return False
|
||||
|
||||
|
||||
def parse_python_list_for_function_calls(input_string):
|
||||
def parse_llama_tool_call_format(input_string):
|
||||
"""
|
||||
Parse a Python list of function calls and
|
||||
return a list of tuples containing the function name and arguments
|
||||
"""
|
||||
# Parse the string into an AST
|
||||
tree = ast.parse(input_string)
|
||||
Parse tool calls in the format:
|
||||
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
|
||||
# Ensure the input is a list
|
||||
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
||||
raise ValueError("Input must be a list of function calls")
|
||||
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
|
||||
"""
|
||||
# Strip outer brackets and whitespace
|
||||
input_string = input_string.strip()
|
||||
if not (input_string.startswith("[") and input_string.endswith("]")):
|
||||
return None
|
||||
|
||||
content = input_string[1:-1].strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
result = []
|
||||
|
||||
# Iterate through each function call in the list
|
||||
for node in tree.body[0].value.elts:
|
||||
if isinstance(node, ast.Call):
|
||||
function_name = node.func.id
|
||||
function_args = {}
|
||||
# State variables for parsing
|
||||
pos = 0
|
||||
length = len(content)
|
||||
|
||||
# Extract keyword arguments
|
||||
for keyword in node.keywords:
|
||||
try:
|
||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
||||
) from e
|
||||
while pos < length:
|
||||
# Find function name
|
||||
name_end = content.find("(", pos)
|
||||
if name_end == -1:
|
||||
break
|
||||
|
||||
result.append((function_name, function_args))
|
||||
func_name = content[pos:name_end].strip()
|
||||
|
||||
return result
|
||||
# Find closing parenthesis for this function call
|
||||
paren_level = 1
|
||||
args_start = name_end + 1
|
||||
args_end = args_start
|
||||
|
||||
while args_end < length and paren_level > 0:
|
||||
if content[args_end] == "(":
|
||||
paren_level += 1
|
||||
elif content[args_end] == ")":
|
||||
paren_level -= 1
|
||||
args_end += 1
|
||||
|
||||
if paren_level != 0:
|
||||
# Unmatched parentheses
|
||||
return None
|
||||
|
||||
# Parse arguments
|
||||
args_str = content[args_start : args_end - 1].strip()
|
||||
args_dict = {}
|
||||
|
||||
if args_str:
|
||||
# Split by commas, but respect nested structures
|
||||
parts = []
|
||||
part_start = 0
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
nested_level = 0
|
||||
|
||||
for i, char in enumerate(args_str):
|
||||
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
|
||||
if not in_quotes:
|
||||
in_quotes = True
|
||||
quote_char = char
|
||||
elif char == quote_char:
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
elif not in_quotes:
|
||||
if char in ("{", "["):
|
||||
nested_level += 1
|
||||
elif char in ("}", "]"):
|
||||
nested_level -= 1
|
||||
elif char == "," and nested_level == 0:
|
||||
parts.append(args_str[part_start:i].strip())
|
||||
part_start = i + 1
|
||||
|
||||
parts.append(args_str[part_start:].strip())
|
||||
|
||||
# Process each key=value pair
|
||||
for part in parts:
|
||||
if "=" in part:
|
||||
key, value = part.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to convert value to appropriate Python type
|
||||
if (value.startswith('"') and value.endswith('"')) or (
|
||||
value.startswith("'") and value.endswith("'")
|
||||
):
|
||||
# String
|
||||
value = value[1:-1]
|
||||
elif value.lower() == "true":
|
||||
value = True
|
||||
elif value.lower() == "false":
|
||||
value = False
|
||||
elif value.lower() == "none":
|
||||
value = None
|
||||
elif value.startswith("{") and value.endswith("}"):
|
||||
# This is a nested dictionary
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
value = json.loads(value.replace("'", '"'))
|
||||
except json.JSONDecodeError:
|
||||
# Keep as string if parsing fails
|
||||
pass
|
||||
elif value.startswith("[") and value.endswith("]"):
|
||||
# This is a nested list
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
value = json.loads(value.replace("'", '"'))
|
||||
except json.JSONDecodeError:
|
||||
# Keep as string if parsing fails
|
||||
pass
|
||||
else:
|
||||
# Try to convert to number
|
||||
try:
|
||||
if "." in value:
|
||||
value = float(value)
|
||||
else:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
# Keep as string if not a valid number
|
||||
pass
|
||||
|
||||
args_dict[key] = value
|
||||
|
||||
result.append((func_name, args_dict))
|
||||
|
||||
# Move to the next function call
|
||||
pos = args_end
|
||||
|
||||
# Skip the comma between function calls if present
|
||||
if pos < length and content[pos] == ",":
|
||||
pos += 1
|
||||
|
||||
return result if result else None
|
||||
|
||||
|
||||
class ToolUtils:
|
||||
|
@ -149,17 +204,19 @@ class ToolUtils:
|
|||
return None
|
||||
elif is_json(message_body):
|
||||
response = json.loads(message_body)
|
||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
||||
if ("type" in response and response["type"] == "function") or (
|
||||
"name" in response and "parameters" in response
|
||||
):
|
||||
function_name = response["name"]
|
||||
args = response["parameters"]
|
||||
return function_name, args
|
||||
else:
|
||||
return None
|
||||
elif is_valid_python_list(message_body):
|
||||
res = parse_python_list_for_function_calls(message_body)
|
||||
elif function_calls := parse_llama_tool_call_format(message_body):
|
||||
# FIXME: Enable multiple tool calls
|
||||
return res[0]
|
||||
return function_calls[0]
|
||||
else:
|
||||
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -21,8 +21,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from ..prompt_format import (
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
# llama3_1_e2e_tool_call_dialog,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
|
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
|
|
@ -4,12 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
|
|
5
llama_stack/models/llama/llama4/__init__.py
Normal file
5
llama_stack/models/llama/llama4/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
108
llama_stack/models/llama/llama4/args.py
Normal file
108
llama_stack/models/llama/llama4/args.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class QuantizationScheme(Enum):
|
||||
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
scheme: Optional[QuantizationScheme] = None
|
||||
group_size: Optional[int] = None
|
||||
spinquant: bool = False
|
||||
|
||||
|
||||
class LoRAArgs(BaseModel):
|
||||
rank: int
|
||||
scale: float
|
||||
|
||||
|
||||
class MoEArgs(BaseModel):
|
||||
num_experts: int = -1
|
||||
capacity_factor: float = 1.0 # capacity factor determines how many tokens each expert can choose
|
||||
auto_scale_F: bool = ( # noqa: N815
|
||||
True # if true, rescales hidden_dim such that number of activated params is same as equivalent dense layer
|
||||
)
|
||||
top_k: int = 1
|
||||
interleave_moe_layer_step: int = 1
|
||||
|
||||
|
||||
class Size(BaseModel):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class VisionArgs(BaseModel):
|
||||
image_size: Size
|
||||
patch_size: Size
|
||||
|
||||
# parameters for the encoder transformer
|
||||
dim: int
|
||||
n_layers: int
|
||||
n_heads: int
|
||||
mlp_ratio: float
|
||||
output_dim: int
|
||||
|
||||
pixel_shuffle_ratio: float
|
||||
|
||||
|
||||
class ModelArgs(BaseModel):
|
||||
dim: int = -1
|
||||
n_layers: int = -1
|
||||
n_heads: int = -1
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
vocab_size: int = -1
|
||||
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_exp: Optional[float] = None
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
rope_scaling_factor: Optional[float] = None
|
||||
rope_high_freq_factor: Optional[float] = None
|
||||
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
attn_temperature_tuning: bool = False
|
||||
floor_scale: float = 8192.0
|
||||
attn_scale: float = 0.1
|
||||
|
||||
vision_args: Optional[VisionArgs] = None
|
||||
moe_args: Optional[MoEArgs] = None
|
||||
quantization_args: Optional[QuantizationArgs] = None
|
||||
lora_args: Optional[LoRAArgs] = None
|
||||
|
||||
max_batch_size: int = 32
|
||||
max_seq_len: int = 2048
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate(self) -> "ModelArgs":
|
||||
assert self.n_kv_heads <= self.n_heads, f"n_kv_heads ({self.n_kv_heads}) must be <= n_heads ({self.n_heads})"
|
||||
assert self.n_heads % self.n_kv_heads == 0, (
|
||||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
|
||||
if self.use_scaled_rope:
|
||||
# NOTE: ideally these values should have come from params.json. However, we have
|
||||
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
|
||||
# specific values.
|
||||
if self.rope_scaling_factor is None:
|
||||
self.rope_scaling_factor = 16
|
||||
if self.rope_high_freq_factor is None:
|
||||
self.rope_high_freq_factor = 1
|
||||
|
||||
return self
|
316
llama_stack/models/llama/llama4/chat_format.py
Normal file
316
llama_stack/models/llama/llama4/chat_format.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
RawMessage,
|
||||
RawTextItem,
|
||||
Role,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
def role_str(role: Role) -> str:
|
||||
role_strs = {
|
||||
Role.user: "user",
|
||||
Role.system: "system",
|
||||
Role.tool: "ipython", # special
|
||||
Role.assistant: "assistant",
|
||||
}
|
||||
return role_strs[role]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformedImage:
|
||||
image_tiles: torch.Tensor
|
||||
# is the aspect ratio needed anywhere?
|
||||
aspect_ratio: Tuple[int, int]
|
||||
|
||||
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
new_img.paste(image, mask=image.split()[3]) # 3 is the alpha channel
|
||||
return new_img
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
possible_headers: Dict[Role, str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
vision_args: Optional[VisionArgs] = None,
|
||||
max_num_chunks: int = 16,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.vision_args = vision_args
|
||||
self.max_num_chunks = max_num_chunks
|
||||
|
||||
self.possible_headers = {role: f"<|header_start|>{role_str(role)}<|header_end|>\n\n" for role in Role}
|
||||
|
||||
self.image_transform = None
|
||||
self.dynamic_image_transform = None
|
||||
if vision_args:
|
||||
self.dynamic_image_transform = VariableSizeImageTransform(vision_args.image_size.width)
|
||||
self.image_transform = ResizeNormalizeImageTransform(
|
||||
vision_args.image_size.width, vision_args.image_size.height
|
||||
)
|
||||
|
||||
def _encode_header(self, role: str) -> List[int]:
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_start|>"])
|
||||
|
||||
# TODO: need to check if this is correct
|
||||
tokens.extend(self.tokenizer.encode("ipython" if role == "tool" else role, bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|header_end|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
|
||||
def encode_content(self, content: RawContent) -> LLMInput:
|
||||
tokens, images = self._encode_content(content, bos=True)
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
transformed_image: TransformedImage,
|
||||
) -> List[int]:
|
||||
assert self.vision_args is not None, "The model is not vision-enabled"
|
||||
|
||||
image_tensor = transformed_image.image_tiles
|
||||
image_channels = image_tensor.shape[-3]
|
||||
image_height = image_tensor.shape[-2]
|
||||
image_width = image_tensor.shape[-1]
|
||||
image_chunks = image_tensor.view(-1, image_channels, image_height, image_width).shape[0]
|
||||
|
||||
patch_height = self.vision_args.patch_size.height
|
||||
patch_width = self.vision_args.patch_size.width
|
||||
|
||||
if image_height % patch_height != 0:
|
||||
raise ValueError(f"{image_height=} not divisible by {patch_height=}")
|
||||
if image_width % patch_width != 0:
|
||||
raise ValueError(f"{image_width=} not divisible by {patch_width=}")
|
||||
|
||||
ds_ratio = int(round(1.0 / (self.vision_args.pixel_shuffle_ratio**2)))
|
||||
n_patches_per_chunk = int((image_height // patch_height) * (image_width // patch_width) // ds_ratio)
|
||||
|
||||
image_ar = transformed_image.aspect_ratio
|
||||
tokens = [self.tokenizer.special_tokens["<|image_start|>"]]
|
||||
if image_chunks == 1:
|
||||
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
||||
else:
|
||||
ratio_h, ratio_w = image_ar
|
||||
for _ in range(ratio_h):
|
||||
for xx in range(ratio_w):
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
if xx < ratio_w - 1:
|
||||
tokens.append(self.tokenizer.special_tokens["<|tile_x_separator|>"])
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|tile_y_separator|>"])
|
||||
|
||||
tokens += [self.tokenizer.special_tokens["<|image|>"]]
|
||||
tokens += [self.tokenizer.special_tokens["<|patch|>"]] * n_patches_per_chunk
|
||||
tokens += [self.tokenizer.special_tokens["<|image_end|>"]]
|
||||
|
||||
return tokens
|
||||
|
||||
def _encode_content(self, content: RawContent, bos: bool = False) -> Tuple[List[int], List[TransformedImage]]:
|
||||
tokens = []
|
||||
tranformed_images = []
|
||||
|
||||
added_bos = False
|
||||
|
||||
def _process(c):
|
||||
nonlocal added_bos, bos
|
||||
|
||||
if isinstance(c, str) or isinstance(c, RawTextItem):
|
||||
if isinstance(c, RawTextItem):
|
||||
c = c.text
|
||||
tokens.extend(self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False))
|
||||
added_bos = True
|
||||
|
||||
elif isinstance(c, RawMediaItem):
|
||||
if not self.vision_args:
|
||||
raise ValueError("The model is not vision-enabled, but a media item was found")
|
||||
|
||||
bos = False if added_bos else bos
|
||||
if bos:
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
added_bos = True
|
||||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
if image_tiles.shape[0] > 1:
|
||||
image_global = self.image_transform(image)
|
||||
image_global = image_global.unsqueeze(0)
|
||||
image_combine = torch.cat((image_tiles, image_global), dim=0)
|
||||
image_tiles = image_combine
|
||||
|
||||
transformed_image = TransformedImage(image_tiles=image_tiles, aspect_ratio=ar)
|
||||
tokens.extend(self._encode_image(transformed_image))
|
||||
tranformed_images.append(transformed_image)
|
||||
|
||||
if isinstance(content, list):
|
||||
for c in content:
|
||||
_process(c)
|
||||
else:
|
||||
_process(content)
|
||||
|
||||
return tokens, tranformed_images
|
||||
|
||||
def encode_message(
|
||||
self, message: RawMessage, tool_prompt_format: ToolPromptFormat
|
||||
) -> Tuple[List[int], List[TransformedImage]]:
|
||||
tokens = self._encode_header(message.role)
|
||||
images = []
|
||||
|
||||
def _process_content(c):
|
||||
toks, imgs = self._encode_content(c)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
_process_content(message.content)
|
||||
|
||||
if message.role == "user" and message.context is not None:
|
||||
# This is RAG context; why is it here in the chat format? I don't think
|
||||
# this is needed and can be moved upwards
|
||||
_process_content("\n\n")
|
||||
_process_content(message.context)
|
||||
|
||||
if message.role == "assistant":
|
||||
for t in message.tool_calls:
|
||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
# Tool calls and Tool Response messages should be eom
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
|
||||
elif message.role == "tool":
|
||||
eom = True
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
||||
return tokens, images
|
||||
|
||||
def encode_dialog_prompt(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
) -> LLMInput:
|
||||
tokens = []
|
||||
images = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
|
||||
for message in messages:
|
||||
toks, imgs = self.encode_message(message, tool_prompt_format)
|
||||
tokens.extend(toks)
|
||||
images.extend(imgs)
|
||||
|
||||
# Add the start of an assistant message for the model to complete.
|
||||
tokens.extend(self._encode_header("assistant"))
|
||||
|
||||
return self._model_input_from_tokens_images(tokens, images)
|
||||
|
||||
# TODO(this should be generic, not only for assistant messages)
|
||||
def decode_assistant_message(self, tokens: List[int], stop_reason: StopReason) -> RawMessage:
|
||||
content = self.tokenizer.decode(tokens)
|
||||
|
||||
return self.decode_assistant_message_from_content(content, stop_reason)
|
||||
|
||||
def decode_assistant_message_from_content(self, content: str, stop_reason: StopReason) -> RawMessage:
|
||||
content = content.strip(" ")
|
||||
header_str = self.possible_headers[Role.assistant]
|
||||
if content.startswith(header_str):
|
||||
content = content[len(header_str) :]
|
||||
|
||||
ipython = content.startswith("<|python_start|>")
|
||||
if ipython:
|
||||
content = content[len("<|python_start|>") :]
|
||||
content = content.replace("<|python_end|>", "")
|
||||
|
||||
if content.endswith("<|eot|>"):
|
||||
content = content[: -len("<|eot|>")]
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif content.endswith("<|eom|>"):
|
||||
content = content[: -len("<|eom|>")]
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
tool_name = None
|
||||
tool_arguments = {}
|
||||
|
||||
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
|
||||
if custom_tool_info is not None:
|
||||
tool_name, tool_arguments = custom_tool_info
|
||||
# Sometimes when agent has custom tools alongside builin tools
|
||||
# Agent responds for builtin tool calls in the format of the custom tools
|
||||
# This code tries to handle that case
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
tool_arguments = {
|
||||
"query": list(tool_arguments.values())[0],
|
||||
}
|
||||
else:
|
||||
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
|
||||
if builtin_tool_info is not None:
|
||||
tool_name, query = builtin_tool_info
|
||||
tool_arguments = {
|
||||
"query": query,
|
||||
}
|
||||
if tool_name in BuiltinTool.__members__:
|
||||
tool_name = BuiltinTool[tool_name]
|
||||
elif ipython:
|
||||
tool_name = BuiltinTool.code_interpreter
|
||||
tool_arguments = {
|
||||
"code": content,
|
||||
}
|
||||
|
||||
tool_calls = []
|
||||
if tool_name is not None and tool_arguments is not None:
|
||||
call_id = str(uuid.uuid4())
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
def _model_input_from_tokens_images(self, tokens: List[int], images: List[TransformedImage]) -> LLMInput:
|
||||
return LLMInput(
|
||||
tokens=tokens,
|
||||
images=[x.image_tiles for x in images] if len(images) > 0 else None,
|
||||
)
|
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
57
llama_stack/models/llama/llama4/datatypes.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedEmbedding:
|
||||
embedding: torch.Tensor
|
||||
mask: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMInput:
|
||||
"""
|
||||
This is the input to the LLM from the "user" -- the user in this case views the
|
||||
Llama4 model holistically and does not care or know about its inner workings (e.g.,
|
||||
whether it has an encoder or if it is early fusion or not.)
|
||||
|
||||
This is distinct from the "TransformerInput" class which is really the Llama4
|
||||
backbone operating on early fused modalities and producing text output
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# images are already pre-processed (resized, tiled, etc.)
|
||||
images: Optional[List[torch.Tensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerInput:
|
||||
"""
|
||||
This is the "core" backbone transformer of the Llama4 model. Inputs for other modalities
|
||||
are expected to be "embedded" via encoders sitting before this layer in the model.
|
||||
"""
|
||||
|
||||
tokens: torch.Tensor
|
||||
|
||||
# tokens_position defines the position of the tokens in each batch,
|
||||
# - when it is a tensor ([batch_size,]), it is the start position of the tokens in each batch
|
||||
# - when it is an int, the start position are the same for all batches
|
||||
tokens_position: Union[torch.Tensor, int]
|
||||
image_embedding: Optional[MaskedEmbedding] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMOutput:
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
TransformerOutput = LLMOutput
|
58
llama_stack/models/llama/llama4/ffn.py
Normal file
58
llama_stack/models/llama/llama4/ffn.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
do_reduce: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.do_reduce = do_reduce
|
||||
|
||||
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
||||
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "mlp.fc1_weight" in state_dict:
|
||||
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
|
||||
state_dict[prefix + "w1.weight"] = w1
|
||||
state_dict[prefix + "w3.weight"] = w3
|
||||
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
|
||||
|
||||
def forward(self, x):
|
||||
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
|
||||
out = F.linear(x, self.w2.weight)
|
||||
if self.do_reduce:
|
||||
return reduce_from_model_parallel_region(out)
|
||||
return out
|
313
llama_stack/models/llama/llama4/generation.py
Normal file
313
llama_stack/models/llama/llama4/generation.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import codecs
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, RawContent, RawMessage
|
||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||
from .model import Transformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||
|
||||
|
||||
class Llama4:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
**params,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
# TODO: params.json should always have correct vocab_size
|
||||
if model_args.vocab_size == -1:
|
||||
model_args.vocab_size = tokenizer.n_words
|
||||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
moe_num_experts=model_args.moe_args.num_experts,
|
||||
)
|
||||
print("Loaded checkpoint")
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama4(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer, vision_args=args.vision_args)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
params = self.model.args
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
cprint("Input to model:\n", "yellow")
|
||||
for inp in llm_inputs:
|
||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
image_embedding = None
|
||||
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
|
||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||
image_mask = image_mask.unsqueeze(-1)
|
||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||
|
||||
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
|
||||
image_embedding = MaskedEmbedding(
|
||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||
mask=image_mask,
|
||||
)
|
||||
|
||||
xformer_input = TransformerInput(
|
||||
tokens=tokens[:, prev_pos:cur_pos],
|
||||
tokens_position=prev_pos,
|
||||
image_embedding=image_embedding,
|
||||
)
|
||||
xformer_output = self.model.forward(xformer_input)
|
||||
logits = xformer_output.logits
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx].item(),
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
437
llama_stack/models/llama/llama4/model.py
Normal file
437
llama_stack/models/llama/llama4/model.py
Normal file
|
@ -0,0 +1,437 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .args import ModelArgs
|
||||
from .datatypes import TransformerInput, TransformerOutput
|
||||
from .ffn import FeedForward
|
||||
from .moe import MoE
|
||||
|
||||
|
||||
def rmsnorm(x, eps):
|
||||
def _norm(y):
|
||||
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
return _norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return rmsnorm(x, self.eps) * self.weight
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
|
||||
low_freq_factor = 1
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / scale_factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
dim: int,
|
||||
end: int,
|
||||
theta: float,
|
||||
use_scaled: bool,
|
||||
scale_factor: float,
|
||||
high_freq_factor: float,
|
||||
):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
# TODO: this module needs to be moved into a separate file since it can be used by
|
||||
# the vision encoder as well.
|
||||
def __init__(
|
||||
self,
|
||||
args: ModelArgs,
|
||||
use_qk_norm: bool,
|
||||
use_rope: bool,
|
||||
add_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
self.use_qk_norm = use_qk_norm
|
||||
# For attention temperature tuning
|
||||
self.attn_temperature_tuning = args.attn_temperature_tuning
|
||||
self.floor_scale = args.floor_scale
|
||||
self.attn_scale = args.attn_scale
|
||||
|
||||
self.n_heads = args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
||||
self.wq = ColumnParallelLinear(
|
||||
args.dim,
|
||||
args.n_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
args.dim,
|
||||
self.n_kv_heads * self.head_dim,
|
||||
bias=add_bias,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
args.n_heads * self.head_dim,
|
||||
args.dim,
|
||||
bias=add_bias,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
|
||||
self.cache_k = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.cache_v = torch.zeros(
|
||||
(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.norm_eps = args.norm_eps
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "wqkv.weight" in state_dict:
|
||||
wqkv = state_dict.pop(prefix + "wqkv.weight")
|
||||
d, r = divmod(wqkv.shape[0], self.n_heads + 2 * self.n_kv_heads)
|
||||
if r != 0:
|
||||
raise ValueError(
|
||||
f"shape={tuple(wqkv.shape)} is not divisible by "
|
||||
f"n_heads ({self.n_heads}) + 2 * n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
wq, wk, wv = wqkv.split([d * self.n_heads, d * self.n_kv_heads, d * self.n_kv_heads], dim=0)
|
||||
state_dict[prefix + "wq.weight"] = wq
|
||||
state_dict[prefix + "wk.weight"] = wk
|
||||
state_dict[prefix + "wv.weight"] = wv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
||||
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
||||
|
||||
if self.use_rope:
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = rmsnorm(xq, self.norm_eps)
|
||||
xk = rmsnorm(xk, self.norm_eps)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
seq_positions = torch.arange(start_pos, start_pos + seqlen, device=xq.device, dtype=torch.float32)
|
||||
attn_scales = torch.log(torch.floor((seq_positions + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
# reshape for broadcasting [seqlen] -> [1, seqlen, 1, 1]
|
||||
attn_scales = attn_scales.view(1, seqlen, 1, 1)
|
||||
xq = xq * attn_scales
|
||||
|
||||
self.cache_k = self.cache_k.to(xq)
|
||||
self.cache_v = self.cache_v.to(xq)
|
||||
|
||||
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
|
||||
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
|
||||
|
||||
xk = self.cache_k[:bsz, : start_pos + seqlen]
|
||||
xv = self.cache_v[:bsz, : start_pos + seqlen]
|
||||
|
||||
xq, xk, xv = [t.transpose(1, 2) for t in (xq, xk, xv)]
|
||||
|
||||
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
||||
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
output = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_heads = args.n_heads
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads if args.head_dim is None else args.head_dim
|
||||
|
||||
self.is_nope_layer = args.nope_layer_interval is not None and (layer_id + 1) % args.nope_layer_interval == 0
|
||||
|
||||
use_rope = not self.is_nope_layer
|
||||
use_qk_norm = args.use_qk_norm and not self.is_nope_layer
|
||||
|
||||
self.attention = Attention(args, use_rope=use_rope, use_qk_norm=use_qk_norm)
|
||||
|
||||
if args.moe_args and (layer_id + 1) % args.moe_args.interleave_moe_layer_step == 0:
|
||||
self.feed_forward = MoE(
|
||||
dim=args.dim,
|
||||
hidden_dim=int(args.ffn_exp * args.dim),
|
||||
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
||||
multiple_of=args.multiple_of,
|
||||
moe_args=args.moe_args,
|
||||
)
|
||||
else:
|
||||
hidden_dim = int(4 * args.dim)
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
if args.ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
dim=args.dim,
|
||||
hidden_dim=hidden_dim,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "attention.wqkv.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "attention_norm.weight"] = state_dict.pop(prefix + "attention.wqkv.layer_norm_weight")
|
||||
|
||||
if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.mlp.layer_norm_weight")
|
||||
elif prefix + "feed_forward.norm.weight" in state_dict:
|
||||
state_dict[prefix + "ffn_norm.weight"] = state_dict.pop(prefix + "feed_forward.norm.weight")
|
||||
|
||||
for k in (
|
||||
"feed_forward.experts.mlp",
|
||||
"feed_forward.mlp_shared",
|
||||
"attention.wo",
|
||||
"attention.wqkv",
|
||||
):
|
||||
if prefix + k + "._extra_state" in state_dict:
|
||||
state_dict.pop(prefix + k + "._extra_state")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
global_attn_mask: Optional[torch.Tensor],
|
||||
local_attn_mask: Optional[torch.Tensor],
|
||||
):
|
||||
# The iRoPE architecture uses global attention mask for NoPE layers or
|
||||
# if chunked local attention is not used
|
||||
if self.is_nope_layer or local_attn_mask is None:
|
||||
mask = global_attn_mask
|
||||
else:
|
||||
mask = local_attn_mask
|
||||
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(layer_id, args))
|
||||
|
||||
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||||
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
|
||||
|
||||
self.freqs_cis = precompute_freqs_cis(
|
||||
args.dim // args.n_heads,
|
||||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
args.rope_scaling_factor,
|
||||
args.rope_high_freq_factor,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
# circular import otherwise until we refactor out Attention
|
||||
from .vision.embedding import VisionEmbeddings
|
||||
|
||||
self.vision_embeddings = VisionEmbeddings(vision_args)
|
||||
self.vision_projection = ColumnParallelLinear(
|
||||
vision_args.output_dim,
|
||||
args.dim,
|
||||
bias=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "rope.freqs" in state_dict:
|
||||
state_dict.pop(prefix + "rope.freqs")
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, model_input: TransformerInput) -> TransformerOutput:
|
||||
tokens = model_input.tokens
|
||||
start_pos = model_input.tokens_position
|
||||
assert isinstance(start_pos, int), (
|
||||
"This implementation does not support different start positions per batch item"
|
||||
)
|
||||
|
||||
_bsz, seqlen = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
if image_embedding := model_input.image_embedding:
|
||||
h_image = self.vision_projection(image_embedding.embedding)
|
||||
h = h * ~image_embedding.mask + h_image * image_embedding.mask
|
||||
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
|
||||
global_attn_mask, local_attn_mask = None, None
|
||||
if seqlen > 1:
|
||||
global_attn_mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
|
||||
global_attn_mask = torch.triu(global_attn_mask, diagonal=1).type_as(h)
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/100005
|
||||
# torch.triu is buggy when the device is mps: filled values are
|
||||
# nan instead of 0.
|
||||
if global_attn_mask.device.type == torch.device("mps").type:
|
||||
global_attn_mask = torch.nan_to_num(global_attn_mask, nan=0.0)
|
||||
|
||||
if chunk_size := self.args.attention_chunk_size:
|
||||
local_attn_mask = create_chunked_attention_mask(seqlen, chunk_size, tokens.device)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, global_attn_mask, local_attn_mask)
|
||||
h = self.norm(h)
|
||||
output = self.output(h).float()
|
||||
|
||||
return TransformerOutput(logits=output)
|
||||
|
||||
|
||||
# tokens (0, K), (K, 2K), (2K, 3K) attend to each other when doing local chunked attention
|
||||
# in the iRoPE architecture
|
||||
def create_chunked_attention_mask(seq_len: int, attention_chunk_size: int, device: torch.device) -> torch.Tensor:
|
||||
block_pos = torch.abs(
|
||||
(torch.arange(seq_len).unsqueeze(0) // attention_chunk_size)
|
||||
- (torch.arange(seq_len).unsqueeze(1) // attention_chunk_size)
|
||||
)
|
||||
token_pos = torch.arange(seq_len).unsqueeze(0) - torch.arange(seq_len).unsqueeze(1)
|
||||
mask = (block_pos == 0) & (token_pos <= 0)
|
||||
return mask.to(device)
|
214
llama_stack/models/llama/llama4/moe.py
Normal file
214
llama_stack/models/llama/llama4/moe.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# ruff: noqa: N806
|
||||
# pyre-strict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .args import MoEArgs
|
||||
from .ffn import FeedForward
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_local_experts: int,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
dtype = torch.get_default_dtype()
|
||||
self.num_local_experts = num_local_experts
|
||||
self.dim = dim
|
||||
divide_factor = fs_init.get_model_parallel_world_size()
|
||||
|
||||
self.w1: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w2: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.w3: nn.Parameter = nn.Parameter(
|
||||
torch.empty(
|
||||
num_local_experts,
|
||||
dim,
|
||||
divide_exact(hidden_dim, divide_factor),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
self.prefix = prefix
|
||||
if prefix + "moe_w_in_eD_F" in state_dict:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
|
||||
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
|
||||
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
routed_in_egD: torch.Tensor, # noqa: N803
|
||||
) -> torch.Tensor:
|
||||
e = self.num_local_experts
|
||||
D = self.dim
|
||||
|
||||
x_egD = routed_in_egD.view(e, -1, D)
|
||||
|
||||
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
|
||||
out_egD = out_egD.view(-1, D)
|
||||
|
||||
return out_egD
|
||||
|
||||
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
||||
return torch.bmm(middle_out_egF, w2)
|
||||
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGD: [e, G, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
ffn_dim_multiplier: float,
|
||||
multiple_of: int,
|
||||
moe_args: MoEArgs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.moe_args = moe_args
|
||||
|
||||
hidden_dim_denom: float = 1
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim_denom = moe_args.capacity_factor + 1
|
||||
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
|
||||
# custom dim factor multiplier
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
|
||||
if moe_args.auto_scale_F:
|
||||
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
||||
|
||||
hidden_dim += -hidden_dim % multiple_of
|
||||
|
||||
num_local_experts: int = moe_args.num_experts
|
||||
dtype: torch.dtype = torch.get_default_dtype()
|
||||
self.experts = Experts(
|
||||
num_local_experts,
|
||||
dim,
|
||||
hidden_dim,
|
||||
)
|
||||
|
||||
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
|
||||
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
if prefix + "w_in_shared_FD.weight" in state_dict:
|
||||
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
|
||||
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
|
||||
|
||||
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
|
||||
_, slen, D = x_bsD.shape
|
||||
x_aD = x_bsD.view(-1, D)
|
||||
|
||||
a = x_aD.shape[0]
|
||||
|
||||
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
|
||||
|
||||
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
|
||||
router_scores = (
|
||||
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
|
||||
.scatter_(1, router_indices_aK, router_scores_aK)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
|
||||
|
||||
router_scores = torch.sigmoid(router_scores)
|
||||
|
||||
routed_in_EG_D: Tensor = torch.gather(
|
||||
x_aD,
|
||||
dim=0,
|
||||
index=router_indices.reshape(-1, 1).expand(-1, D),
|
||||
)
|
||||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_eg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
||||
|
||||
|
||||
def divide_exact(numerator: int, denominator: int) -> int:
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
436
llama_stack/models/llama/llama4/preprocess.py
Normal file
|
@ -0,0 +1,436 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv
|
||||
from PIL import Image, ImageFile
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
IMAGE_RES = 448
|
||||
|
||||
|
||||
class ResizeNormalizeImageTransform:
|
||||
def __init__(
|
||||
self,
|
||||
size_width=None,
|
||||
size_height=None,
|
||||
) -> None:
|
||||
self._size_width = size_width or IMAGE_RES
|
||||
self._size_height = size_height or IMAGE_RES
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
|
||||
self.tv_transform = tv.Compose(
|
||||
[
|
||||
tv.Resize((self._size_height, self._size_width)),
|
||||
tv.ToTensor(),
|
||||
tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, image: Image.Image) -> torch.Tensor:
|
||||
return self.tv_transform(image)
|
||||
|
||||
|
||||
class VariableSizeImageTransform(object):
|
||||
"""
|
||||
This class accepts images of any size and dynamically resize, pads and chunks it
|
||||
based on the image aspect ratio and the number of image chunks we allow.
|
||||
|
||||
The algorithm will NOT distort the image fit a certain aspect ratio, because
|
||||
that leads to a significant degradation in image quality.
|
||||
|
||||
It can be summarized in 6 steps:
|
||||
1. Find all possible canvas combinations of max_num_chunks;
|
||||
2. Find the best canvas to fit the image;
|
||||
3. Resize without distortion
|
||||
4. Pad
|
||||
5. Normalize
|
||||
6. Chunk
|
||||
|
||||
For example, if an input image is of size 300x800, patch_size of 224,
|
||||
and max_num_chunks = 8, it will find the closest aspect ratio that
|
||||
is allowed within 8 image chunks, with some restrictions.
|
||||
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
|
||||
giving a total of 8 chunks.
|
||||
|
||||
If resize_to_max_canvas, the image will be resized (without distortion),
|
||||
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
|
||||
where we maintain the original aspect ratio and pad with zeros value for the rest.
|
||||
This approach minimizes the amount of padding required for any arbitrary resolution.
|
||||
|
||||
However, if limit_upscaling_to_patch_size is set to True,
|
||||
the upscaling will be limited to the patch size. In the example above,
|
||||
the image would remain 300x800 (no upscaling), and then padded to 448:896.
|
||||
|
||||
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
|
||||
patches are coming from the resizing and chunking.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = IMAGE_RES) -> None:
|
||||
self.size = size
|
||||
self.to_tensor = tv.ToTensor()
|
||||
self._mean = (0.5, 0.5, 0.5)
|
||||
self._std = (0.5, 0.5, 0.5)
|
||||
self.normalize = tv.Normalize(
|
||||
mean=self._mean,
|
||||
std=self._std,
|
||||
inplace=True,
|
||||
)
|
||||
self.resample = tv.InterpolationMode.BILINEAR
|
||||
|
||||
@staticmethod
|
||||
def get_factors(n: int) -> Set[int]:
|
||||
"""
|
||||
Calculate all factors of a given number, i.e. a dividor that leaves
|
||||
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
|
||||
|
||||
Args:
|
||||
n (int): The number to find factors for.
|
||||
|
||||
Returns:
|
||||
set: A set containing all factors of the number.
|
||||
"""
|
||||
factors_set = set()
|
||||
|
||||
for i in range(1, int(n**0.5) + 1):
|
||||
if n % i == 0:
|
||||
factors_set.add(i)
|
||||
factors_set.add(n // i)
|
||||
return factors_set
|
||||
|
||||
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Computes all of the allowed resoltuions for a fixed number of chunks
|
||||
and patch_size. Useful for when dividing an image into chunks.
|
||||
|
||||
Args:
|
||||
max_num_chunks (int): Maximum number of chunks for processing.
|
||||
patch_size (int): Size of the side of the patch.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: List of possible resolutions as tuples (height, width).
|
||||
|
||||
Example:
|
||||
>>> max_num_chunks = 5
|
||||
>>> patch_size = 224
|
||||
>>> find_supported_resolutions(max_num_chunks, patch_size)
|
||||
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
|
||||
(672, 224), (224, 448), (448, 224)])
|
||||
|
||||
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
|
||||
{
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.33: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
|
||||
and return the resolutions multiplied by the patch_size:
|
||||
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
|
||||
"""
|
||||
asp_dict = defaultdict(list)
|
||||
for chunk_size in range(max_num_chunks, 0, -1):
|
||||
_factors = sorted(self.get_factors(chunk_size))
|
||||
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
|
||||
for height, width in _asp_ratios:
|
||||
ratio_float = height / width
|
||||
asp_dict[ratio_float].append((height, width))
|
||||
|
||||
# get the resolutions multiplied by the patch_size
|
||||
possible_resolutions = []
|
||||
for value in asp_dict.values():
|
||||
for height, width in value:
|
||||
possible_resolutions.append((height * patch_size, width * patch_size))
|
||||
|
||||
return possible_resolutions
|
||||
|
||||
@staticmethod
|
||||
def get_max_res_without_distortion(
|
||||
image_size: Tuple[int, int],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the maximum resolution to which an image can be resized to without distorting its
|
||||
aspect ratio, based on the target resolution.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): The original resolution of the image (height, width).
|
||||
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
|
||||
Returns:
|
||||
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
|
||||
Example:
|
||||
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
|
||||
(134, 200)
|
||||
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
|
||||
(450, 338)
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
target_width, target_height = target_size
|
||||
|
||||
scale_w = target_width / original_width
|
||||
scale_h = target_height / original_height
|
||||
|
||||
if scale_w < scale_h:
|
||||
new_width = target_width
|
||||
new_height = min(math.floor(original_height * scale_w), target_height)
|
||||
else:
|
||||
new_height = target_height
|
||||
new_width = min(math.floor(original_width * scale_h), target_width)
|
||||
|
||||
return new_width, new_height
|
||||
|
||||
def _pad(self, image: Image.Image, target_size) -> Image.Image:
|
||||
new_width, new_height = target_size
|
||||
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
||||
new_im.paste(image)
|
||||
return new_im
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
num_channels, height, width = image.size()
|
||||
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize_without_distortion(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
target_size: Tuple[int, int],
|
||||
max_upscaling_size: Optional[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Used to resize an image to target_resolution, without distortion.
|
||||
|
||||
If target_size requires upscaling the image, the user can set max_upscaling_size to
|
||||
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
|
||||
modifying target_size works as a boundary for the image's largest side.
|
||||
|
||||
Args:
|
||||
resample (str): Resampling method used when resizing images.
|
||||
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
|
||||
max_upscaling_size (int): The maximum size to upscale the image to.
|
||||
If None, there is no limit.
|
||||
Examples:
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(600, 300) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 600
|
||||
>>> image_size = (2000, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 100) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = 2000
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
|
||||
>>> target_size = (1000, 1200)
|
||||
>>> max_upscaling_size = None
|
||||
>>> image_size = (400, 200)
|
||||
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
|
||||
(1000, 500) # new_size_without_distortion
|
||||
"""
|
||||
|
||||
image_width, image_height = image.size
|
||||
image_size = (image_width, image_height)
|
||||
|
||||
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
|
||||
if max_upscaling_size is not None:
|
||||
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
|
||||
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
|
||||
target_size = (new_target_width, new_target_height)
|
||||
|
||||
# resize to target_size while preserving aspect ratio
|
||||
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
|
||||
|
||||
image = F.resize(
|
||||
image,
|
||||
(
|
||||
max(new_size_without_distortion[1], 1),
|
||||
max(new_size_without_distortion[0], 1),
|
||||
),
|
||||
interpolation=self.resample,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
def get_best_fit(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
possible_resolutions: torch.Tensor,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Determines the best canvas possible from a list of possible resolutions to, without distortion,
|
||||
resize an image to.
|
||||
|
||||
For each possible resolution, calculates the scaling factors for
|
||||
width and height, and selects the smallest one, which is the limiting side.
|
||||
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
|
||||
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
|
||||
|
||||
If upscaling is possible (any of the scaling factors is greater than 1),
|
||||
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
|
||||
|
||||
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
|
||||
reduce downscaling as much as possible.
|
||||
|
||||
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
|
||||
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
|
||||
has more padding.
|
||||
|
||||
Args:
|
||||
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
|
||||
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
|
||||
row represents a possible resolution (height, width).
|
||||
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
|
||||
|
||||
Returns:
|
||||
List[int]: The best resolution [height, width] for the given image.
|
||||
|
||||
Example:
|
||||
>>> image_size = (200, 300)
|
||||
>>> possible_resolutions = torch.tensor([[224, 672],
|
||||
... [672, 224],
|
||||
... [224, 448],
|
||||
... [448, 224],
|
||||
... [224, 224]])
|
||||
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
|
||||
[224, 448]
|
||||
|
||||
We have:
|
||||
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
|
||||
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
|
||||
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
|
||||
Only one of the scales > 1:
|
||||
upscaling_possible = tensor([1.1200, 1.1200])
|
||||
smallest_rescale = tensor(1.1200)
|
||||
So we pick the resolution with the smallest smallest area:
|
||||
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
|
||||
optimal_canvas = tensor([224, 448])
|
||||
"""
|
||||
|
||||
original_width, original_height = image_size
|
||||
|
||||
# get all possible resolutions heights/widths
|
||||
target_widths, target_heights = (
|
||||
possible_resolutions[:, 0],
|
||||
possible_resolutions[:, 1],
|
||||
)
|
||||
|
||||
# get scaling factors to resize the image without distortion
|
||||
scale_w = target_widths / original_width
|
||||
scale_h = target_heights / original_height
|
||||
|
||||
# get the min scale between width and height (limiting side -> no distortion)
|
||||
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
|
||||
|
||||
# filter only scales that allow upscaling
|
||||
upscaling_options = scales[scales >= 1]
|
||||
if len(upscaling_options) > 0:
|
||||
if resize_to_max_canvas:
|
||||
selected_scale = torch.max(upscaling_options)
|
||||
else:
|
||||
selected_scale = torch.min(upscaling_options)
|
||||
else:
|
||||
# no upscaling possible,
|
||||
# get the minimum downscaling (max scale for scales<1)
|
||||
downscaling_options = scales[scales < 1]
|
||||
selected_scale = torch.max(downscaling_options)
|
||||
|
||||
# get all resolutions that support this scaling factor,
|
||||
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
|
||||
chosen_canvas = possible_resolutions[scales == selected_scale]
|
||||
|
||||
# if there are multiple resolutions,
|
||||
# get the one with minimum area to reduce padding
|
||||
if len(chosen_canvas) > 1:
|
||||
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
|
||||
optimal_idx = torch.argmin(areas)
|
||||
optimal_canvas = chosen_canvas[optimal_idx]
|
||||
else:
|
||||
optimal_canvas = chosen_canvas[0]
|
||||
|
||||
return tuple(optimal_canvas.tolist())
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image,
|
||||
max_num_chunks: int,
|
||||
normalize_img: bool = True,
|
||||
resize_to_max_canvas: bool = False,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Args:
|
||||
image (PIL.Image): Image to be resized.
|
||||
max_num_chunks (int): Maximum number of chunks to split the image into.
|
||||
normalize_img (bool): Whether to normalize the image.
|
||||
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
|
||||
If True, picks the canvas the allows the largest resizing without distortion.
|
||||
If False, downsample as little as possible, including no resizing at all,
|
||||
but never upsample, unless the image is smaller than the patch size.
|
||||
"""
|
||||
assert max_num_chunks > 0
|
||||
assert isinstance(image, Image.Image), type(image)
|
||||
w, h = image.size
|
||||
|
||||
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
|
||||
possible_resolutions = torch.tensor(possible_resolutions)
|
||||
|
||||
best_resolution = self.get_best_fit(
|
||||
image_size=(w, h),
|
||||
possible_resolutions=possible_resolutions,
|
||||
resize_to_max_canvas=resize_to_max_canvas,
|
||||
)
|
||||
|
||||
max_upscaling_size = None if resize_to_max_canvas else self.size
|
||||
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
|
||||
image = self._pad(image, best_resolution)
|
||||
|
||||
image = self.to_tensor(image)
|
||||
|
||||
if normalize_img:
|
||||
image = self.normalize(image)
|
||||
|
||||
ratio_w, ratio_h = (
|
||||
best_resolution[0] // self.size,
|
||||
best_resolution[1] // self.size,
|
||||
)
|
||||
|
||||
image = self._split(image, ratio_w, ratio_h) # type: ignore
|
||||
|
||||
ar = (ratio_h, ratio_w)
|
||||
return image, ar
|
277
llama_stack/models/llama/llama4/prompt_format.md
Normal file
277
llama_stack/models/llama/llama4/prompt_format.md
Normal file
File diff suppressed because one or more lines are too long
306
llama_stack/models/llama/llama4/prompts.py
Normal file
306
llama_stack/models/llama/llama4/prompts.py
Normal file
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from ..prompt_format import (
|
||||
Llama4UseCase,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
)
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||
with open(THIS_DIR.parent / "resources/small_dog.jpg", "rb") as f:
|
||||
img_small_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/dog.jpg", "rb") as f:
|
||||
img_dog = f.read()
|
||||
with open(THIS_DIR.parent / "resources/pasta.jpeg", "rb") as f:
|
||||
img_pasta = f.read()
|
||||
out = []
|
||||
out.extend(
|
||||
[
|
||||
textwrap.dedent(
|
||||
"""
|
||||
# Llama 4 - Prompt Formats
|
||||
## Tokens
|
||||
Here is a list of special tokens that are supported by Llama 4:
|
||||
- `<|begin_of_text|>`: Specifies the start of the prompt
|
||||
- `<|end_of_text|>`: Model will cease to generate more tokens. This token is generated only by the base models.
|
||||
- `<|header_start|>` and `<|header_end|>`: These tokens enclose the role for a particular message. The possible roles are: [system, user and assistant].
|
||||
- `<|eot|>`: End of turn. Represents when the model has determined that it has finished interacting with the user message that initiated its response. This is used in two scenarios:
|
||||
- at the end of a direct interaction between the model and the user
|
||||
- at the end of multiple interactions between the model and any available tools
|
||||
This token signals to the executor that the model has finished generating a response.
|
||||
- `<|image_start|>` and `<|image_end|>`: These tokens enclose the image data in the prompt.
|
||||
- `<|patch|>`: This token represents a piece of the tile/
|
||||
- `<|tile_y_separator|>` and `<|tile_x_separator|>`: These tokens are used to separate the y and x tiles of an image
|
||||
- `<|image|>`: In the new architecture, this token now separates the regular sized image information from a downsized version of it that fits in a single tile. The longer side is used for calculating the scale factor and the rest is padded to fit the tile.
|
||||
"""
|
||||
),
|
||||
textwrap.dedent(
|
||||
"""
|
||||
There are 3 different roles that are supported by Llama 4
|
||||
- `system`: Sets the context in which to interact with the AI model. It typically includes rules, guidelines, or necessary information that helps the model respond effectively.
|
||||
- `user`: Represents the human interacting with the model. It includes the inputs, commands, and questions to the model.
|
||||
- `assistant`: Represents the response generated by the AI model based on the context provided in the `system`, `tool` and `user` prompts.
|
||||
"""
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if base_model:
|
||||
out.extend(
|
||||
[
|
||||
"# Llama 4 Base Model",
|
||||
Llama4UseCase(
|
||||
title="Text completion - Paris information",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[TextCompletionContent(content="The capital of France is Paris")],
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Text completion - The color of the sky",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[
|
||||
TextCompletionContent(content="The color of the sky is blue but sometimes it can also be")
|
||||
],
|
||||
notes="",
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Text completion - Translation example",
|
||||
description="Text completion for Llama 4 base model uses this format.",
|
||||
dialogs=[
|
||||
TextCompletionContent(
|
||||
content="""apple is pomme,
|
||||
bannana is banane,
|
||||
cherry is"""
|
||||
)
|
||||
],
|
||||
notes="",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
out.extend(
|
||||
[
|
||||
"# Llama 4 Instruct Model",
|
||||
Llama4UseCase(
|
||||
title="Simple User and assistant conversation",
|
||||
description="Here is a regular multi-turn user assistant conversation and how its formatted.",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(role="system", content="You are a helpful assistant"),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="Answer who are you in the form of jeopardy?",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes="",
|
||||
max_gen_len=512,
|
||||
),
|
||||
"# Image prompt format",
|
||||
Llama4UseCase(
|
||||
title="Single image prompt format - small image",
|
||||
description="This example passes an image that is smaller than the tile size, to show the tile separator tokens are not needed",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_small_dog)),
|
||||
RawTextItem(text="Describe this image in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="""Notice the structure of the image section:
|
||||
```
|
||||
<|image_start|><|image|><|patch|>...<|patch|><|image_end|>
|
||||
```
|
||||
This is due to the image being smaller than the tile size.
|
||||
""",
|
||||
max_gen_len=512,
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Single image prompt format",
|
||||
description="Here is an example of how to pass an image to the model",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_dog)),
|
||||
RawTextItem(text="Describe this image in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="""With a bigger image, the image will include the tile separator tokens. Additionally, the image tag now separates a scaled down version of the image from the regular sized image.
|
||||
```
|
||||
<|image_start|><|patch|>...<|patch|><|tile_x_separator|><|patch|>...<|patch|><|tile_y_separator|><|patch|>...<|patch|><|image|><|patch|>...<|patch|><|image_end|>
|
||||
```
|
||||
""",
|
||||
max_gen_len=1024,
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Multiple images prompt format",
|
||||
description="Here is an example of how to pass an image to the model",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content=[
|
||||
RawMediaItem(data=BytesIO(img_dog)),
|
||||
RawMediaItem(data=BytesIO(img_pasta)),
|
||||
RawTextItem(text="Describe these images in two sentences"),
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
notes="With multiple images, each one is encapsulated in their corresponding image tags.",
|
||||
max_gen_len=4096,
|
||||
),
|
||||
"# Tool calling\nWe are continuing the format for zero shot function calling used in previous versions of Llama. All available functions can be provided either in the system message or in the user message.",
|
||||
Llama4UseCase(
|
||||
title="Zero shot function calling - system message",
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="system",
|
||||
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
||||
also point it out. You should only return the function call in tools call sections.
|
||||
|
||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||
You SHOULD NOT include any other text in the response.
|
||||
|
||||
Here is a list of functions in JSON format that you can invoke.
|
||||
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather info for places",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"city"
|
||||
],
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city to get the weather for"
|
||||
},
|
||||
"metric": {
|
||||
"type": "string",
|
||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
||||
"default": "celsius"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""",
|
||||
),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="What is the weather in SF and Seattle?",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes=textwrap.dedent(
|
||||
"""
|
||||
- The output supports multiple, and parallel tool calls natively
|
||||
- JSON format for defining the functions in the system prompt is similar to Llama3.1
|
||||
"""
|
||||
),
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Zero shot function calling - user message",
|
||||
description=textwrap.dedent(
|
||||
"""
|
||||
Similar to the above example, you can also provide information for all the available tools in the user message.
|
||||
"""
|
||||
),
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="""Questions: Can you retrieve the details for the user with the ID 7890, who has black as their special request?
|
||||
Here is a list of functions in JSON format that you can invoke:
|
||||
[
|
||||
{
|
||||
"name": "get_user_info",
|
||||
"description": "Retrieve details for a specific user by their unique identifier. Note that the provided function is in Python 3 syntax.",
|
||||
"parameters": {
|
||||
"type": "dict",
|
||||
"required": [
|
||||
"user_id"
|
||||
],
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "integer",
|
||||
"description": "The unique identifier of the user. It is used to fetch the specific user details from the database."
|
||||
},
|
||||
"special": {
|
||||
"type": "string",
|
||||
"description": "Any special information or parameters that need to be considered while fetching user details.",
|
||||
"default": "none"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Should you decide to return the function call(s), put them in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]
|
||||
|
||||
You SHOULD NOT include any other text in the response.""",
|
||||
),
|
||||
]
|
||||
],
|
||||
notes=textwrap.dedent(
|
||||
"""
|
||||
- The tool call format for the model is the same whether your function calls are provided in the system or user message.
|
||||
"""
|
||||
),
|
||||
),
|
||||
Llama4UseCase(
|
||||
title="Tool calling with custom formats",
|
||||
description=textwrap.dedent(
|
||||
"""
|
||||
Here is an example of how you could also write custom instructions for model to do zero shot tool calling.
|
||||
In this example, we define a custom tool calling format using the `<function>` tag.
|
||||
"""
|
||||
),
|
||||
dialogs=[
|
||||
[
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="""You have access to the following functions:\nUse the function 'trending_songs' to 'Returns the trending songs on a Music site':\n{"name": "trending_songs", "description": "Returns the trending songs on a Music site", "parameters": {"genre": {"description": "The genre of the songs to return", "param_type": "str", "required": false}, "n": {"description": "The number of songs to return", "param_type": "int", "required": true}}}\n\nThink very carefully before calling functions.\nIf you choose to call a function ONLY reply in the following format with no prefix or suffix:\n\n<function=example_function_name>{"example_name": "example_value"}</function>
|
||||
Reminder:
|
||||
- If looking for real time information use relevant functions before falling back to brave_search
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line<|eot_id|>""",
|
||||
),
|
||||
RawMessage(
|
||||
role="user",
|
||||
content="Use tools to get latest trending songs",
|
||||
),
|
||||
]
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return out
|
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
225
llama_stack/models/llama/llama4/quantization/loader.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from ...quantize_impls import ffn_swiglu
|
||||
|
||||
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
|
||||
|
||||
def experts_batched_swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor, # (e, g, D)
|
||||
w1: Tensor, # (e, D, F)
|
||||
w3: Tensor, # (e, D, F)
|
||||
w2: Tensor, # (e, F, D)
|
||||
) -> torch.Tensor:
|
||||
from ...quantize_impls import bmm_nt
|
||||
|
||||
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
|
||||
return bmm_nt(middle_out_egF, w2)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
use_rich_progress: bool = True,
|
||||
) -> Transformer:
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
Int4ScaledWeights,
|
||||
load_fp8,
|
||||
load_int4,
|
||||
quantize_fp8,
|
||||
quantize_int4,
|
||||
)
|
||||
|
||||
rank = get_model_parallel_rank()
|
||||
|
||||
def should_quantize_block(block: nn.Module) -> bool:
|
||||
if not isinstance(block, TransformerBlock):
|
||||
return False
|
||||
|
||||
is_moe = isinstance(block.feed_forward, MoE)
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
# skip quantization on first and last layers
|
||||
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
|
||||
return is_moe
|
||||
|
||||
use_rich_progress = use_rich_progress and rank == 0
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||
if os.path.isfile(int4_scales_path):
|
||||
log_status(f"Rank {rank}: Loading int4 scales")
|
||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = int4_scales[key]
|
||||
return load_int4(
|
||||
weight,
|
||||
scale,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, output_device=torch.device("cuda"))
|
||||
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
log_status(f"Rank {rank}: Loading fp8 scales")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = fp8_scales[key]
|
||||
return load_fp8(
|
||||
weight,
|
||||
scale,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
else:
|
||||
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
|
||||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
processed_blocks = 0
|
||||
try:
|
||||
if use_rich_progress:
|
||||
progress.start()
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if not should_quantize_block(block):
|
||||
continue
|
||||
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(
|
||||
f"{prefix}.experts.{key}",
|
||||
param.transpose(1, 2).contiguous(),
|
||||
),
|
||||
)
|
||||
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
# Quantize shared experts
|
||||
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.shared_expert, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
||||
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
||||
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
|
||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||
|
||||
param_count = 0
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
param_count += 1
|
||||
|
||||
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
|
||||
finally:
|
||||
if use_rich_progress:
|
||||
progress.stop()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||
def logging_callbacks(
|
||||
use_rich_progress: bool,
|
||||
rank: int,
|
||||
model: Transformer,
|
||||
should_quantize_block: Callable[[nn.Module], bool],
|
||||
):
|
||||
console = None
|
||||
if use_rich_progress:
|
||||
from rich.console import Console
|
||||
|
||||
console = Console(highlight=False)
|
||||
|
||||
def log_status(message: str) -> None:
|
||||
if use_rich_progress:
|
||||
console.print(message)
|
||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
|
||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||
progress = None
|
||||
if use_rich_progress:
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
|
||||
progress = Progress(
|
||||
SpinnerColumn(),
|
||||
BarColumn(complete_style="green", finished_style="bright_green"),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
TextColumn("ETA:"),
|
||||
TimeRemainingColumn(),
|
||||
TextColumn("[bold]{task.fields[status]}"),
|
||||
console=console,
|
||||
expand=True,
|
||||
)
|
||||
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
||||
|
||||
def update_status(message: Optional[str], completed: Optional[int] = None) -> None:
|
||||
if use_rich_progress:
|
||||
if message is not None:
|
||||
progress.update(task_id, status=message)
|
||||
if completed is not None:
|
||||
progress.update(task_id, completed=completed)
|
||||
elif rank == 0 and completed and completed % 10 == 0:
|
||||
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
||||
|
||||
return progress, log_status, update_status
|
200000
llama_stack/models/llama/llama4/tokenizer.model
Executable file
200000
llama_stack/models/llama/llama4/tokenizer.model
Executable file
File diff suppressed because it is too large
Load diff
270
llama_stack/models/llama/llama4/tokenizer.py
Normal file
270
llama_stack/models/llama/llama4/tokenizer.py
Normal file
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
|
||||
_INSTANCE = None
|
||||
|
||||
|
||||
def get_reserved_special_tokens(name, count, start_index=0):
|
||||
return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)]
|
||||
|
||||
|
||||
# 200005, ..., 200079
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
||||
"<|header_start|>",
|
||||
"<|header_end|>",
|
||||
"<|eom|>",
|
||||
"<|eot|>",
|
||||
"<|step|>",
|
||||
"<|text_post_train_reserved_special_token_0|>",
|
||||
"<|text_post_train_reserved_special_token_1|>",
|
||||
"<|text_post_train_reserved_special_token_2|>",
|
||||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 8
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
LLAMA4_VISION_SPECIAL_TOKENS = [
|
||||
"<|image_start|>",
|
||||
"<|image_end|>",
|
||||
"<|vision_reserved_special_token_0|>",
|
||||
"<|vision_reserved_special_token_1|>",
|
||||
"<|tile_x_separator|>",
|
||||
"<|tile_y_separator|>",
|
||||
"<|vision_reserved_special_token_2|>",
|
||||
"<|vision_reserved_special_token_3|>",
|
||||
"<|vision_reserved_special_token_4|>",
|
||||
"<|vision_reserved_special_token_5|>",
|
||||
"<|image|>",
|
||||
"<|vision_reserved_special_token_6|>",
|
||||
"<|patch|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
# 201134, ..., 201143
|
||||
LLAMA4_REASONING_SPECIAL_TOKENS = [
|
||||
"<|reasoning_reserved_special_token_0|>",
|
||||
"<|reasoning_reserved_special_token_1|>",
|
||||
"<|reasoning_reserved_special_token_2|>",
|
||||
"<|reasoning_reserved_special_token_3|>",
|
||||
"<|reasoning_reserved_special_token_4|>",
|
||||
"<|reasoning_reserved_special_token_5|>",
|
||||
"<|reasoning_reserved_special_token_6|>",
|
||||
"<|reasoning_reserved_special_token_7|>",
|
||||
"<|reasoning_thinking_start|>",
|
||||
"<|reasoning_thinking_end|>",
|
||||
]
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = (
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
|
||||
)
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_suffix|>",
|
||||
]
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
|
||||
"""
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 2048
|
||||
|
||||
O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
global _INSTANCE
|
||||
|
||||
if _INSTANCE is None:
|
||||
_INSTANCE = Tokenizer(os.path.join(os.path.dirname(__file__), "tokenizer.model"))
|
||||
return _INSTANCE
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
"""
|
||||
Initializes the Tokenizer with a Tiktoken model.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the Tiktoken model file.
|
||||
"""
|
||||
assert os.path.isfile(model_path), model_path
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
|
||||
special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
|
||||
assert len(set(special_tokens)) == len(special_tokens)
|
||||
assert len(special_tokens) <= self.num_reserved_special_tokens
|
||||
|
||||
reserved_tokens = [
|
||||
f"<|reserved_special_token_{i}|>" for i in range(self.num_reserved_special_tokens - len(special_tokens))
|
||||
]
|
||||
special_tokens = special_tokens + reserved_tokens
|
||||
|
||||
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=self.O200K_PATTERN,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
|
||||
self.n_words: int = num_base_tokens + len(special_tokens)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
|
||||
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
|
||||
|
||||
self.pad_id: int = self.special_tokens["<|finetune_right_pad|>"]
|
||||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
|
||||
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
self.special_tokens["<|eot|>"],
|
||||
]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
s: str,
|
||||
*,
|
||||
bos: bool,
|
||||
eos: bool,
|
||||
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = (),
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
s (str): The input string to be encoded.
|
||||
bos (bool): Whether to prepend the beginning-of-sequence token.
|
||||
eos (bool): Whether to append the end-of-sequence token.
|
||||
allowed_special ("all"|set[str]): allowed special tokens in string
|
||||
disallowed_special ("all"|set[str]): special tokens that raise an error when in string
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
|
||||
By default, setting disallowed_special=() encodes a string by ignoring
|
||||
special tokens. Specifically:
|
||||
- Setting `disallowed_special` to () will cause all text corresponding
|
||||
to special tokens to be encoded as natural text (insteading of raising
|
||||
an error).
|
||||
- Setting `allowed_special` to "all" will treat all text corresponding
|
||||
to special tokens to be encoded as special tokens.
|
||||
"""
|
||||
if allowed_special is None:
|
||||
allowed_special = set()
|
||||
assert type(s) is str
|
||||
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
t: List[int] = []
|
||||
for substr in substrs:
|
||||
t.extend(
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
if bos:
|
||||
t.insert(0, self.bos_id)
|
||||
if eos:
|
||||
t.append(self.eos_id)
|
||||
return t
|
||||
|
||||
def decode(self, t: Sequence[int]) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
t (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
|
||||
return self.model.decode(cast(List[int], t))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
209
llama_stack/models/llama/llama4/vision/embedding.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
|
||||
from ..args import VisionArgs
|
||||
from .encoder import VisionEncoder
|
||||
|
||||
|
||||
class PixelShuffle(nn.Module):
|
||||
def __init__(self, ps_ratio):
|
||||
super().__init__()
|
||||
self.ps_ratio = ps_ratio
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, N, C], N = number of patches
|
||||
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
||||
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
||||
hh = ww = int(math.sqrt(x.shape[1]))
|
||||
x = x.reshape(x.shape[0], hh, ww, -1)
|
||||
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
||||
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
||||
return pixel_shuffle_patches
|
||||
|
||||
|
||||
def pixel_shuffle_op(input_x, ps_ratio):
|
||||
n, w, h, c = input_x.size()
|
||||
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
input_x = input_x.view(
|
||||
n,
|
||||
int(h * ps_ratio),
|
||||
int(w * ps_ratio),
|
||||
int(c / (ps_ratio * ps_ratio)),
|
||||
)
|
||||
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
||||
return input_x
|
||||
|
||||
|
||||
class SimpleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
bias: bool = True,
|
||||
dropout: float = 0.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.non_linearity(self.c_proj(hidden))
|
||||
|
||||
|
||||
class PixelShuffleMLP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ps_ratio: float,
|
||||
input_dim: int,
|
||||
output_dim: int = 4096,
|
||||
add_fc: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
||||
self.mlp = SimpleMLP(
|
||||
int(input_dim // (ps_ratio**2)),
|
||||
output_dim,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
self.fc = nn.Identity()
|
||||
if add_fc:
|
||||
self.fc = ColumnParallelLinear(
|
||||
output_dim,
|
||||
output_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
||||
encoded_patches = self.pixel_shuffle(encoded_patches)
|
||||
return self.fc(self.mlp(encoded_patches))
|
||||
|
||||
|
||||
class VisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, args: VisionArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
image_size = args.image_size
|
||||
patch_size = args.patch_size
|
||||
self.vision_encoder = VisionEncoder(
|
||||
image_size=(image_size.height, image_size.width),
|
||||
patch_size=(patch_size.height, patch_size.width),
|
||||
dim=args.dim,
|
||||
layers=args.n_layers,
|
||||
heads=args.n_heads,
|
||||
mlp_ratio=args.mlp_ratio,
|
||||
)
|
||||
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
||||
self.vision_adapter = PixelShuffleMLP(
|
||||
ps_ratio=args.pixel_shuffle_ratio,
|
||||
input_dim=args.dim,
|
||||
output_dim=args.output_dim,
|
||||
)
|
||||
|
||||
self.output_dim = args.output_dim
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
original_sd = self.state_dict()
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
||||
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
||||
|
||||
def _get_empty_sequence(self, h):
|
||||
return torch.zeros(
|
||||
h.shape[0],
|
||||
h.shape[1],
|
||||
self.output_dim,
|
||||
device=h.device,
|
||||
dtype=h.dtype,
|
||||
)
|
||||
|
||||
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
||||
# each image is a tensor of shape [num_tiles, C, H, W]
|
||||
def forward(
|
||||
self,
|
||||
image_batch: List[List[torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
h_ref: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
images_flattened = [image for sample in image_batch for image in sample]
|
||||
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
||||
embedding = self.vision_encoder(images_flattened)
|
||||
projected_embedding = self.vision_adapter(embedding)
|
||||
|
||||
h_image = self._get_empty_sequence(h_ref)
|
||||
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
||||
|
||||
|
||||
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
||||
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
||||
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
||||
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
||||
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
||||
|
||||
assert not torch.isnan(encoded_patches_proj).any()
|
||||
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
||||
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
||||
)
|
||||
|
||||
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
||||
for index in range(h_image.size(0)):
|
||||
encoded_patches_per_sample = encoded_patches_list[index]
|
||||
sample_image_mask = image_mask[index]
|
||||
|
||||
if encoded_patches_per_sample.numel() == 0:
|
||||
continue
|
||||
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
||||
-1, encoded_patches_per_sample.size(-1)
|
||||
)
|
||||
|
||||
n_tokens_to_fill = sample_image_mask.sum()
|
||||
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
||||
|
||||
h_image[index].masked_scatter_(
|
||||
sample_image_mask.expand(-1, h_image.size(-1)),
|
||||
encoded_patches_per_sample[:n_tokens_to_fill],
|
||||
)
|
||||
|
||||
return h_image
|
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
411
llama_stack/models/llama/llama4/vision/encoder.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fairscale.nn.model_parallel.initialize as fs_init
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from torch import einsum
|
||||
|
||||
from ..args import ModelArgs
|
||||
from ..model import Attention
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class ColumnParallelConv2dPatch(torch.nn.Module):
|
||||
"""Conv2D Patching layer with model parallelism.
|
||||
Column parallel over unfolded input.
|
||||
Arguments:
|
||||
in_channels: Input channels.
|
||||
out_channels: Output channels.
|
||||
kernel_size: Size of convolution kernel.
|
||||
stride (default 1): Stride for convolution.
|
||||
bias (default False): Use bias in Conv2d.
|
||||
Input: (bsz, in_channels, height, width)
|
||||
Output: (bsz, num_tokens, out_channels)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
bias: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
||||
self._linear = ColumnParallelLinear(
|
||||
in_channels * kernel_size[0] * kernel_size[1],
|
||||
out_channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self._unfold(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self._linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class _FeedForward(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
dropout: float,
|
||||
act_layer: Callable = nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
# layers
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
dim,
|
||||
hidden_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
init_method=lambda x: x,
|
||||
)
|
||||
self.non_linearity = act_layer()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.c_fc(x)
|
||||
hidden = self.non_linearity(hidden)
|
||||
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
||||
return self.c_proj(hidden)
|
||||
|
||||
|
||||
class _TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert d_model % n_head == 0
|
||||
self.n_heads = n_head
|
||||
self.head_dim = d_model // self.n_heads
|
||||
|
||||
attn_args = ModelArgs(
|
||||
dim=d_model,
|
||||
head_dim=self.head_dim,
|
||||
n_heads=self.n_heads,
|
||||
n_kv_heads=self.n_heads,
|
||||
)
|
||||
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = _FeedForward(
|
||||
dim=d_model,
|
||||
hidden_dim=int(mlp_ratio * d_model),
|
||||
dropout=0.0,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.gated = gated
|
||||
if gated:
|
||||
self.gate_attn = nn.Parameter(torch.zeros(1))
|
||||
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def attention(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
freq_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
||||
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
||||
|
||||
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
|
||||
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class _Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: Callable = nn.GELU,
|
||||
gated: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
_TransformerBlock(
|
||||
d_model=dim,
|
||||
n_head=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
gated=gated,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
|
||||
out = []
|
||||
for idx, r in enumerate(self.resblocks):
|
||||
if return_intermediate is not None and idx in return_intermediate:
|
||||
out.append(x)
|
||||
x = r(x, mask=mask, freq_cis=freq_cis)
|
||||
if return_intermediate is not None:
|
||||
return x, torch.stack(out, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class PackingIndex:
|
||||
Z = 0 # Z (time) coordinate of the token in the original sample
|
||||
Y = 1 # Y (height) coordinate of the token in the original sample
|
||||
X = 2 # X (width) coordinate of the token in the original sample
|
||||
TIME = 3 # Total number of time units (frames) in the original sample
|
||||
HEIGHT = 4 # Height of the original sample
|
||||
WIDTH = 5 # Width of the original sample
|
||||
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
|
||||
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
|
||||
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
|
||||
|
||||
# Total size of the enum, remember to update this!
|
||||
NUM_METADATA = 8
|
||||
|
||||
# Note: For padding tokens IDX = -1
|
||||
# For cls tokens, IDX = -2
|
||||
ID_CLS_TOKEN = -2
|
||||
ID_PAD_TOKEN = -1
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: Tuple[int, int],
|
||||
patch_size: Tuple[int, int],
|
||||
dim: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
in_channels: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (
|
||||
self.image_size[0] // self.patch_size[0],
|
||||
self.image_size[1] // self.patch_size[1],
|
||||
)
|
||||
self.conv1 = ColumnParallelConv2dPatch(
|
||||
in_channels=in_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
scale = dim**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
|
||||
|
||||
self.positional_embedding_vlm = nn.Parameter(
|
||||
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
|
||||
)
|
||||
|
||||
self.ln_pre = LayerNorm(dim)
|
||||
self.ln_post = LayerNorm(dim)
|
||||
self.transformer = _Transformer(
|
||||
dim,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
act_layer=nn.GELU,
|
||||
)
|
||||
|
||||
# NOTE: hack for the fixed res
|
||||
image_h, image_w = self.image_size
|
||||
patch_h, patch_w = self.patch_size
|
||||
idx_h, idx_w = image_h // patch_h, image_w // patch_w
|
||||
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
|
||||
img_idx = img_idx.reshape(idx_h * idx_w, 1)
|
||||
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
||||
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
|
||||
|
||||
packed_img_idx = torch.empty(
|
||||
img_idx.shape[0],
|
||||
img_idx.shape[1],
|
||||
PackingIndex.NUM_METADATA - 1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
|
||||
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
|
||||
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
|
||||
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
|
||||
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
|
||||
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
|
||||
self.packed_img_idx = packed_img_idx # for positional embedding load hook
|
||||
|
||||
# compute rope freqs
|
||||
rope_freq = self.get_rope_freqs(dim // heads // 2)
|
||||
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
|
||||
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
|
||||
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
||||
# disable RoPE for padding and cls tokens
|
||||
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
|
||||
# compute complex freqs
|
||||
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
||||
# xlf automatically broadcasts
|
||||
self.freq_cis = self.freq_cis.squeeze(0)
|
||||
self.n_heads = heads // fs_init.get_model_parallel_world_size()
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def get_rope_freqs(self, dim, theta=10000):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
return freqs
|
||||
|
||||
@torch.amp.autocast("cuda", enabled=False)
|
||||
def compute_rope_freqs(self, freqs, t):
|
||||
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||
freqs = freqs.repeat_interleave(2, dim=-1)
|
||||
return freqs
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool = True,
|
||||
missing_keys: List[str] = None,
|
||||
unexpected_keys: List[str] = None,
|
||||
error_msgs: List[str] = None,
|
||||
return_state_dict: bool = False,
|
||||
) -> None:
|
||||
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
||||
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
|
||||
raise ValueError(
|
||||
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
|
||||
)
|
||||
|
||||
batch_size, token_per_image, _ = self.packed_img_idx.shape
|
||||
# Input points for idx are [x, y, w, h]
|
||||
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
|
||||
total_windows, window_size, _ = idx.shape
|
||||
|
||||
# Grid values are [-1, 1] and coords are w, h
|
||||
grid = (
|
||||
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
|
||||
)[None, ...]
|
||||
|
||||
# In this mode, cls token has no position embedding
|
||||
if orig_pos_embed is not None:
|
||||
posemb = (
|
||||
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
|
||||
sample = F.grid_sample(
|
||||
posemb, grid, padding_mode="zeros"
|
||||
) # padding tokens / class token will get zero for posemb
|
||||
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
|
||||
sample = torch.where(
|
||||
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
|
||||
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
|
||||
sample,
|
||||
)
|
||||
|
||||
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
|
||||
|
||||
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
|
||||
|
||||
if return_state_dict:
|
||||
return state_dict
|
||||
|
||||
def apply_class_embedding(self, x):
|
||||
x = torch.cat(
|
||||
[
|
||||
x,
|
||||
self.class_embedding.to(x.dtype)
|
||||
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
||||
],
|
||||
dim=1,
|
||||
) # shape = [*, grid ** 2 + 1, width]
|
||||
return x
|
||||
|
||||
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
|
||||
if images.ndim == 5:
|
||||
num_concurrent_media = 1
|
||||
bsz, num_chunks, nch, h, w = images.shape
|
||||
else:
|
||||
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
|
||||
|
||||
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
# patch embedding
|
||||
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
||||
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
||||
_, ntok, dim = x.shape
|
||||
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
||||
|
||||
# apply cls token
|
||||
x = self.apply_class_embedding(x)
|
||||
ntok += 1
|
||||
|
||||
# apply position embeddings
|
||||
if self.positional_embedding_vlm is not None:
|
||||
x = x + self.positional_embedding_vlm.to(x.dtype)
|
||||
|
||||
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = x.view(bsz * num_concurrent_media, -1, dim)
|
||||
freq_cis = self.freq_cis.to(images.device)
|
||||
|
||||
tf_output = self.transformer(
|
||||
x,
|
||||
freq_cis=freq_cis,
|
||||
)
|
||||
|
||||
int_x = None
|
||||
if isinstance(tf_output, tuple):
|
||||
x, int_x = tf_output
|
||||
else:
|
||||
x = tf_output
|
||||
x = self.ln_post(x)
|
||||
|
||||
# remove cls token output
|
||||
x = x[:, :-1, :]
|
||||
|
||||
# add and output x + int_x features
|
||||
if int_x is not None:
|
||||
int_x = int_x[:, :-1, :, :]
|
||||
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
|
||||
x = torch.cat([x, int_x], dim=-1)
|
||||
|
||||
return x
|
|
@ -27,6 +27,7 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
@ -46,6 +47,7 @@ class UseCase(BaseModel):
|
|||
dialogs: List[List[RawMessage] | TextCompletionContent | str] = Field(default_factory=list)
|
||||
notes: str = ""
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json
|
||||
max_gen_len: int = 512
|
||||
|
||||
def md_format(self):
|
||||
section = textwrap.dedent(
|
||||
|
@ -71,22 +73,22 @@ class UseCase(BaseModel):
|
|||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
max_gen_len=64,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
max_gen_len=512,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
tool_prompt_format=self.tool_prompt_format,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
# FIXME: This is added to undo the hack in chat_formatter where
|
||||
|
@ -115,6 +117,45 @@ class UseCase(BaseModel):
|
|||
return section
|
||||
|
||||
|
||||
class Llama4UseCase(UseCase):
|
||||
def dialogs_to_text(self, generator) -> str:
|
||||
def _code_block(text):
|
||||
return f"```\n{text}\n```"
|
||||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
else:
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.0):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
text += "\n\n"
|
||||
text += "##### Model Response Format\n"
|
||||
text += _code_block(tokenizer.decode(output_tokens))
|
||||
text += "\n\n"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def llama3_1_builtin_tool_call_dialog(tool_prompt_format=ToolPromptFormat.json):
|
||||
interface = LLama31Interface(tool_prompt_format)
|
||||
|
||||
|
|
316
llama_stack/models/llama/quantize_impls.py
Normal file
316
llama_stack/models/llama/quantize_impls.py
Normal file
|
@ -0,0 +1,316 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# type: ignore
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class Int4ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Int4Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Int4Weights(
|
||||
Int4ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Int4Weights",
|
||||
["weight", "scale", "zero_point", "shape"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def int4_row_quantize(
|
||||
x: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_bit = 4 # Number of target bits.
|
||||
to_quant = x.reshape(-1, group_size).to(torch.float)
|
||||
|
||||
max_val = to_quant.amax(dim=1, keepdim=True)
|
||||
min_val = to_quant.amin(dim=1, keepdim=True)
|
||||
max_int = 2**n_bit - 1
|
||||
min_int = 0
|
||||
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
||||
|
||||
zeros = min_val + scales * (2 ** (n_bit - 1))
|
||||
|
||||
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
|
||||
|
||||
# Recenter output and move to int8.
|
||||
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
|
||||
|
||||
# Cutlass expects column major layout for scale and zero point,
|
||||
# so we transpose here and make them contiguous.
|
||||
scales = scales.view(x.shape[0], -1).t().contiguous()
|
||||
zeros = zeros.view(x.shape[0], -1).t().contiguous()
|
||||
|
||||
return out, scales, zeros
|
||||
|
||||
|
||||
def pack_int4(x: torch.Tensor) -> torch.Tensor:
|
||||
# Given int8 x, pack adjacent int4 values into a single int8.
|
||||
low_x = x[:, ::2]
|
||||
high_x = x[:, 1::2]
|
||||
|
||||
# High bits need to left shift, this also masks off extra bits.
|
||||
high_x = torch.bitwise_left_shift(high_x, 4)
|
||||
# Low bits need to have sign bits removed.
|
||||
low_x = torch.bitwise_and(low_x, 0xF)
|
||||
|
||||
# Recombine into a single value with bitwise or.
|
||||
return torch.bitwise_or(low_x, high_x).contiguous()
|
||||
|
||||
|
||||
def bmm_nt(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
if isinstance(w, Fp8ScaledWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
|
||||
return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
|
||||
elif isinstance(w, Int4ScaledWeights):
|
||||
return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
raise ValueError("Unsupported quantization type")
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
|
||||
isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
|
||||
):
|
||||
return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_int4(
|
||||
w: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Quantize [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input high precision tensor to quantize.
|
||||
"""
|
||||
if w.ndim >= 3:
|
||||
wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
|
||||
wq = torch.stack([pack_int4(i) for i in wq], dim=0)
|
||||
scale = torch.stack(scale, dim=0)
|
||||
zero_point = torch.stack(zero_point, dim=0)
|
||||
else:
|
||||
wq, scale, zero_point = int4_row_quantize(w)
|
||||
wq = pack_int4(wq)
|
||||
del w
|
||||
return Int4Weights(
|
||||
weight=wq.to(output_device),
|
||||
scale=scale.to(output_device),
|
||||
zero_point=zero_point.to(output_device),
|
||||
shape=wq.shape,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device=output_device,
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device=output_device),
|
||||
scale=w_scale.to(device=output_device),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_int4(
|
||||
w: Tensor,
|
||||
scale: Tensor,
|
||||
zero_point: Tensor,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Int4Weights:
|
||||
"""Load INT4 [n, k/2] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k/2] input INT4.
|
||||
"""
|
||||
return Int4Weights(
|
||||
weight=w.to(torch.int8).to(device=output_device),
|
||||
scale=scale.to(device=output_device),
|
||||
zero_point=zero_point.to(device=output_device),
|
||||
shape=w.shape,
|
||||
)
|
||||
|
||||
|
||||
def fc_dynamic(
|
||||
x: Tensor,
|
||||
w: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
|
||||
"""
|
||||
if isinstance(w, Int4Weights):
|
||||
y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
|
||||
else:
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_dynamic(
|
||||
x: Tensor,
|
||||
w1: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w3: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
w2: Union[Fp8RowwiseWeights, Int4Weights],
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
assert x.dim() == 3 or x.dim() == 2
|
||||
if x.dim() == 3:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
else:
|
||||
(T, D) = x.shape # noqa: N806
|
||||
B = 1 # noqa: N806
|
||||
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
if x.dim() == 3:
|
||||
return z_.view(B, T, D)
|
||||
else:
|
||||
return z_
|
BIN
llama_stack/models/llama/resources/dog.jpg
Normal file
BIN
llama_stack/models/llama/resources/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
BIN
llama_stack/models/llama/resources/pasta.jpeg
Normal file
BIN
llama_stack/models/llama/resources/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
BIN
llama_stack/models/llama/resources/small_dog.jpg
Normal file
BIN
llama_stack/models/llama/resources/small_dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 41 KiB |
|
@ -4,23 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import (
|
||||
from .sku_types import (
|
||||
CheckpointQuantizationFormat,
|
||||
CoreModelId,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
ModelFamily,
|
||||
)
|
||||
|
||||
LLAMA2_VOCAB_SIZE = 32000
|
||||
|
@ -36,16 +28,13 @@ def resolve_model(descriptor: str) -> Optional[Model]:
|
|||
|
||||
def all_registered_models() -> List[Model]:
|
||||
return (
|
||||
llama2_family() + llama3_family() + llama3_1_family() + llama3_2_family() + llama3_3_family() + safety_models()
|
||||
)
|
||||
|
||||
|
||||
def recommended_sampling_params() -> SamplingParams:
|
||||
return SamplingParams(
|
||||
strategy=TopPSamplingStrategy(
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
)
|
||||
llama2_family()
|
||||
+ llama3_family()
|
||||
+ llama3_1_family()
|
||||
+ llama3_2_family()
|
||||
+ llama3_3_family()
|
||||
+ llama4_family()
|
||||
+ safety_models()
|
||||
)
|
||||
|
||||
|
||||
|
@ -83,13 +72,66 @@ def llama3_3_family() -> List[Model]:
|
|||
]
|
||||
|
||||
|
||||
def llama4_family() -> List[Model]:
|
||||
return [
|
||||
*llama4_base_models(),
|
||||
*llama4_instruct_models(),
|
||||
]
|
||||
|
||||
|
||||
def llama4_base_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e,
|
||||
description="Llama 4 Scout (17b 16 experts model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e,
|
||||
description="Llama 4 Maverick (17b 128 experts model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def llama4_instruct_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
description="Llama 4 Scout (17b 16 experts instruct model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
description="Llama 4 Maverick (17b 128 experts instruct model)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||
pth_file_count=8,
|
||||
arch_args={},
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
description="Llama 4 Maverick (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
pth_file_count=8,
|
||||
variant="fp8",
|
||||
arch_args={},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def llama2_base_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama2_7b,
|
||||
description="Llama 2 7b model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -108,7 +150,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b,
|
||||
description="Llama 2 13b model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -127,7 +168,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b,
|
||||
description="Llama 2 70b model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -169,7 +209,6 @@ def llama3_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b,
|
||||
description="Llama 3 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -193,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b,
|
||||
description="Llama 3.1 8b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -212,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b,
|
||||
description="Llama 3.1 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -232,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -252,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
description="Llama 3.1 405b model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -272,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -296,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b,
|
||||
description="Llama 3.2 1b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
@ -315,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b,
|
||||
description="Llama 3.2 3b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 3072,
|
||||
"n_layers": 28,
|
||||
|
@ -334,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision,
|
||||
description="Llama 3.2 11b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -356,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision,
|
||||
description="Llama 3.2 90b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -383,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_7b_chat,
|
||||
description="Llama 2 7b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -402,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b_chat,
|
||||
description="Llama 2 13b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -421,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b_chat,
|
||||
description="Llama 2 70b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -445,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||
description="Llama 3 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -464,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b_instruct,
|
||||
description="Llama 3 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -488,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||
description="Llama 3.1 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -507,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b_instruct,
|
||||
description="Llama 3.1 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -527,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -547,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
description="Llama 3.1 405b instruct model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -567,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -623,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -642,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -657,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -676,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -694,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||
description="Llama 3.2 1b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_1b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -702,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b_instruct,
|
||||
description="Llama 3.2 3b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_3b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -711,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
|
||||
description="Llama 3.2 11b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -733,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
|
||||
description="Llama 3.2 90b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -760,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||
description="Llama 3.3 70b instruct",
|
||||
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -785,7 +796,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||
description="Llama Guard v3 11b vision system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -809,7 +819,6 @@ def safety_models() -> List[Model]:
|
|||
description="Llama Guard v3 1b 'int4' quantized system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
|
||||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 12,
|
||||
|
@ -827,7 +836,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_1b,
|
||||
description="Llama Guard v3 1b system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
@ -989,12 +997,24 @@ def llama_meta_pth_size(model: Model) -> int:
|
|||
if model.core_model_id not in (
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
):
|
||||
return 0
|
||||
|
||||
if model.pth_file_count == 16:
|
||||
return 51268302389
|
||||
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 60903742309
|
||||
else:
|
||||
return 101470976045
|
||||
if model.model_family == ModelFamily.llama3_1:
|
||||
if model.pth_file_count == 16:
|
||||
return 51268302389
|
||||
elif model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 60903742309
|
||||
else:
|
||||
return 101470976045
|
||||
|
||||
if model.model_family == ModelFamily.llama4:
|
||||
if model.core_model_id == CoreModelId.llama4_maverick_17b_128e:
|
||||
return 100458118386
|
||||
elif model.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
if model.quantization_format == CheckpointQuantizationFormat.fp8_mixed:
|
||||
return 54121549657
|
||||
else:
|
||||
return 100426653046
|
||||
|
|
229
llama_stack/models/llama/sku_types.py
Normal file
229
llama_stack/models/llama/sku_types.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.core_model_id.value
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
|
||||
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
@ -201,3 +202,12 @@ def remote_provider_spec(
|
|||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(str, Enum):
|
||||
OK = "OK"
|
||||
ERROR = "Error"
|
||||
NOT_IMPLEMENTED = "Not Implemented"
|
||||
|
||||
|
||||
HealthResponse = dict[str, Any]
|
||||
|
|
|
@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
|
|||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
|
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
@ -89,7 +89,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
agent_id: str,
|
||||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
|
@ -99,7 +98,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
@ -255,7 +253,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn_messages
|
||||
input_messages = last_turn.input_messages
|
||||
|
||||
turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
|
@ -64,7 +63,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -107,7 +105,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
|
|
@ -5,19 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: Optional[int] = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
|
@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
@ -47,27 +50,19 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"max_seq_len": 4096,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
"model_parallel_size": model_parallel_size,
|
||||
"max_batch_size": max_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
|
||||
config["quantization"] = {
|
||||
"type": "fp8",
|
||||
}
|
||||
return config
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
get_default_tool_prompt_format,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if not isinstance(response_format, JsonSchemaResponseFormat):
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature or 1.0
|
||||
top_p = sampling_params.strategy.top_p or 1.0
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
||||
|
||||
|
||||
def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||
tool_config = request.tool_config
|
||||
if tool_config is not None and tool_config.tool_prompt_format is not None:
|
||||
return tool_config.tool_prompt_format
|
||||
else:
|
||||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||
self.inner_generator = cls.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result
|
|
@ -5,15 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
|
@ -28,18 +33,23 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -48,6 +58,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
augment_content_with_response_format_prompt,
|
||||
chat_completion_request_to_messages,
|
||||
|
@ -55,16 +69,22 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .generators import LlamaGenerator
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log = get_logger(__name__, category="inference")
|
||||
# there's a single model parallel process running serving the model. for now,
|
||||
# we don't support multiple concurrent requests to this process.
|
||||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||
return LlamaGenerator(config, model_id, llama_model)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
@ -77,29 +97,10 @@ class MetaReferenceInferenceImpl(
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(self.config, model_id, llama_model)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = Llama3.build(self.config, model_id, llama_model)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
@ -127,11 +128,58 @@ class MetaReferenceInferenceImpl(
|
|||
if model.model_type == ModelType.embedding:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
||||
# TODO: what is this?! you can't really specify skipping via model metadata
|
||||
# kill this madness
|
||||
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
||||
return model
|
||||
|
||||
await self.load_model(model.identifier, llama_model)
|
||||
return model
|
||||
|
||||
async def load_model(self, model_id, llama_model) -> None:
|
||||
log.info(f"Loading model `{model_id}`")
|
||||
|
||||
builder_params = [self.config, model_id, llama_model]
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=llama_builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||
if llama_model.model_family == ModelFamily.llama4
|
||||
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
||||
),
|
||||
)
|
||||
self.generator.start()
|
||||
else:
|
||||
self.generator = llama_builder_fn(*builder_params)
|
||||
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
log.info("Warming up...")
|
||||
await self.completion(
|
||||
model_id=model_id,
|
||||
content="Hello, world!",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
)
|
||||
await self.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[UserMessage(content="Hi how are you?")],
|
||||
sampling_params=SamplingParams(max_tokens=20),
|
||||
)
|
||||
log.info("Warmed up!")
|
||||
|
||||
def check_model(self, request) -> None:
|
||||
if self.model_id is None or self.llama_model is None:
|
||||
raise RuntimeError(
|
||||
"No avaible model yet, please register your requested model or add your model in the resouces first"
|
||||
)
|
||||
elif request.model != self.model_id:
|
||||
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -161,17 +209,55 @@ class MetaReferenceInferenceImpl(
|
|||
if request.stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
results = await self._nonstream_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
for content in content_batch:
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
results = await self._nonstream_completion(request_batch)
|
||||
return BatchCompletionResponse(batch=results)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
@ -204,37 +290,54 @@ class MetaReferenceInferenceImpl(
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
results = []
|
||||
for token_results in self.generator.completion(request_batch):
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
||||
if content.endswith("<|eot_id|>"):
|
||||
content = content[: -len("<|eot_id|>")]
|
||||
elif content.endswith("<|eom_id|>"):
|
||||
content = content[: -len("<|eom_id|>")]
|
||||
return CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
)
|
||||
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||
state.tokens = state.tokens[:-1]
|
||||
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||
results.append(
|
||||
CompletionResponse(
|
||||
content=content,
|
||||
stop_reason=state.stop_reason,
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
|
@ -269,7 +372,7 @@ class MetaReferenceInferenceImpl(
|
|||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
|
@ -285,39 +388,110 @@ class MetaReferenceInferenceImpl(
|
|||
if request.stream:
|
||||
return self._stream_chat_completion(request)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request)
|
||||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
def impl():
|
||||
tokens = []
|
||||
logprobs = []
|
||||
stop_reason = None
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
if stop_reason is None:
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=logprobs if request.logprobs else None,
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request_batch = []
|
||||
for messages in messages_batch:
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: List[ChatCompletionRequest]
|
||||
) -> List[ChatCompletionResponse]:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
first_request = request_batch[0]
|
||||
|
||||
class ItemState(BaseModel):
|
||||
tokens: List[int] = []
|
||||
logprobs: List[TokenLogProbs] = []
|
||||
stop_reason: StopReason | None = None
|
||||
finished: bool = False
|
||||
|
||||
def impl():
|
||||
states = [ItemState() for _ in request_batch]
|
||||
|
||||
for token_results in self.generator.chat_completion(request_batch):
|
||||
first = token_results[0]
|
||||
if not first.finished and not first.ignore_token:
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||
cprint(first.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{first.token}>", "magenta", end="")
|
||||
|
||||
for result in token_results:
|
||||
idx = result.batch_idx
|
||||
state = states[idx]
|
||||
if state.finished or result.ignore_token:
|
||||
continue
|
||||
|
||||
state.finished = result.finished
|
||||
if first_request.logprobs:
|
||||
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||
|
||||
state.tokens.append(result.token)
|
||||
if result.token == tokenizer.eot_id:
|
||||
state.stop_reason = StopReason.end_of_turn
|
||||
elif result.token == tokenizer.eom_id:
|
||||
state.stop_reason = StopReason.end_of_message
|
||||
|
||||
results = []
|
||||
for state in states:
|
||||
if state.stop_reason is None:
|
||||
state.stop_reason = StopReason.out_of_tokens
|
||||
|
||||
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||
results.append(
|
||||
ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=state.logprobs if first_request.logprobs else None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
async with SEMAPHORE:
|
||||
|
@ -326,6 +500,8 @@ class MetaReferenceInferenceImpl(
|
|||
return impl()
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
def impl():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
@ -340,6 +516,25 @@ class MetaReferenceInferenceImpl(
|
|||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||
cprint(f"<{token_result.token}>", "magenta", end="")
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
text = token_result.text
|
||||
|
||||
if request.logprobs:
|
||||
assert len(token_result.logprobs) == 1
|
||||
|
||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
|
@ -355,10 +550,10 @@ class MetaReferenceInferenceImpl(
|
|||
)
|
||||
continue
|
||||
|
||||
if token_result.text == "<|eot_id|>":
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
text = ""
|
||||
elif token_result.text == "<|eom_id|>":
|
||||
elif token_result.token == tokenizer.eom_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
text = ""
|
||||
else:
|
||||
|
|
|
@ -1,483 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from ..common import TokenResult, model_checkpoint_dir
|
||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .args import ModelArgs
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
Note:
|
||||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
if "DEVICE" in os.environ:
|
||||
device = os.environ.get("DEVICE")
|
||||
if device == "cuda":
|
||||
assert torch.cuda.is_available(), "PyTorch CUDA backend not available"
|
||||
if device == "xpu":
|
||||
assert torch.xpu.is_available(), "PyTorch XPU backend not available"
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.xpu.is_available():
|
||||
device = "xpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
log.info(f"Using {device} device")
|
||||
|
||||
llama_model_id = llama_model.core_model_id.value
|
||||
if not torch.distributed.is_initialized():
|
||||
if device == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
||||
model_parallel_size = llama_model.pth_file_count
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
)
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
if "model" in params:
|
||||
params = params["model"]
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
assert model_args.vocab_size == tokenizer.n_words, (
|
||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
)
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from ..quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from ..quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
|
||||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from ..quantization.hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
add_hadamard_transform_for_spinquant(model)
|
||||
else:
|
||||
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
|
||||
else:
|
||||
if device == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
else:
|
||||
torch.set_default_device(device)
|
||||
if device == "xpu" and torch.xpu.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama3(model, tokenizer, model_args, llama_model_id)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
llama_model: str,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.llama_model = llama_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: LLMInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
logits_processor: Optional["LogitsProcessor"] = None,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
if print_input_tokens:
|
||||
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
|
||||
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
||||
if is_vision:
|
||||
images = model_input.vision.images if model_input.vision is not None else []
|
||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||
|
||||
# the method works for bsz > 1 so add a batch dimension
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=[images],
|
||||
batch_masks=[mask],
|
||||
total_len=total_len,
|
||||
)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens)
|
||||
|
||||
prev_pos = 0
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
if min_prompt_len == total_len:
|
||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||
logits = self.model.forward(tokens, prev_pos)
|
||||
token_logprobs = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor.process_logits(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
||||
)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
model_input = self.formatter.encode_content(request.content)
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=model_input,
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
|
||||
max_gen_len = self.model.params.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(
|
||||
request.messages,
|
||||
request.tool_config.tool_prompt_format,
|
||||
),
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
include_stop_token=True,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
|
||||
class LogitsProcessor:
|
||||
def __init__(self, token_enforcer: TokenEnforcer):
|
||||
self.token_enforcer = token_enforcer
|
||||
self.mask: Optional[torch.Tensor] = None
|
||||
|
||||
def process_logits(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
token_sequence = tokens[0, :].tolist()
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
|
||||
|
||||
if self.mask is not None:
|
||||
self.mask.fill_(-math.inf)
|
||||
else:
|
||||
self.mask = torch.full_like(scores, -math.inf)
|
||||
|
||||
self.mask[:, :, allowed_tokens] = 0
|
||||
scores = scores + self.mask
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor(
|
||||
tokenizer: Tokenizer,
|
||||
vocab_size: int,
|
||||
response_format: Optional[ResponseFormat],
|
||||
) -> Optional["LogitsProcessor"]:
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if response_format.type != ResponseFormatType.json_schema.value:
|
||||
raise ValueError(f"Unsupported response format type {response_format.type}")
|
||||
|
||||
parser = JsonSchemaParser(response_format.json_schema)
|
||||
data = TokenEnforcerTokenizerData(
|
||||
_build_regular_tokens_list(tokenizer, vocab_size),
|
||||
tokenizer.decode,
|
||||
tokenizer.stop_tokens,
|
||||
)
|
||||
token_enforcer = TokenEnforcer(data, parser)
|
||||
return LogitsProcessor(token_enforcer)
|
||||
|
||||
|
||||
def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> List[Tuple[int, str, bool]]:
|
||||
token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
|
||||
regular_tokens = []
|
||||
|
||||
special_token_ids = set(tokenizer.special_tokens.values())
|
||||
for token_idx in range(vocab_size):
|
||||
if token_idx in special_token_ids:
|
||||
continue
|
||||
|
||||
# We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
|
||||
decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
|
||||
decoded_regular = tokenizer.decode([token_idx])
|
||||
is_word_start_token = len(decoded_after_0) > len(decoded_regular)
|
||||
regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
|
||||
return regular_tokens
|
||||
|
||||
|
||||
def _infer_sampling_params(sampling_params: SamplingParams):
|
||||
if isinstance(sampling_params.strategy, GreedySamplingStrategy):
|
||||
temperature = 0.0
|
||||
top_p = 1.0
|
||||
elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
|
||||
temperature = sampling_params.strategy.temperature
|
||||
top_p = sampling_params.strategy.top_p
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
|
||||
return temperature, top_p
|
|
@ -4,23 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Generator
|
||||
from typing import Any, Callable, Generator, List
|
||||
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .llama3.generation import Llama3
|
||||
from .parallel_utils import ModelParallelProcessGroup
|
||||
|
||||
|
||||
|
@ -29,21 +23,20 @@ class ModelRunner:
|
|||
self.llama = llama
|
||||
|
||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||
def __call__(self, req: Any):
|
||||
if isinstance(req, ChatCompletionRequestWithRawContent):
|
||||
return self.llama.chat_completion(req)
|
||||
elif isinstance(req, CompletionRequestWithRawContent):
|
||||
return self.llama.completion(req)
|
||||
def __call__(self, task: Any):
|
||||
if task[0] == "chat_completion":
|
||||
return self.llama.chat_completion(task[1])
|
||||
elif task[0] == "completion":
|
||||
return self.llama.completion(task[1])
|
||||
else:
|
||||
raise ValueError(f"Unexpected task type {type(req)}")
|
||||
raise ValueError(f"Unexpected task type {task[0]}")
|
||||
|
||||
|
||||
def init_model_cb(
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
builder_fn: Callable,
|
||||
params: list[Any],
|
||||
):
|
||||
llama = Llama3.build(config, model_id, llama_model)
|
||||
llama = builder_fn(*params)
|
||||
return ModelRunner(llama)
|
||||
|
||||
|
||||
|
@ -60,25 +53,15 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
model_parallel_size: int,
|
||||
builder_fn: Callable,
|
||||
builder_params: list[Any],
|
||||
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
||||
):
|
||||
self.config = config
|
||||
self.model_id = model_id
|
||||
self.llama_model = llama_model
|
||||
|
||||
# this is a hack because Agent's loop uses this to tokenize and check if input is too long
|
||||
# while the tool-use loop is going
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
checkpoint_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
checkpoint_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
tokenizer_path = os.path.join(checkpoint_dir, "tokenizer.model")
|
||||
self.formatter = ChatFormat(Tokenizer(tokenizer_path))
|
||||
self.model_parallel_size = model_parallel_size
|
||||
self.builder_fn = builder_fn
|
||||
self.builder_params = builder_params
|
||||
self.formatter = formatter
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
@ -87,11 +70,9 @@ class LlamaModelParallelGenerator:
|
|||
self.__exit__(None, None, None)
|
||||
|
||||
def __enter__(self):
|
||||
model_parallel_size = self.llama_model.pth_file_count
|
||||
|
||||
self.group = ModelParallelProcessGroup(
|
||||
model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.config, self.model_id, self.llama_model),
|
||||
self.model_parallel_size,
|
||||
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
||||
)
|
||||
self.group.start()
|
||||
return self
|
||||
|
@ -101,16 +82,16 @@ class LlamaModelParallelGenerator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("completion", req_obj))
|
||||
yield from gen
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
req_obj = deepcopy(request)
|
||||
gen = self.group.run_inference(req_obj)
|
||||
req_obj = deepcopy(request_batch)
|
||||
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||
yield from gen
|
||||
|
|
|
@ -19,7 +19,7 @@ import tempfile
|
|||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Callable, Generator, Literal, Optional, Union
|
||||
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -70,12 +69,12 @@ class CancelSentinel(BaseModel):
|
|||
|
||||
class TaskRequest(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
||||
|
||||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: TokenResult
|
||||
result: List[GenerationResult]
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
@ -332,7 +331,7 @@ class ModelParallelProcessGroup:
|
|||
|
||||
def run_inference(
|
||||
self,
|
||||
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
|
||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
||||
) -> Generator:
|
||||
assert not self.running, "inference already running"
|
||||
|
||||
|
|
|
@ -1,177 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import collections
|
||||
import logging
|
||||
from typing import Optional, Type
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||
|
||||
log.info("Using efficient FP8 operators in FBGEMM.")
|
||||
except ImportError:
|
||||
log.error("No efficient FP8 operators. Please install FBGEMM in fp8_requirements.txt.")
|
||||
raise
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Fp8ScaledWeights:
|
||||
# TODO: Ugly trick so torch allows us to replace parameters
|
||||
# with our custom Fp8Weights instance. Do this properly.
|
||||
@property
|
||||
def __class__(self) -> Type[nn.parameter.Parameter]:
|
||||
return nn.Parameter
|
||||
|
||||
@property
|
||||
def grad_fn(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
|
||||
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
|
||||
class Fp8RowwiseWeights(
|
||||
Fp8ScaledWeights,
|
||||
collections.namedtuple(
|
||||
"Fp8RowwiseWeights",
|
||||
["weight", "scale", "shape", "activation_scale_ub"],
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def ffn_swiglu(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
if isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights):
|
||||
return ffn_swiglu_fp8_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
assert isinstance(w1, Tensor)
|
||||
assert isinstance(w3, Tensor)
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
assert isinstance(w2, Tensor)
|
||||
return (z @ w2.T).view(B, T, D)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def quantize_fp8(
|
||||
w: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Quantize [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input high precision tensor to quantize.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
|
||||
del w
|
||||
return Fp8RowwiseWeights(
|
||||
weight=wq,
|
||||
scale=w_scale,
|
||||
shape=wq.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_fp8(
|
||||
w: Tensor,
|
||||
w_scale: Tensor,
|
||||
fp8_activation_scale_ub: float,
|
||||
) -> Fp8RowwiseWeights:
|
||||
"""Load FP8 [n, k] weight tensor.
|
||||
|
||||
Args:
|
||||
w (Tensor): [n, k] input FP8.
|
||||
fp8_activation_scale_ub (float): Upper bound for activation max.
|
||||
"""
|
||||
activation_scale_ub = torch.tensor(
|
||||
[fp8_activation_scale_ub],
|
||||
dtype=torch.float,
|
||||
device="cuda",
|
||||
)
|
||||
return Fp8RowwiseWeights(
|
||||
weight=w.to(torch.float8_e4m3fn).to(device="cuda"),
|
||||
scale=w_scale.to(device="cuda"),
|
||||
shape=w.shape,
|
||||
activation_scale_ub=activation_scale_ub,
|
||||
)
|
||||
|
||||
|
||||
def fc_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Single w8a8 fc layer with dynamic row-wise scaling.
|
||||
"""
|
||||
if isinstance(w, Fp8RowwiseWeights):
|
||||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
|
||||
del xq
|
||||
return y
|
||||
|
||||
|
||||
def ffn_swiglu_fp8_dynamic(
|
||||
x: Tensor,
|
||||
w1: Fp8RowwiseWeights,
|
||||
w3: Fp8RowwiseWeights,
|
||||
w2: Fp8RowwiseWeights,
|
||||
activation_scale_ub: Optional[Tensor] = None,
|
||||
num_tokens: Optional[Tensor] = None,
|
||||
is_memory_bounded: bool = False,
|
||||
) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
HD_L = w1.shape[0] # noqa: N806
|
||||
assert HD_L == w3.shape[0]
|
||||
x1 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w1,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
x2 = fc_fp8_dynamic(
|
||||
x.view(B * T, D),
|
||||
w3,
|
||||
activation_scale_ub,
|
||||
num_tokens,
|
||||
is_memory_bounded,
|
||||
)
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
del x1, x2
|
||||
|
||||
z_ = fc_fp8_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
|
||||
|
||||
return z_.view(B, T, D)
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# The file gets a special treatment for now?
|
||||
# ruff: noqa: N803
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9,
|
||||
"Skip when H100 is not available",
|
||||
)
|
||||
class FP8Tests(unittest.TestCase):
|
||||
@settings(deadline=None)
|
||||
@given(
|
||||
D=st.sampled_from([4096, 8192]),
|
||||
HD_L=st.sampled_from([1280, 2560]),
|
||||
B=st.sampled_from([1, 2]),
|
||||
T=st.sampled_from([2048, 4096]),
|
||||
UB=st.sampled_from([1000, 10000]),
|
||||
)
|
||||
def test_fp8_ffn(
|
||||
self,
|
||||
D: int, # noqa
|
||||
HD_L: int,
|
||||
B: int,
|
||||
T: int,
|
||||
UB: float,
|
||||
) -> None:
|
||||
x = torch.randn(size=(B, T, D), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
w1 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w3 = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01
|
||||
w2 = torch.randn(size=(D, HD_L), dtype=torch.bfloat16, device="cuda") * 0.1
|
||||
|
||||
x_q = quantize_fp8(x, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w1_q = quantize_fp8(w1, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w3_q = quantize_fp8(w3, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
w2_q = quantize_fp8(w2, UB, mode=FfnQuantizeMode.FP8_ROWWISE)
|
||||
|
||||
def ref_ffn(x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
||||
(B, T, D) = x.shape # noqa: N806
|
||||
(HD_L, D_) = w1.shape # noqa: N806
|
||||
assert D_ == D
|
||||
|
||||
x1 = x.view(B * T, D) @ w1.T
|
||||
x2 = x.view(B * T, D) @ w3.T
|
||||
|
||||
z = torch.nn.functional.silu(x1) * x2
|
||||
return (z @ w2.T).view(B, T, D).to(torch.bfloat16)
|
||||
|
||||
v = ffn_swiglu_fp8_dynamic(x, w1_q, w3_q, w2_q)
|
||||
|
||||
# Fake quant
|
||||
x = x_q.weight.bfloat16() * x_q.scale.unsqueeze(-1)
|
||||
w1 = w1_q.weight.bfloat16() * w1_q.scale.unsqueeze(-1)
|
||||
w3 = w3_q.weight.bfloat16() * w3_q.scale.unsqueeze(-1)
|
||||
w2 = w2_q.weight.bfloat16() * w2_q.scale.unsqueeze(-1)
|
||||
|
||||
v_ref = ref_ffn(x, w1, w3, w2)
|
||||
|
||||
torch.testing.assert_close(v_ref, v, atol=4.0e-3, rtol=4.0e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1,152 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.args import ModelArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama3.model import Transformer, TransformerBlock
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantization.fp8_impls import (
|
||||
quantize_fp8,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_dir: str,
|
||||
tokenizer_path: str,
|
||||
quantized_ckpt_dir: str,
|
||||
max_seq_len: Optional[int] = 512,
|
||||
max_batch_size: Optional[int] = 4,
|
||||
model_parallel_size: Optional[int] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
seed: int = 1,
|
||||
):
|
||||
""" """
|
||||
if not os.path.exists(quantized_ckpt_dir):
|
||||
os.makedirs(quantized_ckpt_dir)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "params.json"),
|
||||
os.path.join(quantized_ckpt_dir, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
os.path.join(ckpt_dir, "tokenizer.model"),
|
||||
os.path.join(quantized_ckpt_dir, "tokenizer.model"),
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
if not model_parallel_is_initialized():
|
||||
if model_parallel_size is None:
|
||||
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
)
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
assert model_args.vocab_size == tokenizer.n_words, (
|
||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
)
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
|
||||
log.info(ckpt_path)
|
||||
assert quantized_ckpt_dir is not None, "QUantized checkpoint directory should not be None"
|
||||
fp8_scales = {}
|
||||
for block in model.layers:
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w1.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w1.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w1_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w3.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w3.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w3_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_weight = quantize_fp8(
|
||||
block.feed_forward.w2.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cpu"),
|
||||
)
|
||||
with torch.inference_mode():
|
||||
block.feed_forward.w2.weight = Parameter(fp8_weight.weight)
|
||||
fp8_scales[f"{block.layer_id}_feed_forward.w2_{get_model_parallel_rank()}"] = fp8_weight.scale
|
||||
|
||||
fp8_scales_path = os.path.join(quantized_ckpt_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
torch.save(fp8_scales, fp8_scales_path)
|
||||
|
||||
ckpt_path = os.path.join(
|
||||
quantized_ckpt_dir,
|
||||
"consolidated.{:02d}.pth".format(get_model_parallel_rank()),
|
||||
)
|
||||
torch.save(model.state_dict(), ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -1,31 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
set -x
|
||||
|
||||
cd $(dirname "$(realpath "$0")")
|
||||
|
||||
MASTER_HOST=$1
|
||||
RUN_ID=$2
|
||||
CKPT_DIR=$3
|
||||
QUANT_CKPT_DIR=$4
|
||||
TOKENIZER_PATH=$5
|
||||
NNODES=$6
|
||||
NPROC=$7
|
||||
|
||||
echo $MASTER_HOST, $RUN_ID, $CKPT_DIR, $QUANT_CKPT_DIR
|
||||
|
||||
NCCL_NET=Socket NCCL_SOCKET_IFNAME=eth TIKTOKEN_CACHE_DIR="" PYTHONPATH="/home/$USER/llama-stack" \
|
||||
torchrun \
|
||||
--nnodes=$NNODES --nproc_per_node=$NPROC \
|
||||
--rdzv_id=$RUN_ID \
|
||||
--rdzv_conf='timeout=120' \
|
||||
--rdzv_backend=c10d \
|
||||
--rdzv_endpoint="${MASTER_HOST}:29502" \
|
||||
quantize_checkpoint.py $CKPT_DIR $TOKENIZER_PATH $QUANT_CKPT_DIR
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -23,6 +24,10 @@ from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
|||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
|
@ -30,6 +35,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
|
@ -74,3 +81,25 @@ class SentenceTransformersInferenceImpl(
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: List[InterleavedContent],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
):
|
||||
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: List[List[Message]],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
):
|
||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||
|
|
|
@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue