refactor: move all llama code to models/llama out of meta reference (#1887)

# What does this PR do?

Move around bits. This makes the copies from llama-models _much_ easier
to maintain and ensures we don't entangle meta-reference specific
tidbits into llama-models code even by accident.

Also, kills the meta-reference-quantized-gpu distro and rolls
quantization deps into meta-reference-gpu.

## Test Plan

```
LLAMA_MODELS_DEBUG=1 \
  with-proxy llama stack run meta-reference-gpu \
  --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \
   --env INFERENCE_CHECKPOINT_DIR=<DIR> \
   --env MODEL_PARALLEL_SIZE=4 \
   --env QUANTIZATION_TYPE=fp8_mixed
```

Start a server with and without quantization. Point integration tests to
it using:

```
pytest -s -v  tests/integration/inference/test_text_inference.py \
   --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct
```
This commit is contained in:
Ashwin Bharambe 2025-04-07 15:03:58 -07:00 committed by GitHub
parent c52ccc4bbd
commit 530d4bdfe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 1267 additions and 1683 deletions

View file

@ -11,19 +11,18 @@ import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
GreedySamplingStrategy,
JsonSchemaResponseFormat,
ResponseFormat,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
Model,
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import MetaReferenceInferenceConfig
from .inference import resolve_model
from .llama3.generation import Llama3
from .llama4.generation import Llama4
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
class Llama4Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
@ -134,11 +132,13 @@ class Llama4Generator:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
if isinstance(config.quantization, Fp8QuantizationConfig):
quantization_mode = "fp8_mixed"
elif isinstance(config.quantization, Int4QuantizationConfig):
quantization_mode = "int4_mixed"
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
@ -148,7 +148,7 @@ class Llama4Generator:
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=llama_model.pth_file_count,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
@ -166,8 +166,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_content(request.content),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -178,7 +178,8 @@ class Llama4Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
def chat_completion(
self,
@ -190,8 +191,8 @@ class Llama4Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -202,20 +203,46 @@ class Llama4Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build(
config=config,
model_id=model_id,
llama_model=llama_model,
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
@ -231,8 +258,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_content(request.content),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -243,7 +270,8 @@ class Llama3Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]
def chat_completion(
self,
@ -255,8 +283,8 @@ class Llama3Generator:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
yield from self.inner_generator.generate(
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
@ -267,4 +295,5 @@ class Llama3Generator:
self.args.vocab_size,
request.response_format,
),
)
):
yield result[0]