feat: add batch inference API to llama stack inference

This commit is contained in:
Ashwin Bharambe 2025-04-08 13:50:52 -07:00
parent ed58a94b30
commit 0cfb2e2473
24 changed files with 1041 additions and 377 deletions

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -285,7 +290,7 @@ class Llama3:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -233,7 +233,7 @@ class Llama4:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)