forked from phoenix-oss/llama-stack-mirror
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
c52ccc4bbd
commit
530d4bdfe1
85 changed files with 1267 additions and 1683 deletions
|
@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
"""
|
||||
|
@ -48,18 +97,18 @@ class QuantizationType(Enum):
|
|||
"""Type of model quantization to run inference with.
|
||||
|
||||
:cvar bf16: BFloat16 typically this means _no_ quantization
|
||||
:cvar fp8: 8-bit floating point quantization
|
||||
:cvar int4: 4-bit integer quantization
|
||||
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
|
||||
:cvar int4_mixed: 4-bit integer quantization with mixed precision
|
||||
"""
|
||||
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
type: Literal["fp8"] = "fp8"
|
||||
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
|
|||
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
||||
"""
|
||||
|
||||
type: Literal["int4"] = "int4"
|
||||
type: Literal["int4_mixed"] = "int4_mixed"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue