diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index cdd6b3b53..542fb5be5 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -85,7 +85,7 @@
}
}
},
- "/v1/batch-inference/chat-completion": {
+ "/v1/inference/batch-chat-completion": {
"post": {
"responses": {
"200": {
@@ -112,7 +112,7 @@
}
},
"tags": [
- "BatchInference (Coming Soon)"
+ "Inference"
],
"description": "",
"parameters": [],
@@ -128,7 +128,7 @@
}
}
},
- "/v1/batch-inference/completion": {
+ "/v1/inference/batch-completion": {
"post": {
"responses": {
"200": {
@@ -155,7 +155,7 @@
}
},
"tags": [
- "BatchInference (Coming Soon)"
+ "Inference"
],
"description": "",
"parameters": [],
@@ -239,7 +239,7 @@
}
},
"tags": [
- "Inference"
+ "BatchInference (Coming Soon)"
],
"description": "Generate a chat completion for the given messages using the specified model.",
"parameters": [],
@@ -287,7 +287,7 @@
}
},
"tags": [
- "Inference"
+ "BatchInference (Coming Soon)"
],
"description": "Generate a completion for the given content using the specified model.",
"parameters": [],
@@ -4366,6 +4366,51 @@
],
"title": "ToolCall"
},
+ "ToolConfig": {
+ "type": "object",
+ "properties": {
+ "tool_choice": {
+ "oneOf": [
+ {
+ "type": "string",
+ "enum": [
+ "auto",
+ "required",
+ "none"
+ ],
+ "title": "ToolChoice",
+ "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
+ },
+ {
+ "type": "string"
+ }
+ ],
+ "default": "auto",
+ "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
+ },
+ "tool_prompt_format": {
+ "type": "string",
+ "enum": [
+ "json",
+ "function_tag",
+ "python_list"
+ ],
+ "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
+ },
+ "system_message_behavior": {
+ "type": "string",
+ "enum": [
+ "append",
+ "replace"
+ ],
+ "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
+ "default": "append"
+ }
+ },
+ "additionalProperties": false,
+ "title": "ToolConfig",
+ "description": "Configuration for tool use."
+ },
"ToolDefinition": {
"type": "object",
"properties": {
@@ -4554,7 +4599,7 @@
"BatchChatCompletionRequest": {
"type": "object",
"properties": {
- "model": {
+ "model_id": {
"type": "string"
},
"messages_batch": {
@@ -4575,25 +4620,8 @@
"$ref": "#/components/schemas/ToolDefinition"
}
},
- "tool_choice": {
- "type": "string",
- "enum": [
- "auto",
- "required",
- "none"
- ],
- "title": "ToolChoice",
- "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
- },
- "tool_prompt_format": {
- "type": "string",
- "enum": [
- "json",
- "function_tag",
- "python_list"
- ],
- "title": "ToolPromptFormat",
- "description": "Prompt format for calling custom / zero shot tools."
+ "tool_config": {
+ "$ref": "#/components/schemas/ToolConfig"
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat"
@@ -4613,7 +4641,7 @@
},
"additionalProperties": false,
"required": [
- "model",
+ "model_id",
"messages_batch"
],
"title": "BatchChatCompletionRequest"
@@ -4710,7 +4738,7 @@
"BatchCompletionRequest": {
"type": "object",
"properties": {
- "model": {
+ "model_id": {
"type": "string"
},
"content_batch": {
@@ -4740,7 +4768,7 @@
},
"additionalProperties": false,
"required": [
- "model",
+ "model_id",
"content_batch"
],
"title": "BatchCompletionRequest"
@@ -4812,51 +4840,6 @@
],
"title": "CancelTrainingJobRequest"
},
- "ToolConfig": {
- "type": "object",
- "properties": {
- "tool_choice": {
- "oneOf": [
- {
- "type": "string",
- "enum": [
- "auto",
- "required",
- "none"
- ],
- "title": "ToolChoice",
- "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
- },
- {
- "type": "string"
- }
- ],
- "default": "auto",
- "description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
- },
- "tool_prompt_format": {
- "type": "string",
- "enum": [
- "json",
- "function_tag",
- "python_list"
- ],
- "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
- },
- "system_message_behavior": {
- "type": "string",
- "enum": [
- "append",
- "replace"
- ],
- "description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
- "default": "append"
- }
- },
- "additionalProperties": false,
- "title": "ToolConfig",
- "description": "Configuration for tool use."
- },
"ChatCompletionRequest": {
"type": "object",
"properties": {
@@ -11173,7 +11156,9 @@
"x-displayName": "Agents API for creating and interacting with agentic systems."
},
{
- "name": "BatchInference (Coming Soon)"
+ "name": "BatchInference (Coming Soon)",
+ "description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).",
+ "x-displayName": "Batch inference API for generating completions and chat completions."
},
{
"name": "Benchmarks"
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index aa8d9456e..fa7b130e2 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -40,7 +40,7 @@ paths:
schema:
$ref: '#/components/schemas/AppendRowsRequest'
required: true
- /v1/batch-inference/chat-completion:
+ /v1/inference/batch-chat-completion:
post:
responses:
'200':
@@ -60,7 +60,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- - BatchInference (Coming Soon)
+ - Inference
description: ''
parameters: []
requestBody:
@@ -69,7 +69,7 @@ paths:
schema:
$ref: '#/components/schemas/BatchChatCompletionRequest'
required: true
- /v1/batch-inference/completion:
+ /v1/inference/batch-completion:
post:
responses:
'200':
@@ -89,7 +89,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- - BatchInference (Coming Soon)
+ - Inference
description: ''
parameters: []
requestBody:
@@ -148,7 +148,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- - Inference
+ - BatchInference (Coming Soon)
description: >-
Generate a chat completion for the given messages using the specified model.
parameters: []
@@ -183,7 +183,7 @@ paths:
default:
$ref: '#/components/responses/DefaultError'
tags:
- - Inference
+ - BatchInference (Coming Soon)
description: >-
Generate a completion for the given content using the specified model.
parameters: []
@@ -3009,6 +3009,54 @@ components:
- tool_name
- arguments
title: ToolCall
+ ToolConfig:
+ type: object
+ properties:
+ tool_choice:
+ oneOf:
+ - type: string
+ enum:
+ - auto
+ - required
+ - none
+ title: ToolChoice
+ description: >-
+ Whether tool use is required or automatic. This is a hint to the model
+ which may not be followed. It depends on the Instruction Following
+ capabilities of the model.
+ - type: string
+ default: auto
+ description: >-
+ (Optional) Whether tool use is automatic, required, or none. Can also
+ specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
+ tool_prompt_format:
+ type: string
+ enum:
+ - json
+ - function_tag
+ - python_list
+ description: >-
+ (Optional) Instructs the model how to format tool calls. By default, Llama
+ Stack will attempt to use a format that is best adapted to the model.
+ - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
+ - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a
+ tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
+ syntax -- a list of function calls.
+ system_message_behavior:
+ type: string
+ enum:
+ - append
+ - replace
+ description: >-
+ (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
+ Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
+ Replaces the default system prompt with the provided system message. The
+ system message can include the string '{{function_definitions}}' to indicate
+ where the function definitions should be inserted.
+ default: append
+ additionalProperties: false
+ title: ToolConfig
+ description: Configuration for tool use.
ToolDefinition:
type: object
properties:
@@ -3145,7 +3193,7 @@ components:
BatchChatCompletionRequest:
type: object
properties:
- model:
+ model_id:
type: string
messages_batch:
type: array
@@ -3159,26 +3207,8 @@ components:
type: array
items:
$ref: '#/components/schemas/ToolDefinition'
- tool_choice:
- type: string
- enum:
- - auto
- - required
- - none
- title: ToolChoice
- description: >-
- Whether tool use is required or automatic. This is a hint to the model
- which may not be followed. It depends on the Instruction Following capabilities
- of the model.
- tool_prompt_format:
- type: string
- enum:
- - json
- - function_tag
- - python_list
- title: ToolPromptFormat
- description: >-
- Prompt format for calling custom / zero shot tools.
+ tool_config:
+ $ref: '#/components/schemas/ToolConfig'
response_format:
$ref: '#/components/schemas/ResponseFormat'
logprobs:
@@ -3193,7 +3223,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- - model
+ - model_id
- messages_batch
title: BatchChatCompletionRequest
BatchChatCompletionResponse:
@@ -3261,7 +3291,7 @@ components:
BatchCompletionRequest:
type: object
properties:
- model:
+ model_id:
type: string
content_batch:
type: array
@@ -3283,7 +3313,7 @@ components:
title: LogProbConfig
additionalProperties: false
required:
- - model
+ - model_id
- content_batch
title: BatchCompletionRequest
BatchCompletionResponse:
@@ -3335,54 +3365,6 @@ components:
required:
- job_uuid
title: CancelTrainingJobRequest
- ToolConfig:
- type: object
- properties:
- tool_choice:
- oneOf:
- - type: string
- enum:
- - auto
- - required
- - none
- title: ToolChoice
- description: >-
- Whether tool use is required or automatic. This is a hint to the model
- which may not be followed. It depends on the Instruction Following
- capabilities of the model.
- - type: string
- default: auto
- description: >-
- (Optional) Whether tool use is automatic, required, or none. Can also
- specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
- tool_prompt_format:
- type: string
- enum:
- - json
- - function_tag
- - python_list
- description: >-
- (Optional) Instructs the model how to format tool calls. By default, Llama
- Stack will attempt to use a format that is best adapted to the model.
- - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a
- tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
- syntax -- a list of function calls.
- system_message_behavior:
- type: string
- enum:
- - append
- - replace
- description: >-
- (Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
- Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
- Replaces the default system prompt with the provided system message. The
- system message can include the string '{{function_definitions}}' to indicate
- where the function definitions should be inserted.
- default: append
- additionalProperties: false
- title: ToolConfig
- description: Configuration for tool use.
ChatCompletionRequest:
type: object
properties:
@@ -7632,6 +7614,17 @@ tags:
x-displayName: >-
Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon)
+ description: >-
+ This is an asynchronous API. If the request is successful, the response will
+ be a job which can be polled for completion.
+
+
+ NOTE: This API is not yet implemented and is subject to change in concert with
+ other asynchronous APIs
+
+ including (post-training, evals, etc).
+ x-displayName: >-
+ Batch inference API for generating completions and chat completions.
- name: Benchmarks
- name: DatasetIO
- name: Datasets
diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py
index 330a683ba..7a324128d 100644
--- a/llama_stack/apis/batch_inference/batch_inference.py
+++ b/llama_stack/apis/batch_inference/batch_inference.py
@@ -6,11 +6,8 @@
from typing import List, Optional, Protocol, runtime_checkable
-from pydantic import BaseModel
-
+from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
- ChatCompletionResponse,
- CompletionResponse,
InterleavedContent,
LogProbConfig,
Message,
@@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
-from llama_stack.schema_utils import json_schema_type, webmethod
-
-
-@json_schema_type
-class BatchCompletionResponse(BaseModel):
- batch: List[CompletionResponse]
-
-
-@json_schema_type
-class BatchChatCompletionResponse(BaseModel):
- batch: List[ChatCompletionResponse]
+from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
+ """Batch inference API for generating completions and chat completions.
+
+ This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
+
+ NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
+ including (post-training, evals, etc).
+ """
+
@webmethod(route="/batch-inference/completion", method="POST")
- async def batch_completion(
+ async def completion(
self,
model: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
- ) -> BatchCompletionResponse: ...
+ ) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST")
- async def batch_chat_completion(
+ async def chat_completion(
self,
model: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
- tools: Optional[List[ToolDefinition]] = list,
+ tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
- ) -> BatchChatCompletionResponse: ...
+ ) -> Job: ...
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 3390a3fef..9eb3910c6 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document"
+@json_schema_type
+class BatchCompletionResponse(BaseModel):
+ batch: List[CompletionResponse]
+
+
+@json_schema_type
+class BatchChatCompletionResponse(BaseModel):
+ batch: List[ChatCompletionResponse]
+
+
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@@ -716,6 +726,17 @@ class Inference(Protocol):
"""
...
+ @webmethod(route="/inference/batch-completion", method="POST")
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> BatchCompletionResponse:
+ raise NotImplementedError("Batch completion is not implemented")
+
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
@@ -756,6 +777,19 @@ class Inference(Protocol):
"""
...
+ @webmethod(route="/inference/batch-chat-completion", method="POST")
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> BatchChatCompletionResponse:
+ raise NotImplementedError("Batch chat completion is not implemented")
+
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index bc313036f..b9623ef3c 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
+ BatchChatCompletionResponse,
+ BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
@@ -334,6 +336,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> BatchChatCompletionResponse:
+ logger.debug(
+ f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
+ )
+ provider = self.routing_table.get_provider_impl(model_id)
+ return await provider.batch_chat_completion(
+ model_id=model_id,
+ messages_batch=messages_batch,
+ tools=tools,
+ tool_config=tool_config,
+ sampling_params=sampling_params,
+ response_format=response_format,
+ logprobs=logprobs,
+ )
+
async def completion(
self,
model_id: str,
@@ -398,6 +424,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> BatchCompletionResponse:
+ logger.debug(
+ f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
+ )
+ provider = self.routing_table.get_provider_impl(model_id)
+ return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
+
async def embeddings(
self,
model_id: str,
diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py
index f55cd5e1c..fe7a7a898 100644
--- a/llama_stack/models/llama/llama3/chat_format.py
+++ b/llama_stack/models/llama/llama3/chat_format.py
@@ -226,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
- content = ""
return RawMessage(
role="assistant",
diff --git a/llama_stack/models/llama/llama3/generation.py b/llama_stack/models/llama/llama3/generation.py
index 8c6aa242b..35c140707 100644
--- a/llama_stack/models/llama/llama3/generation.py
+++ b/llama_stack/models/llama/llama3/generation.py
@@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args)
- def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
+ def __init__(
+ self,
+ model: Transformer | CrossAttentionTransformer,
+ tokenizer: Tokenizer,
+ args: ModelArgs,
+ ):
self.args = args
self.model = model
self.tokenizer = tokenizer
@@ -149,7 +154,7 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
- model_inputs: List[LLMInput],
+ llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@@ -164,15 +169,15 @@ class Llama3:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
- for inp in model_inputs:
+ for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
- prompt_tokens = [inp.tokens for inp in model_inputs]
+ prompt_tokens = [inp.tokens for inp in llm_inputs]
- bsz = len(model_inputs)
+ bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@@ -193,8 +198,8 @@ class Llama3:
is_vision = not isinstance(self.model, Transformer)
if is_vision:
- images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
- mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
+ images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
+ mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
@@ -229,7 +234,7 @@ class Llama3:
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
- text_only_inference = all(inp.vision is None for inp in model_inputs)
+ text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward(
position_ids,
tokens,
@@ -285,7 +290,7 @@ class Llama3:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
- finished=eos_reached[idx],
+ finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py
index 160bb00f8..9d60d00e9 100644
--- a/llama_stack/models/llama/llama4/chat_format.py
+++ b/llama_stack/models/llama/llama4/chat_format.py
@@ -301,7 +301,6 @@ class ChatFormat:
arguments=tool_arguments,
)
)
- content = ""
return RawMessage(
role="assistant",
diff --git a/llama_stack/models/llama/llama4/generation.py b/llama_stack/models/llama/llama4/generation.py
index 7a4087c8f..8e94bb33a 100644
--- a/llama_stack/models/llama/llama4/generation.py
+++ b/llama_stack/models/llama/llama4/generation.py
@@ -233,7 +233,7 @@ class Llama4:
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
- finished=eos_reached[idx],
+ finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py
index 315667506..6f796d0d4 100644
--- a/llama_stack/providers/inline/inference/meta_reference/config.py
+++ b/llama_stack/providers/inline/inference/meta_reference/config.py
@@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
+ max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
+ max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
- "max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir,
"quantization": {
"type": quantization_type,
},
"model_parallel_size": model_parallel_size,
+ "max_batch_size": max_batch_size,
+ "max_seq_len": max_seq_len,
}
diff --git a/llama_stack/providers/inline/inference/meta_reference/generators.py b/llama_stack/providers/inline/inference/meta_reference/generators.py
index 34dd58a9a..0a928ce73 100644
--- a/llama_stack/providers/inline/inference/meta_reference/generators.py
+++ b/llama_stack/providers/inline/inference/meta_reference/generators.py
@@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
-from llama_stack.models.llama.sku_types import Model
+from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
@@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model)
-# TODO: combine Llama3 and Llama4 generators since they are almost identical now
-class Llama4Generator:
+class LlamaGenerator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
@@ -144,7 +143,8 @@ class Llama4Generator:
else:
quantization_mode = None
- self.inner_generator = Llama4.build(
+ cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
+ self.inner_generator = cls.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
@@ -158,142 +158,55 @@ class Llama4Generator:
def completion(
self,
- request: CompletionRequestWithRawContent,
+ request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
- sampling_params = request.sampling_params or SamplingParams()
+ first_request = request_batch[0]
+ sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
- llm_inputs=[self.formatter.encode_content(request.content)],
+ llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
- logprobs=bool(request.logprobs),
+ logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
- request.response_format,
+ first_request.response_format,
),
):
- yield result[0]
+ yield result
def chat_completion(
self,
- request: ChatCompletionRequestWithRawContent,
+ request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
- sampling_params = request.sampling_params or SamplingParams()
+ first_request = request_batch[0]
+ sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
- llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
+ llm_inputs=[
+ self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
+ for request in request_batch
+ ],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
- logprobs=bool(request.logprobs),
+ logprobs=bool(first_request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
- request.response_format,
+ first_request.response_format,
),
):
- yield result[0]
-
-
-class Llama3Generator:
- def __init__(
- self,
- config: MetaReferenceInferenceConfig,
- model_id: str,
- llama_model: Model,
- ):
- if config.checkpoint_dir and config.checkpoint_dir != "null":
- ckpt_dir = config.checkpoint_dir
- else:
- resolved_model = resolve_model(model_id)
- if resolved_model is None:
- # if the model is not a native llama model, get the default checkpoint_dir based on model id
- ckpt_dir = model_checkpoint_dir(model_id)
- else:
- # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
- ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
-
- if config.quantization:
- if config.quantization.type == "fp8_mixed":
- quantization_mode = QuantizationMode.fp8_mixed
- elif config.quantization.type == "int4_mixed":
- quantization_mode = QuantizationMode.int4_mixed
- elif config.quantization.type == "bf16":
- quantization_mode = None
- else:
- raise ValueError(f"Unsupported quantization mode {config.quantization}")
- else:
- quantization_mode = None
-
- self.inner_generator = Llama3.build(
- ckpt_dir=ckpt_dir,
- max_seq_len=config.max_seq_len,
- max_batch_size=config.max_batch_size,
- world_size=config.model_parallel_size or llama_model.pth_file_count,
- quantization_mode=quantization_mode,
- )
- self.tokenizer = self.inner_generator.tokenizer
- self.args = self.inner_generator.args
- self.formatter = self.inner_generator.formatter
-
- def completion(
- self,
- request: CompletionRequestWithRawContent,
- ) -> Generator:
- sampling_params = request.sampling_params or SamplingParams()
- max_gen_len = sampling_params.max_tokens
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
- max_gen_len = self.args.max_seq_len - 1
-
- temperature, top_p = _infer_sampling_params(sampling_params)
- for result in self.inner_generator.generate(
- model_inputs=[self.formatter.encode_content(request.content)],
- max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=bool(request.logprobs),
- echo=False,
- logits_processor=get_logits_processor(
- self.tokenizer,
- self.args.vocab_size,
- request.response_format,
- ),
- ):
- yield result[0]
-
- def chat_completion(
- self,
- request: ChatCompletionRequestWithRawContent,
- ) -> Generator:
- sampling_params = request.sampling_params or SamplingParams()
- max_gen_len = sampling_params.max_tokens
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
- max_gen_len = self.args.max_seq_len - 1
-
- temperature, top_p = _infer_sampling_params(sampling_params)
- for result in self.inner_generator.generate(
- model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
- max_gen_len=max_gen_len,
- temperature=temperature,
- top_p=top_p,
- logprobs=bool(request.logprobs),
- echo=False,
- logits_processor=get_logits_processor(
- self.tokenizer,
- self.args.vocab_size,
- request.response_format,
- ),
- ):
- yield result[0]
+ yield result
diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py
index 3a7632065..0b56ba1f7 100644
--- a/llama_stack/providers/inline/inference/meta_reference/inference.py
+++ b/llama_stack/providers/inline/inference/meta_reference/inference.py
@@ -5,10 +5,10 @@
# the root directory of this source tree.
import asyncio
-import logging
import os
from typing import AsyncGenerator, List, Optional, Union
+from pydantic import BaseModel
from termcolor import cprint
from llama_stack.apis.common.content_types import (
@@ -17,6 +17,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
+ BatchChatCompletionResponse,
+ BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@@ -38,8 +40,10 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
+ UserMessage,
)
from llama_stack.apis.models import Model, ModelType
+from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@@ -65,21 +69,17 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from .config import MetaReferenceInferenceConfig
-from .generators import Llama3Generator, Llama4Generator
+from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator
-log = logging.getLogger(__name__)
+log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1)
-def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
- return Llama3Generator(config, model_id, llama_model)
-
-
-def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
- return Llama4Generator(config, model_id, llama_model)
+def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
+ return LlamaGenerator(config, model_id, llama_model)
class MetaReferenceInferenceImpl(
@@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`")
- if llama_model.model_family in {
- ModelFamily.llama3,
- ModelFamily.llama3_1,
- ModelFamily.llama3_2,
- ModelFamily.llama3_3,
- }:
- builder_fn = llama3_builder_fn
- elif llama_model.model_family == ModelFamily.llama4:
- builder_fn = llama4_builder_fn
- else:
- raise ValueError(f"Unsupported model family: {llama_model.model_family}")
-
builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
- builder_fn=builder_fn,
+ builder_fn=llama_builder_fn,
builder_params=builder_params,
formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance())
@@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
)
self.generator.start()
else:
- self.generator = builder_fn(*builder_params)
+ self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id
self.llama_model = llama_model
+ log.info("Warming up...")
+ await self.completion(
+ model_id=model_id,
+ content="Hello, world!",
+ sampling_params=SamplingParams(max_tokens=10),
+ )
+ await self.chat_completion(
+ model_id=model_id,
+ messages=[UserMessage(content="Hi how are you?")],
+ sampling_params=SamplingParams(max_tokens=20),
+ )
+ log.info("Warmed up!")
+
def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None:
raise RuntimeError(
@@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_completion(request)
else:
- return await self._nonstream_completion(request)
+ results = await self._nonstream_completion([request])
+ return results[0]
+
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ stream: Optional[bool] = False,
+ logprobs: Optional[LogProbConfig] = None,
+ ) -> BatchCompletionResponse:
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+ if logprobs:
+ assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
+
+ content_batch = [
+ augment_content_with_response_format_prompt(response_format, content) for content in content_batch
+ ]
+
+ request_batch = []
+ for content in content_batch:
+ request = CompletionRequest(
+ model=model_id,
+ content=content,
+ sampling_params=sampling_params,
+ response_format=response_format,
+ stream=stream,
+ logprobs=logprobs,
+ )
+ self.check_model(request)
+ request = await convert_request_to_raw(request)
+ request_batch.append(request)
+
+ results = await self._nonstream_completion(request_batch)
+ return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl():
yield x
- async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
+ async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
+ first_request = request_batch[0]
+
+ class ItemState(BaseModel):
+ tokens: List[int] = []
+ logprobs: List[TokenLogProbs] = []
+ stop_reason: StopReason | None = None
+ finished: bool = False
+
def impl():
- tokens = []
- logprobs = []
- stop_reason = None
+ states = [ItemState() for _ in request_batch]
- for token_result in self.generator.completion(request):
- tokens.append(token_result.token)
- if token_result.token == tokenizer.eot_id:
- stop_reason = StopReason.end_of_turn
- elif token_result.token == tokenizer.eom_id:
- stop_reason = StopReason.end_of_message
+ results = []
+ for token_results in self.generator.completion(request_batch):
+ for result in token_results:
+ idx = result.batch_idx
+ state = states[idx]
+ if state.finished or result.ignore_token:
+ continue
- if request.logprobs:
- assert len(token_result.logprobs) == 1
+ state.finished = result.finished
+ if first_request.logprobs:
+ state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
- logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
+ state.tokens.append(result.token)
+ if result.token == tokenizer.eot_id:
+ state.stop_reason = StopReason.end_of_turn
+ elif result.token == tokenizer.eom_id:
+ state.stop_reason = StopReason.end_of_message
- if stop_reason is None:
- stop_reason = StopReason.out_of_tokens
+ for state in states:
+ if state.stop_reason is None:
+ state.stop_reason = StopReason.out_of_tokens
- if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
- tokens = tokens[:-1]
- content = self.generator.formatter.tokenizer.decode(tokens)
- return CompletionResponse(
- content=content,
- stop_reason=stop_reason,
- logprobs=logprobs if request.logprobs else None,
- )
+ if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
+ state.tokens = state.tokens[:-1]
+ content = self.generator.formatter.tokenizer.decode(state.tokens)
+ results.append(
+ CompletionResponse(
+ content=content,
+ stop_reason=state.stop_reason,
+ logprobs=state.logprobs if first_request.logprobs else None,
+ )
+ )
+
+ return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format,
stream=stream,
logprobs=logprobs,
- tool_config=tool_config,
+ tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
@@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream:
return self._stream_chat_completion(request)
else:
- return await self._nonstream_chat_completion(request)
+ results = await self._nonstream_chat_completion([request])
+ return results[0]
- async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ stream: Optional[bool] = False,
+ logprobs: Optional[LogProbConfig] = None,
+ tool_config: Optional[ToolConfig] = None,
+ ) -> BatchChatCompletionResponse:
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+ if logprobs:
+ assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
+
+ # wrapper request to make it easier to pass around (internal only, not exposed to API)
+ request_batch = []
+ for messages in messages_batch:
+ request = ChatCompletionRequest(
+ model=model_id,
+ messages=messages,
+ sampling_params=sampling_params,
+ tools=tools or [],
+ response_format=response_format,
+ logprobs=logprobs,
+ tool_config=tool_config or ToolConfig(),
+ )
+ self.check_model(request)
+
+ # augment and rewrite messages depending on the model
+ request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
+ # download media and convert to raw content so we can send it to the model
+ request = await convert_request_to_raw(request)
+ request_batch.append(request)
+
+ if self.config.create_distributed_process_group:
+ if SEMAPHORE.locked():
+ raise RuntimeError("Only one concurrent request is supported")
+
+ results = await self._nonstream_chat_completion(request_batch)
+ return BatchChatCompletionResponse(batch=results)
+
+ async def _nonstream_chat_completion(
+ self, request_batch: List[ChatCompletionRequest]
+ ) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer
+ first_request = request_batch[0]
+
+ class ItemState(BaseModel):
+ tokens: List[int] = []
+ logprobs: List[TokenLogProbs] = []
+ stop_reason: StopReason | None = None
+ finished: bool = False
+
def impl():
- tokens = []
- logprobs = []
- stop_reason = None
+ states = [ItemState() for _ in request_batch]
- for token_result in self.generator.chat_completion(request):
- if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
- cprint(token_result.text, "cyan", end="")
+ for token_results in self.generator.chat_completion(request_batch):
+ first = token_results[0]
+ if not first.finished and not first.ignore_token:
+ if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
+ cprint(first.text, "cyan", end="")
+ if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
+ cprint(f"<{first.token}>", "magenta", end="")
- tokens.append(token_result.token)
+ for result in token_results:
+ idx = result.batch_idx
+ state = states[idx]
+ if state.finished or result.ignore_token:
+ continue
- if token_result.token == tokenizer.eot_id:
- stop_reason = StopReason.end_of_turn
- elif token_result.token == tokenizer.eom_id:
- stop_reason = StopReason.end_of_message
+ state.finished = result.finished
+ if first_request.logprobs:
+ state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
- if request.logprobs:
- assert len(token_result.logprobs) == 1
+ state.tokens.append(result.token)
+ if result.token == tokenizer.eot_id:
+ state.stop_reason = StopReason.end_of_turn
+ elif result.token == tokenizer.eom_id:
+ state.stop_reason = StopReason.end_of_message
- logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
+ results = []
+ for state in states:
+ if state.stop_reason is None:
+ state.stop_reason = StopReason.out_of_tokens
- if stop_reason is None:
- stop_reason = StopReason.out_of_tokens
+ raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
+ results.append(
+ ChatCompletionResponse(
+ completion_message=CompletionMessage(
+ content=raw_message.content,
+ stop_reason=raw_message.stop_reason,
+ tool_calls=raw_message.tool_calls,
+ ),
+ logprobs=state.logprobs if first_request.logprobs else None,
+ )
+ )
- raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
- return ChatCompletionResponse(
- completion_message=CompletionMessage(
- content=raw_message.content,
- stop_reason=raw_message.stop_reason,
- tool_calls=raw_message.tool_calls,
- ),
- logprobs=logprobs if request.logprobs else None,
- )
+ return results
if self.config.create_distributed_process_group:
async with SEMAPHORE:
@@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="")
+ if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
+ cprint(f"<{token_result.token}>", "magenta", end="")
+
+ if token_result.token == tokenizer.eot_id:
+ stop_reason = StopReason.end_of_turn
+ text = ""
+ elif token_result.token == tokenizer.eom_id:
+ stop_reason = StopReason.end_of_message
+ text = ""
+ else:
+ text = token_result.text
+
+ if request.logprobs:
+ assert len(token_result.logprobs) == 1
+
+ logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token)
diff --git a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py
index bed3025a8..50640c6d1 100644
--- a/llama_stack/providers/inline/inference/meta_reference/model_parallel.py
+++ b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py
@@ -6,7 +6,7 @@
from copy import deepcopy
from functools import partial
-from typing import Any, Callable, Generator
+from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
- def __call__(self, req: Any):
- if isinstance(req, ChatCompletionRequestWithRawContent):
- return self.llama.chat_completion(req)
- elif isinstance(req, CompletionRequestWithRawContent):
- return self.llama.completion(req)
+ def __call__(self, task: Any):
+ if task[0] == "chat_completion":
+ return self.llama.chat_completion(task[1])
+ elif task[0] == "completion":
+ return self.llama.completion(task[1])
else:
- raise ValueError(f"Unexpected task type {type(req)}")
+ raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb(
@@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion(
self,
- request: CompletionRequestWithRawContent,
+ request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
- req_obj = deepcopy(request)
- gen = self.group.run_inference(req_obj)
+ req_obj = deepcopy(request_batch)
+ gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
- request: ChatCompletionRequestWithRawContent,
+ request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
- req_obj = deepcopy(request)
- gen = self.group.run_inference(req_obj)
+ req_obj = deepcopy(request_batch)
+ gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen
diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
index 74fc49d5e..8752f06f3 100644
--- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
+++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
@@ -19,7 +19,7 @@ import tempfile
import time
import uuid
from enum import Enum
-from typing import Callable, Generator, Literal, Optional, Union
+from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
import torch
import zmq
@@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
- task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
+ task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
- result: GenerationResult
+ result: List[GenerationResult]
class ExceptionResponse(BaseModel):
@@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
def run_inference(
self,
- req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
+ req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
) -> Generator:
assert not self.running, "inference already running"
diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
index 9c370b6c5..5bc20e3c2 100644
--- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
+++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py
@@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
+ InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
@@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")
+
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
+
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index b8671197e..33b48af46 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
}
return await self.openai_client.chat.completions.create(**params) # type: ignore
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch completion is not supported for Ollama")
+
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch chat completion is not supported for Ollama")
+
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 79f92adce..0044d2e75 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
+
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch completion is not supported for Ollama")
+
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch chat completion is not supported for Ollama")
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 2d2f0400a..cd0f4ec67 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
user=user,
)
return litellm.completion(**params)
+
+ async def batch_completion(
+ self,
+ model_id: str,
+ content_batch: List[InterleavedContent],
+ sampling_params: Optional[SamplingParams] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
+
+ async def batch_chat_completion(
+ self,
+ model_id: str,
+ messages_batch: List[List[Message]],
+ sampling_params: Optional[SamplingParams] = None,
+ tools: Optional[List[ToolDefinition]] = None,
+ tool_config: Optional[ToolConfig] = None,
+ response_format: Optional[ResponseFormat] = None,
+ logprobs: Optional[LogProbConfig] = None,
+ ):
+ raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml
index 9f97158f8..63177ab09 100644
--- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml
+++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml
@@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
- max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
+ max_batch_size: ${env.MAX_BATCH_SIZE:1}
+ max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@@ -28,11 +29,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.SAFETY_MODEL}
- max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
+ max_batch_size: ${env.MAX_BATCH_SIZE:1}
+ max_seq_len: ${env.MAX_SEQ_LEN:4096}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml
index eda332123..380d83060 100644
--- a/llama_stack/templates/meta-reference-gpu/run.yaml
+++ b/llama_stack/templates/meta-reference-gpu/run.yaml
@@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference
config:
model: ${env.INFERENCE_MODEL}
- max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization:
type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
+ max_batch_size: ${env.MAX_BATCH_SIZE:1}
+ max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
diff --git a/tests/integration/inference/test_batch_inference.py b/tests/integration/inference/test_batch_inference.py
new file mode 100644
index 000000000..9a1a62ce0
--- /dev/null
+++ b/tests/integration/inference/test_batch_inference.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+
+import pytest
+
+from ..test_cases.test_case import TestCase
+
+
+def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
+ models = {m.identifier: m for m in client_with_models.models.list()}
+ models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
+ provider_id = models[model_id].provider_id
+ providers = {p.provider_id: p for p in client_with_models.providers.list()}
+ provider = providers[provider_id]
+ if provider.provider_type not in ("inline::meta-reference",):
+ pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ "inference:completion:batch_completion",
+ ],
+)
+def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
+ skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
+ tc = TestCase(test_case)
+
+ content_batch = tc["contents"]
+ response = client_with_models.inference.batch_completion(
+ content_batch=content_batch,
+ model_id=text_model_id,
+ sampling_params={
+ "max_tokens": 50,
+ },
+ )
+ assert len(response.batch) == len(content_batch)
+ for i, r in enumerate(response.batch):
+ print(f"response {i}: {r.content}")
+ assert len(r.content) > 10
+
+
+@pytest.mark.parametrize(
+ "test_case",
+ [
+ "inference:chat_completion:batch_completion",
+ ],
+)
+def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
+ skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
+ tc = TestCase(test_case)
+ qa_pairs = tc["qa_pairs"]
+
+ message_batch = [
+ [
+ {
+ "role": "user",
+ "content": qa["question"],
+ }
+ ]
+ for qa in qa_pairs
+ ]
+
+ response = client_with_models.inference.batch_chat_completion(
+ messages_batch=message_batch,
+ model_id=text_model_id,
+ )
+ assert len(response.batch) == len(qa_pairs)
+ for i, r in enumerate(response.batch):
+ print(f"response {i}: {r.completion_message.content}")
+ assert len(r.completion_message.content) > 0
+ assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()
diff --git a/tests/integration/test_cases/inference/chat_completion.json b/tests/integration/test_cases/inference/chat_completion.json
index 01956bd59..5663089fb 100644
--- a/tests/integration/test_cases/inference/chat_completion.json
+++ b/tests/integration/test_cases/inference/chat_completion.json
@@ -537,5 +537,31 @@
}
]
}
+ },
+ "batch_completion": {
+ "data": {
+ "qa_pairs": [
+ {
+ "question": "What is the capital of France?",
+ "answer": "Paris"
+ },
+ {
+ "question": "Who wrote the book '1984'?",
+ "answer": "George Orwell"
+ },
+ {
+ "question": "Which planet has rings around it with a name starting with letter S?",
+ "answer": "Saturn"
+ },
+ {
+ "question": "When did the first moon landing happen?",
+ "answer": "1969"
+ },
+ {
+ "question": "What word says 'hello' in Spanish?",
+ "answer": "Hola"
+ }
+ ]
+ }
}
}
diff --git a/tests/integration/test_cases/inference/completion.json b/tests/integration/test_cases/inference/completion.json
index 06abbdc8b..731ceddbc 100644
--- a/tests/integration/test_cases/inference/completion.json
+++ b/tests/integration/test_cases/inference/completion.json
@@ -44,5 +44,18 @@
"year_retired": "2003"
}
}
+ },
+ "batch_completion": {
+ "data": {
+ "contents": [
+ "Micheael Jordan is born in ",
+ "Roses are red, violets are ",
+ "If you had a million dollars, what would you do with it? ",
+ "All you need is ",
+ "The capital of France is ",
+ "It is a good day to ",
+ "The answer to the universe is "
+ ]
+ }
}
}