mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add batch inference API to llama stack inference (#1945)
# What does this PR do? This PR adds two methods to the Inference API: - `batch_completion` - `batch_chat_completion` The motivation is for evaluations targeting a local inference engine (like meta-reference or vllm) where batch APIs provide for a substantial amount of acceleration. Why did I not add this to `Api.batch_inference` though? That just resulted in a _lot_ more book-keeping given the structure of Llama Stack. Had I done that, I would have needed to create a notion of a "batch model" resource, setup routing based on that, etc. This does not sound ideal. So what's the future of the batch inference API? I am not sure. Maybe we can keep it for true _asynchronous_ execution. So you can submit requests, and it can return a Job instance, etc. ## Test Plan Run meta-reference-gpu using: ```bash export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000 export MODEL_PARALLEL_SIZE=4 export MAX_BATCH_SIZE=32 export MAX_SEQ_LEN=6144 LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu ``` Then run the batch inference test case.
This commit is contained in:
parent
854c2ad264
commit
f34f22f8c7
23 changed files with 698 additions and 389 deletions
|
@ -226,7 +226,6 @@ class ChatFormat:
|
|||
arguments_json=json.dumps(tool_arguments),
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
|
|
|
@ -140,7 +140,12 @@ class Llama3:
|
|||
|
||||
return Llama3(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
@ -149,7 +154,7 @@ class Llama3:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_inputs: List[LLMInput],
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
@ -164,15 +169,15 @@ class Llama3:
|
|||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in model_inputs:
|
||||
for inp in llm_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = len(model_inputs)
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
|
@ -193,8 +198,8 @@ class Llama3:
|
|||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
|
@ -229,7 +234,7 @@ class Llama3:
|
|||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
|
@ -285,7 +290,7 @@ class Llama3:
|
|||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
finished=eos_reached[idx].item(),
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -301,7 +301,6 @@ class ChatFormat:
|
|||
arguments=tool_arguments,
|
||||
)
|
||||
)
|
||||
content = ""
|
||||
|
||||
return RawMessage(
|
||||
role="assistant",
|
||||
|
|
|
@ -233,7 +233,7 @@ class Llama4:
|
|||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
finished=eos_reached[idx].item(),
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue