refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
register_schema(ToolCall)
register_schema(ToolParamDefinition)
register_schema(ToolDefinition)
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class LogProbConfig(BaseModel):
"""
@ -48,18 +97,18 @@ class QuantizationType(Enum):
"""Type of model quantization to run inference with.
:cvar bf16: BFloat16 typically this means _no_ quantization
:cvar fp8: 8-bit floating point quantization
:cvar int4: 4-bit integer quantization
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
:cvar int4_mixed: 4-bit integer quantization with mixed precision
"""
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
@json_schema_type
class Fp8QuantizationConfig(BaseModel):
type: Literal["fp8"] = "fp8"
type: Literal["fp8_mixed"] = "fp8_mixed"
@json_schema_type
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
"""
type: Literal["int4"] = "int4"
type: Literal["int4_mixed"] = "int4_mixed"
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"