feat: add batch inference API to llama stack inference (#1945)

# What does this PR do?

This PR adds two methods to the Inference API:
- `batch_completion`
- `batch_chat_completion`

The motivation is for evaluations targeting a local inference engine
(like meta-reference or vllm) where batch APIs provide for a substantial
amount of acceleration.

Why did I not add this to `Api.batch_inference` though? That just
resulted in a _lot_ more book-keeping given the structure of Llama
Stack. Had I done that, I would have needed to create a notion of a
"batch model" resource, setup routing based on that, etc. This does not
sound ideal.

So what's the future of the batch inference API? I am not sure. Maybe we
can keep it for true _asynchronous_ execution. So you can submit
requests, and it can return a Job instance, etc.

## Test Plan

Run meta-reference-gpu using:
```bash
export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct
export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000
export MODEL_PARALLEL_SIZE=4
export MAX_BATCH_SIZE=32
export MAX_SEQ_LEN=6144

LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu
```

Then run the batch inference test case.
This commit is contained in:
Ashwin Bharambe 2025-04-12 11:41:12 -07:00 committed by GitHub
parent 854c2ad264
commit f34f22f8c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 698 additions and 389 deletions

View file

@ -85,7 +85,7 @@
}
}
},
"/v1/batch-inference/chat-completion": {
"/v1/inference/batch-chat-completion": {
"post": {
"responses": {
"200": {
@ -112,7 +112,7 @@
}
},
"tags": [
"BatchInference (Coming Soon)"
"Inference"
],
"description": "",
"parameters": [],
@ -128,7 +128,7 @@
}
}
},
"/v1/batch-inference/completion": {
"/v1/inference/batch-completion": {
"post": {
"responses": {
"200": {
@ -155,7 +155,7 @@
}
},
"tags": [
"BatchInference (Coming Soon)"
"Inference"
],
"description": "",
"parameters": [],
@ -239,7 +239,7 @@
}
},
"tags": [
"Inference"
"BatchInference (Coming Soon)"
],
"description": "Generate a chat completion for the given messages using the specified model.",
"parameters": [],
@ -287,7 +287,7 @@
}
},
"tags": [
"Inference"
"BatchInference (Coming Soon)"
],
"description": "Generate a completion for the given content using the specified model.",
"parameters": [],
@ -4366,6 +4366,51 @@
],
"title": "ToolCall"
},
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ToolDefinition": {
"type": "object",
"properties": {
@ -4554,7 +4599,7 @@
"BatchChatCompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"messages_batch": {
@ -4575,25 +4620,8 @@
"$ref": "#/components/schemas/ToolDefinition"
}
},
"tool_choice": {
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"title": "ToolPromptFormat",
"description": "Prompt format for calling custom / zero shot tools."
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat"
@ -4613,7 +4641,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"messages_batch"
],
"title": "BatchChatCompletionRequest"
@ -4710,7 +4738,7 @@
"BatchCompletionRequest": {
"type": "object",
"properties": {
"model": {
"model_id": {
"type": "string"
},
"content_batch": {
@ -4740,7 +4768,7 @@
},
"additionalProperties": false,
"required": [
"model",
"model_id",
"content_batch"
],
"title": "BatchCompletionRequest"
@ -4812,51 +4840,6 @@
],
"title": "CancelTrainingJobRequest"
},
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ChatCompletionRequest": {
"type": "object",
"properties": {
@ -11173,7 +11156,9 @@
"x-displayName": "Agents API for creating and interacting with agentic systems."
},
{
"name": "BatchInference (Coming Soon)"
"name": "BatchInference (Coming Soon)",
"description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).",
"x-displayName": "Batch inference API for generating completions and chat completions."
},
{
"name": "Benchmarks"

View file

@ -40,7 +40,7 @@ paths:
schema:
$ref: '#/components/schemas/AppendRowsRequest'
required: true
/v1/batch-inference/chat-completion:
/v1/inference/batch-chat-completion:
post:
responses:
'200':
@ -60,7 +60,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- BatchInference (Coming Soon)
- Inference
description: ''
parameters: []
requestBody:
@ -69,7 +69,7 @@ paths:
schema:
$ref: '#/components/schemas/BatchChatCompletionRequest'
required: true
/v1/batch-inference/completion:
/v1/inference/batch-completion:
post:
responses:
'200':
@ -89,7 +89,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- BatchInference (Coming Soon)
- Inference
description: ''
parameters: []
requestBody:
@ -148,7 +148,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
- BatchInference (Coming Soon)
description: >-
Generate a chat completion for the given messages using the specified model.
parameters: []
@ -183,7 +183,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
- BatchInference (Coming Soon)
description: >-
Generate a completion for the given content using the specified model.
parameters: []
@ -3009,6 +3009,54 @@ components:
- tool_name
- arguments
title: ToolCall
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ToolDefinition:
type: object
properties:
@ -3145,7 +3193,7 @@ components:
BatchChatCompletionRequest:
type: object
properties:
model:
model_id:
type: string
messages_batch:
type: array
@ -3159,26 +3207,8 @@ components:
type: array
items:
$ref: '#/components/schemas/ToolDefinition'
tool_choice:
type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
of the model.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
title: ToolPromptFormat
description: >-
Prompt format for calling custom / zero shot tools.
tool_config:
$ref: '#/components/schemas/ToolConfig'
response_format:
$ref: '#/components/schemas/ResponseFormat'
logprobs:
@ -3193,7 +3223,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- model
- model_id
- messages_batch
title: BatchChatCompletionRequest
BatchChatCompletionResponse:
@ -3261,7 +3291,7 @@ components:
BatchCompletionRequest:
type: object
properties:
model:
model_id:
type: string
content_batch:
type: array
@ -3283,7 +3313,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- model
- model_id
- content_batch
title: BatchCompletionRequest
BatchCompletionResponse:
@ -3335,54 +3365,6 @@ components:
required:
- job_uuid
title: CancelTrainingJobRequest
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ChatCompletionRequest:
type: object
properties:
@ -7632,6 +7614,17 @@ tags:
x-displayName: >-
Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon)
description: >-
This is an asynchronous API. If the request is successful, the response will
be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with
other asynchronous APIs
including (post-training, evals, etc).
x-displayName: >-
Batch inference API for generating completions and chat completions.
- name: Benchmarks
- name: DatasetIO
- name: Datasets

View file

@ -6,11 +6,8 @@
from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion(
async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ...
) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion(
async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ...
) -> Job: ...

View file

@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@ -716,6 +726,17 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-completion", method="POST")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
@ -756,6 +777,19 @@ class Inference(Protocol):
"""
...
@webmethod(route="/inference/batch-chat-completion", method="POST")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,

View file

@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
@ -334,6 +336,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion(
self,
model_id: str,
@ -398,6 +424,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,

View file

@ -226,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -149,7 +154,7 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_inputs: List[LLMInput],
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -164,15 +169,15 @@ class Llama3:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
for inp in model_inputs:
for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(model_inputs)
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -193,8 +198,8 @@ class Llama3:
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
@ -229,7 +234,7 @@ class Llama3:
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs)
text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -285,7 +290,7 @@ class Llama3:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -301,7 +301,6 @@ class ChatFormat:
arguments=tool_arguments,
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -233,7 +233,7 @@ class Llama4:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)

View file

@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
"max_batch_size": max_batch_size,
"max_seq_len": max_seq_len,
}

View file

@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
class Llama4Generator:
class LlamaGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
@ -144,7 +143,8 @@ class Llama4Generator:
else:
quantization_mode = None
self.inner_generator = Llama4.build(
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
self.inner_generator = cls.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
@ -158,142 +158,55 @@ class Llama4Generator:
def completion(
self,
request: CompletionRequestWithRawContent,
request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)],
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
first_request.response_format,
),
):
yield result[0]
yield result
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
first_request.response_format,
),
):
yield result[0]
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
yield result

View file

@ -5,10 +5,10 @@
# the root directory of this source tree.
import asyncio
import logging
import os
from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
@ -17,6 +17,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -38,8 +40,10 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -65,21 +69,17 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
from .generators import Llama3Generator, Llama4Generator
from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__)
log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
return Llama3Generator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return LlamaGenerator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn,
builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
)
self.generator.start()
else:
self.generator = builder_fn(*builder_params)
self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
log.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
log.info("Warmed up!")
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
results = []
for token_results in self.generator.completion(request_batch):
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if request.logprobs:
assert len(token_result.logprobs) == 1
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens)
return CompletionResponse(
content=content,
stop_reason=stop_reason,
logprobs=logprobs if request.logprobs else None,
)
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(state.tokens)
results.append(
CompletionResponse(
content=content,
stop_reason=state.stop_reason,
logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl():
tokens = []
logprobs = []
stop_reason = None
states = [ItemState() for _ in request_batch]
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
for token_results in self.generator.chat_completion(request_batch):
first = token_results[0]
if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
tokens.append(token_result.token)
for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
state.finished = result.finished
if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
if request.logprobs:
assert len(token_result.logprobs) == 1
state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)

View file

@ -6,7 +6,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
elif task[0] == "completion":
return self.llama.completion(task[1])
else:
raise ValueError(f"Unexpected task type {type(req)}")
raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb(
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequestWithRawContent,
request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen

View file

@ -19,7 +19,7 @@ import tempfile
import time
import uuid
from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
import torch
import zmq
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: GenerationResult
result: List[GenerationResult]
class ExceptionResponse(BaseModel):
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
}
return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:

View file

@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
user=user,
)
return litellm.completion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@ -28,11 +29,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.SAFETY_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
vector_io:
- provider_id: faiss
provider_type: inline::faiss

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..test_cases.test_case import TestCase
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type not in ("inline::meta-reference",):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:batch_completion",
],
)
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
content_batch = tc["contents"]
response = client_with_models.inference.batch_completion(
content_batch=content_batch,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
},
)
assert len(response.batch) == len(content_batch)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.content}")
assert len(r.content) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:batch_completion",
],
)
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
qa_pairs = tc["qa_pairs"]
message_batch = [
[
{
"role": "user",
"content": qa["question"],
}
]
for qa in qa_pairs
]
response = client_with_models.inference.batch_chat_completion(
messages_batch=message_batch,
model_id=text_model_id,
)
assert len(response.batch) == len(qa_pairs)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.completion_message.content}")
assert len(r.completion_message.content) > 0
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()

View file

@ -537,5 +537,31 @@
}
]
}
},
"batch_completion": {
"data": {
"qa_pairs": [
{
"question": "What is the capital of France?",
"answer": "Paris"
},
{
"question": "Who wrote the book '1984'?",
"answer": "George Orwell"
},
{
"question": "Which planet has rings around it with a name starting with letter S?",
"answer": "Saturn"
},
{
"question": "When did the first moon landing happen?",
"answer": "1969"
},
{
"question": "What word says 'hello' in Spanish?",
"answer": "Hola"
}
]
}
}
}

View file

@ -44,5 +44,18 @@
"year_retired": "2003"
}
}
},
"batch_completion": {
"data": {
"contents": [
"Micheael Jordan is born in ",
"Roses are red, violets are ",
"If you had a million dollars, what would you do with it? ",
"All you need is ",
"The capital of France is ",
"It is a good day to ",
"The answer to the universe is "
]
}
}
}