forked from phoenix-oss/llama-stack-mirror
refactor: move all llama code to models/llama out of meta reference (#1887)
# What does this PR do? Move around bits. This makes the copies from llama-models _much_ easier to maintain and ensures we don't entangle meta-reference specific tidbits into llama-models code even by accident. Also, kills the meta-reference-quantized-gpu distro and rolls quantization deps into meta-reference-gpu. ## Test Plan ``` LLAMA_MODELS_DEBUG=1 \ with-proxy llama stack run meta-reference-gpu \ --env INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct \ --env INFERENCE_CHECKPOINT_DIR=<DIR> \ --env MODEL_PARALLEL_SIZE=4 \ --env QUANTIZATION_TYPE=fp8_mixed ``` Start a server with and without quantization. Point integration tests to it using: ``` pytest -s -v tests/integration/inference/test_text_inference.py \ --stack-config http://localhost:8321 --text-model meta-llama/Llama-4-Scout-17B-16E-Instruct ```
This commit is contained in:
parent
c52ccc4bbd
commit
530d4bdfe1
85 changed files with 1267 additions and 1683 deletions
|
@ -11,19 +11,18 @@ import torch
|
|||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
|
@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
from .llama3.generation import Llama3
|
||||
from .llama4.generation import Llama4
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||
class Llama4Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
|
@ -134,11 +132,13 @@ class Llama4Generator:
|
|||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
quantization_mode = "fp8_mixed"
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
quantization_mode = "int4_mixed"
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
|
@ -148,7 +148,7 @@ class Llama4Generator:
|
|||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=llama_model.pth_file_count,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
|
@ -166,8 +166,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -178,7 +178,8 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -190,8 +191,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -202,20 +203,46 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
|
||||
class Llama3Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
self.inner_generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
|
@ -231,8 +258,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -243,7 +270,8 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -255,8 +283,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -267,4 +295,5 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue