feat: add batch inference API to llama stack inference (#1945)

# What does this PR do?

This PR adds two methods to the Inference API:
- `batch_completion`
- `batch_chat_completion`

The motivation is for evaluations targeting a local inference engine
(like meta-reference or vllm) where batch APIs provide for a substantial
amount of acceleration.

Why did I not add this to `Api.batch_inference` though? That just
resulted in a _lot_ more book-keeping given the structure of Llama
Stack. Had I done that, I would have needed to create a notion of a
"batch model" resource, setup routing based on that, etc. This does not
sound ideal.

So what's the future of the batch inference API? I am not sure. Maybe we
can keep it for true _asynchronous_ execution. So you can submit
requests, and it can return a Job instance, etc.

## Test Plan

Run meta-reference-gpu using:
```bash
export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct
export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000
export MODEL_PARALLEL_SIZE=4
export MAX_BATCH_SIZE=32
export MAX_SEQ_LEN=6144

LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu
```

Then run the batch inference test case.
This commit is contained in:
Ashwin Bharambe 2025-04-12 11:41:12 -07:00 committed by GitHub
parent 854c2ad264
commit f34f22f8c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 698 additions and 389 deletions

View file

@ -226,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -149,7 +154,7 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_inputs: List[LLMInput],
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -164,15 +169,15 @@ class Llama3:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in model_inputs:
for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(model_inputs)
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -193,8 +198,8 @@ class Llama3:
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
@ -229,7 +234,7 @@ class Llama3:
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs)
text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -285,7 +290,7 @@ class Llama3:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -301,7 +301,6 @@ class ChatFormat:
arguments=tool_arguments,
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -233,7 +233,7 @@ class Llama4:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)