forked from phoenix-oss/llama-stack-mirror
feat: add batch inference API to llama stack inference (#1945)
# What does this PR do? This PR adds two methods to the Inference API: - `batch_completion` - `batch_chat_completion` The motivation is for evaluations targeting a local inference engine (like meta-reference or vllm) where batch APIs provide for a substantial amount of acceleration. Why did I not add this to `Api.batch_inference` though? That just resulted in a _lot_ more book-keeping given the structure of Llama Stack. Had I done that, I would have needed to create a notion of a "batch model" resource, setup routing based on that, etc. This does not sound ideal. So what's the future of the batch inference API? I am not sure. Maybe we can keep it for true _asynchronous_ execution. So you can submit requests, and it can return a Job instance, etc. ## Test Plan Run meta-reference-gpu using: ```bash export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000 export MODEL_PARALLEL_SIZE=4 export MAX_BATCH_SIZE=32 export MAX_SEQ_LEN=6144 LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu ``` Then run the batch inference test case.
This commit is contained in:
parent
854c2ad264
commit
f34f22f8c7
23 changed files with 698 additions and 389 deletions
|
@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
|
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||
class Llama4Generator:
|
||||
class LlamaGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
|
@ -144,7 +143,8 @@ class Llama4Generator:
|
|||
else:
|
||||
quantization_mode = None
|
||||
|
||||
self.inner_generator = Llama4.build(
|
||||
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||
self.inner_generator = cls.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
|
@ -158,142 +158,55 @@ class Llama4Generator:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
request_batch: List[CompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result[0]
|
||||
yield result
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
first_request = request_batch[0]
|
||||
sampling_params = first_request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
llm_inputs=[
|
||||
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||
for request in request_batch
|
||||
],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
logprobs=bool(first_request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
first_request.response_format,
|
||||
),
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
|
||||
class Llama3Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
self.inner_generator = Llama3.build(
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
self.formatter = self.inner_generator.formatter
|
||||
|
||||
def completion(
|
||||
self,
|
||||
request: CompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
model_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequestWithRawContent,
|
||||
) -> Generator:
|
||||
sampling_params = request.sampling_params or SamplingParams()
|
||||
max_gen_len = sampling_params.max_tokens
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
for result in self.inner_generator.generate(
|
||||
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
logprobs=bool(request.logprobs),
|
||||
echo=False,
|
||||
logits_processor=get_logits_processor(
|
||||
self.tokenizer,
|
||||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
):
|
||||
yield result[0]
|
||||
yield result
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue