feat: add batch inference API to llama stack inference (#1945)

# What does this PR do?

This PR adds two methods to the Inference API:
- `batch_completion`
- `batch_chat_completion`

The motivation is for evaluations targeting a local inference engine
(like meta-reference or vllm) where batch APIs provide for a substantial
amount of acceleration.

Why did I not add this to `Api.batch_inference` though? That just
resulted in a _lot_ more book-keeping given the structure of Llama
Stack. Had I done that, I would have needed to create a notion of a
"batch model" resource, setup routing based on that, etc. This does not
sound ideal.

So what's the future of the batch inference API? I am not sure. Maybe we
can keep it for true _asynchronous_ execution. So you can submit
requests, and it can return a Job instance, etc.

## Test Plan

Run meta-reference-gpu using:
```bash
export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct
export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000
export MODEL_PARALLEL_SIZE=4
export MAX_BATCH_SIZE=32
export MAX_SEQ_LEN=6144

LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu
```

Then run the batch inference test case.
This commit is contained in:
Ashwin Bharambe 2025-04-12 11:41:12 -07:00 committed by GitHub
parent 854c2ad264
commit f34f22f8c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 698 additions and 389 deletions

View file

@ -85,7 +85,7 @@
} }
} }
}, },
"/v1/batch-inference/chat-completion": { "/v1/inference/batch-chat-completion": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -112,7 +112,7 @@
} }
}, },
"tags": [ "tags": [
"BatchInference (Coming Soon)" "Inference"
], ],
"description": "", "description": "",
"parameters": [], "parameters": [],
@ -128,7 +128,7 @@
} }
} }
}, },
"/v1/batch-inference/completion": { "/v1/inference/batch-completion": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -155,7 +155,7 @@
} }
}, },
"tags": [ "tags": [
"BatchInference (Coming Soon)" "Inference"
], ],
"description": "", "description": "",
"parameters": [], "parameters": [],
@ -239,7 +239,7 @@
} }
}, },
"tags": [ "tags": [
"Inference" "BatchInference (Coming Soon)"
], ],
"description": "Generate a chat completion for the given messages using the specified model.", "description": "Generate a chat completion for the given messages using the specified model.",
"parameters": [], "parameters": [],
@ -287,7 +287,7 @@
} }
}, },
"tags": [ "tags": [
"Inference" "BatchInference (Coming Soon)"
], ],
"description": "Generate a completion for the given content using the specified model.", "description": "Generate a completion for the given content using the specified model.",
"parameters": [], "parameters": [],
@ -4366,6 +4366,51 @@
], ],
"title": "ToolCall" "title": "ToolCall"
}, },
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ToolDefinition": { "ToolDefinition": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4554,7 +4599,7 @@
"BatchChatCompletionRequest": { "BatchChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model_id": {
"type": "string" "type": "string"
}, },
"messages_batch": { "messages_batch": {
@ -4575,25 +4620,8 @@
"$ref": "#/components/schemas/ToolDefinition" "$ref": "#/components/schemas/ToolDefinition"
} }
}, },
"tool_choice": { "tool_config": {
"type": "string", "$ref": "#/components/schemas/ToolConfig"
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"title": "ToolPromptFormat",
"description": "Prompt format for calling custom / zero shot tools."
}, },
"response_format": { "response_format": {
"$ref": "#/components/schemas/ResponseFormat" "$ref": "#/components/schemas/ResponseFormat"
@ -4613,7 +4641,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"model", "model_id",
"messages_batch" "messages_batch"
], ],
"title": "BatchChatCompletionRequest" "title": "BatchChatCompletionRequest"
@ -4710,7 +4738,7 @@
"BatchCompletionRequest": { "BatchCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model_id": {
"type": "string" "type": "string"
}, },
"content_batch": { "content_batch": {
@ -4740,7 +4768,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"model", "model_id",
"content_batch" "content_batch"
], ],
"title": "BatchCompletionRequest" "title": "BatchCompletionRequest"
@ -4812,51 +4840,6 @@
], ],
"title": "CancelTrainingJobRequest" "title": "CancelTrainingJobRequest"
}, },
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ChatCompletionRequest": { "ChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -11173,7 +11156,9 @@
"x-displayName": "Agents API for creating and interacting with agentic systems." "x-displayName": "Agents API for creating and interacting with agentic systems."
}, },
{ {
"name": "BatchInference (Coming Soon)" "name": "BatchInference (Coming Soon)",
"description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).",
"x-displayName": "Batch inference API for generating completions and chat completions."
}, },
{ {
"name": "Benchmarks" "name": "Benchmarks"

View file

@ -40,7 +40,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/AppendRowsRequest' $ref: '#/components/schemas/AppendRowsRequest'
required: true required: true
/v1/batch-inference/chat-completion: /v1/inference/batch-chat-completion:
post: post:
responses: responses:
'200': '200':
@ -60,7 +60,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- BatchInference (Coming Soon) - Inference
description: '' description: ''
parameters: [] parameters: []
requestBody: requestBody:
@ -69,7 +69,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/BatchChatCompletionRequest' $ref: '#/components/schemas/BatchChatCompletionRequest'
required: true required: true
/v1/batch-inference/completion: /v1/inference/batch-completion:
post: post:
responses: responses:
'200': '200':
@ -89,7 +89,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- BatchInference (Coming Soon) - Inference
description: '' description: ''
parameters: [] parameters: []
requestBody: requestBody:
@ -148,7 +148,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - BatchInference (Coming Soon)
description: >- description: >-
Generate a chat completion for the given messages using the specified model. Generate a chat completion for the given messages using the specified model.
parameters: [] parameters: []
@ -183,7 +183,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Inference - BatchInference (Coming Soon)
description: >- description: >-
Generate a completion for the given content using the specified model. Generate a completion for the given content using the specified model.
parameters: [] parameters: []
@ -3009,6 +3009,54 @@ components:
- tool_name - tool_name
- arguments - arguments
title: ToolCall title: ToolCall
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ToolDefinition: ToolDefinition:
type: object type: object
properties: properties:
@ -3145,7 +3193,7 @@ components:
BatchChatCompletionRequest: BatchChatCompletionRequest:
type: object type: object
properties: properties:
model: model_id:
type: string type: string
messages_batch: messages_batch:
type: array type: array
@ -3159,26 +3207,8 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/ToolDefinition' $ref: '#/components/schemas/ToolDefinition'
tool_choice: tool_config:
type: string $ref: '#/components/schemas/ToolConfig'
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
of the model.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
title: ToolPromptFormat
description: >-
Prompt format for calling custom / zero shot tools.
response_format: response_format:
$ref: '#/components/schemas/ResponseFormat' $ref: '#/components/schemas/ResponseFormat'
logprobs: logprobs:
@ -3193,7 +3223,7 @@ components:
title: LogProbConfig title: LogProbConfig
additionalProperties: false additionalProperties: false
required: required:
- model - model_id
- messages_batch - messages_batch
title: BatchChatCompletionRequest title: BatchChatCompletionRequest
BatchChatCompletionResponse: BatchChatCompletionResponse:
@ -3261,7 +3291,7 @@ components:
BatchCompletionRequest: BatchCompletionRequest:
type: object type: object
properties: properties:
model: model_id:
type: string type: string
content_batch: content_batch:
type: array type: array
@ -3283,7 +3313,7 @@ components:
title: LogProbConfig title: LogProbConfig
additionalProperties: false additionalProperties: false
required: required:
- model - model_id
- content_batch - content_batch
title: BatchCompletionRequest title: BatchCompletionRequest
BatchCompletionResponse: BatchCompletionResponse:
@ -3335,54 +3365,6 @@ components:
required: required:
- job_uuid - job_uuid
title: CancelTrainingJobRequest title: CancelTrainingJobRequest
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ChatCompletionRequest: ChatCompletionRequest:
type: object type: object
properties: properties:
@ -7632,6 +7614,17 @@ tags:
x-displayName: >- x-displayName: >-
Agents API for creating and interacting with agentic systems. Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon) - name: BatchInference (Coming Soon)
description: >-
This is an asynchronous API. If the request is successful, the response will
be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with
other asynchronous APIs
including (post-training, evals, etc).
x-displayName: >-
Batch inference API for generating completions and chat completions.
- name: Benchmarks - name: Benchmarks
- name: DatasetIO - name: DatasetIO
- name: Datasets - name: Datasets

View file

@ -6,11 +6,8 @@
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionResponse,
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
class BatchInference(Protocol): class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST") @webmethod(route="/batch-inference/completion", method="POST")
async def batch_completion( async def completion(
self, self,
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> Job: ...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def batch_chat_completion( async def chat_completion(
self, self,
model: str, model: str,
messages_batch: List[List[Message]], messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> Job: ...

View file

@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document" document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Inference(Protocol): class Inference(Protocol):
@ -716,6 +726,17 @@ class Inference(Protocol):
""" """
... ...
@webmethod(route="/inference/batch-completion", method="POST")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
self, self,
@ -756,6 +777,19 @@ class Inference(Protocol):
""" """
... ...
@webmethod(route="/inference/batch-chat-completion", method="POST")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
async def embeddings( async def embeddings(
self, self,

View file

@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
@ -334,6 +336,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -398,6 +424,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,

View file

@ -226,7 +226,6 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments), arguments_json=json.dumps(tool_arguments),
) )
) )
content = ""
return RawMessage( return RawMessage(
role="assistant", role="assistant",

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args) return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs): def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args self.args = args
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -149,7 +154,7 @@ class Llama3:
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
self, self,
model_inputs: List[LLMInput], llm_inputs: List[LLMInput],
temperature: float = 0.6, temperature: float = 0.6,
top_p: float = 0.9, top_p: float = 0.9,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
@ -164,15 +169,15 @@ class Llama3:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input: if print_model_input:
for inp in model_inputs: for inp in llm_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens] tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint( cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red", "red",
) )
prompt_tokens = [inp.tokens for inp in model_inputs] prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(model_inputs) bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens) min_prompt_len = min(len(t) for t in prompt_tokens)
@ -193,8 +198,8 @@ class Llama3:
is_vision = not isinstance(self.model, Transformer) is_vision = not isinstance(self.model, Transformer)
if is_vision: if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs] images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs] mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks( xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images, batch_images=images,
@ -229,7 +234,7 @@ class Llama3:
for cur_pos in range(min_prompt_len, total_len): for cur_pos in range(min_prompt_len, total_len):
if is_vision: if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long) position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = all(inp.vision is None for inp in model_inputs) text_only_inference = all(inp.vision is None for inp in llm_inputs)
logits = self.model.forward( logits = self.model.forward(
position_ids, position_ids,
tokens, tokens,
@ -285,7 +290,7 @@ class Llama3:
source="output", source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx, batch_idx=idx,
finished=eos_reached[idx], finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]), ignore_token=cur_pos < len(prompt_tokens[idx]),
) )
) )

View file

@ -301,7 +301,6 @@ class ChatFormat:
arguments=tool_arguments, arguments=tool_arguments,
) )
) )
content = ""
return RawMessage( return RawMessage(
role="assistant", role="assistant",

View file

@ -233,7 +233,7 @@ class Llama4:
source="output", source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx, batch_idx=idx,
finished=eos_reached[idx], finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]), ignore_token=cur_pos < len(prompt_tokens[idx]),
) )
) )

View file

@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}", model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"model": model, "model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir, "checkpoint_dir": checkpoint_dir,
"quantization": { "quantization": {
"type": quantization_type, "type": quantization_type,
}, },
"model_parallel_size": model_parallel_size, "model_parallel_size": model_parallel_size,
"max_batch_size": max_batch_size,
"max_seq_len": max_seq_len,
} }

View file

@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model) return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now class LlamaGenerator:
class Llama4Generator:
def __init__( def __init__(
self, self,
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
@ -144,7 +143,8 @@ class Llama4Generator:
else: else:
quantization_mode = None quantization_mode = None
self.inner_generator = Llama4.build( cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
self.inner_generator = cls.build(
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
@ -158,142 +158,55 @@ class Llama4Generator:
def completion( def completion(
self, self,
request: CompletionRequestWithRawContent, request_batch: List[CompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params or SamplingParams() first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate( for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)], llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(first_request.logprobs),
echo=False, echo=False,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
self.tokenizer, self.tokenizer,
self.args.vocab_size, self.args.vocab_size,
request.response_format, first_request.response_format,
), ),
): ):
yield result[0] yield result
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequestWithRawContent, request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params or SamplingParams() first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate( for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(first_request.logprobs),
echo=False, echo=False,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
self.tokenizer, self.tokenizer,
self.args.vocab_size, self.args.vocab_size,
request.response_format, first_request.response_format,
), ),
): ):
yield result[0] yield result
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]

View file

@ -5,10 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import logging
import os import os
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -17,6 +17,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -38,8 +40,10 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -65,21 +69,17 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generators import Llama3Generator, Llama4Generator from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__) log = get_logger(__name__, category="inference")
# there's a single model parallel process running serving the model. for now, # there's a single model parallel process running serving the model. for now,
# we don't support multiple concurrent requests to this process. # we don't support multiple concurrent requests to this process.
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return Llama3Generator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
class MetaReferenceInferenceImpl( class MetaReferenceInferenceImpl(
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model] builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn, builder_fn=llama_builder_fn,
builder_params=builder_params, builder_params=builder_params,
formatter=( formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance()) Llama4ChatFormat(Llama4Tokenizer.get_instance())
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
) )
self.generator.start() self.generator.start()
else: else:
self.generator = builder_fn(*builder_params) self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id self.model_id = model_id
self.llama_model = llama_model self.llama_model = llama_model
log.info("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
log.info("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None: if self.model_id is None or self.llama_model is None:
raise RuntimeError( raise RuntimeError(
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream: if request.stream:
return self._stream_completion(request) return self._stream_completion(request)
else: else:
return await self._nonstream_completion(request) results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl(): for x in impl():
yield x yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl(): def impl():
tokens = [] states = [ItemState() for _ in request_batch]
logprobs = []
stop_reason = None
for token_result in self.generator.completion(request): results = []
tokens.append(token_result.token) for token_results in self.generator.completion(request_batch):
if token_result.token == tokenizer.eot_id: for result in token_results:
stop_reason = StopReason.end_of_turn idx = result.batch_idx
elif token_result.token == tokenizer.eom_id: state = states[idx]
stop_reason = StopReason.end_of_message if state.finished or result.ignore_token:
continue
if request.logprobs: state.finished = result.finished
assert len(token_result.logprobs) == 1 if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if stop_reason is None: for state in states:
stop_reason = StopReason.out_of_tokens if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1] state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens) content = self.generator.formatter.tokenizer.decode(state.tokens)
return CompletionResponse( results.append(
content=content, CompletionResponse(
stop_reason=stop_reason, content=content,
logprobs=logprobs if request.logprobs else None, stop_reason=state.stop_reason,
) logprobs=state.logprobs if first_request.logprobs else None,
)
)
return results
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config, tool_config=tool_config or ToolConfig(),
) )
self.check_model(request) self.check_model(request)
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return await self._nonstream_chat_completion(request) results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl(): def impl():
tokens = [] states = [ItemState() for _ in request_batch]
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(request): for token_results in self.generator.chat_completion(request_batch):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": first = token_results[0]
cprint(token_result.text, "cyan", end="") if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
tokens.append(token_result.token) for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == tokenizer.eot_id: state.finished = result.finished
stop_reason = StopReason.end_of_turn if first_request.logprobs:
elif token_result.token == tokenizer.eom_id: state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
stop_reason = StopReason.end_of_message
if request.logprobs: state.tokens.append(result.token)
assert len(token_result.logprobs) == 1 if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if stop_reason is None: raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
stop_reason = StopReason.out_of_tokens results.append(
ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=state.logprobs if first_request.logprobs else None,
)
)
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason) return results
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls,
),
logprobs=logprobs if request.logprobs else None,
)
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request): for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="") cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token) tokens.append(token_result.token)

View file

@ -6,7 +6,7 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Callable, Generator from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any): def __call__(self, task: Any):
if isinstance(req, ChatCompletionRequestWithRawContent): if task[0] == "chat_completion":
return self.llama.chat_completion(req) return self.llama.chat_completion(task[1])
elif isinstance(req, CompletionRequestWithRawContent): elif task[0] == "completion":
return self.llama.completion(req) return self.llama.completion(task[1])
else: else:
raise ValueError(f"Unexpected task type {type(req)}") raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb( def init_model_cb(
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion( def completion(
self, self,
request: CompletionRequestWithRawContent, request_batch: List[CompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request_batch)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(("completion", req_obj))
yield from gen yield from gen
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequestWithRawContent, request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request_batch)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen yield from gen

View file

@ -19,7 +19,7 @@ import tempfile
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
import torch import torch
import zmq import zmq
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent] task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: GenerationResult result: List[GenerationResult]
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
def run_inference( def run_inference(
self, self,
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent], req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
Inference, Inference,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion") raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
} }
return await self.openai_client.chat.completions.create(**params) # type: ignore return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user, user=user,
) )
return await self.client.chat.completions.create(**params) # type: ignore return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
user=user, user=user,
) )
return litellm.completion(**params) return litellm.completion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -28,11 +29,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.SAFETY_MODEL} model: ${env.SAFETY_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..test_cases.test_case import TestCase
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type not in ("inline::meta-reference",):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:batch_completion",
],
)
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
content_batch = tc["contents"]
response = client_with_models.inference.batch_completion(
content_batch=content_batch,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
},
)
assert len(response.batch) == len(content_batch)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.content}")
assert len(r.content) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:batch_completion",
],
)
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
tc = TestCase(test_case)
qa_pairs = tc["qa_pairs"]
message_batch = [
[
{
"role": "user",
"content": qa["question"],
}
]
for qa in qa_pairs
]
response = client_with_models.inference.batch_chat_completion(
messages_batch=message_batch,
model_id=text_model_id,
)
assert len(response.batch) == len(qa_pairs)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.completion_message.content}")
assert len(r.completion_message.content) > 0
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()

View file

@ -537,5 +537,31 @@
} }
] ]
} }
},
"batch_completion": {
"data": {
"qa_pairs": [
{
"question": "What is the capital of France?",
"answer": "Paris"
},
{
"question": "Who wrote the book '1984'?",
"answer": "George Orwell"
},
{
"question": "Which planet has rings around it with a name starting with letter S?",
"answer": "Saturn"
},
{
"question": "When did the first moon landing happen?",
"answer": "1969"
},
{
"question": "What word says 'hello' in Spanish?",
"answer": "Hola"
}
]
}
} }
} }

View file

@ -44,5 +44,18 @@
"year_retired": "2003" "year_retired": "2003"
} }
} }
},
"batch_completion": {
"data": {
"contents": [
"Micheael Jordan is born in ",
"Roses are red, violets are ",
"If you had a million dollars, what would you do with it? ",
"All you need is ",
"The capital of France is ",
"It is a good day to ",
"The answer to the universe is "
]
}
} }
} }