feat: add batch inference API to llama stack inference

This commit is contained in:
Ashwin Bharambe 2025-04-08 13:50:52 -07:00
parent ed58a94b30
commit 0cfb2e2473
24 changed files with 1041 additions and 377 deletions

View file

@ -85,7 +85,50 @@
} }
} }
}, },
"/v1/batch-inference/chat-completion": { "/v1/inference/batch-chat-completion": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BatchChatCompletionResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BatchChatCompletionRequest"
}
}
},
"required": true
}
}
},
"/v1/batch-inference/chat-completion-inline": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -120,7 +163,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/BatchChatCompletionRequest" "$ref": "#/components/schemas/BatchChatCompletionInlineRequest"
} }
} }
}, },
@ -128,7 +171,50 @@
} }
} }
}, },
"/v1/batch-inference/completion": { "/v1/inference/batch-completion": {
"post": {
"responses": {
"200": {
"description": "OK",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BatchCompletionResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/BatchCompletionRequest"
}
}
},
"required": true
}
}
},
"/v1/batch-inference/completion-inline": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -163,7 +249,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/BatchCompletionRequest" "$ref": "#/components/schemas/BatchCompletionInlineRequest"
} }
} }
}, },
@ -4366,6 +4452,51 @@
], ],
"title": "ToolCall" "title": "ToolCall"
}, },
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ToolDefinition": { "ToolDefinition": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4554,7 +4685,7 @@
"BatchChatCompletionRequest": { "BatchChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model_id": {
"type": "string" "type": "string"
}, },
"messages_batch": { "messages_batch": {
@ -4575,25 +4706,8 @@
"$ref": "#/components/schemas/ToolDefinition" "$ref": "#/components/schemas/ToolDefinition"
} }
}, },
"tool_choice": { "tool_config": {
"type": "string", "$ref": "#/components/schemas/ToolConfig"
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"title": "ToolPromptFormat",
"description": "Prompt format for calling custom / zero shot tools."
}, },
"response_format": { "response_format": {
"$ref": "#/components/schemas/ResponseFormat" "$ref": "#/components/schemas/ResponseFormat"
@ -4613,7 +4727,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"model", "model_id",
"messages_batch" "messages_batch"
], ],
"title": "BatchChatCompletionRequest" "title": "BatchChatCompletionRequest"
@ -4707,12 +4821,62 @@
"title": "TokenLogProbs", "title": "TokenLogProbs",
"description": "Log probabilities for generated tokens." "description": "Log probabilities for generated tokens."
}, },
"BatchCompletionRequest": { "BatchChatCompletionInlineRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
"type": "string" "type": "string"
}, },
"messages_batch": {
"type": "array",
"items": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Message"
}
}
},
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"tools": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolDefinition"
}
},
"tool_config": {
"$ref": "#/components/schemas/ToolConfig"
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat"
},
"logprobs": {
"type": "object",
"properties": {
"top_k": {
"type": "integer",
"default": 0,
"description": "How many tokens (for each position) to return log probabilities for."
}
},
"additionalProperties": false,
"title": "LogProbConfig"
}
},
"additionalProperties": false,
"required": [
"model",
"messages_batch"
],
"title": "BatchChatCompletionInlineRequest"
},
"BatchCompletionRequest": {
"type": "object",
"properties": {
"model_id": {
"type": "string"
},
"content_batch": { "content_batch": {
"type": "array", "type": "array",
"items": { "items": {
@ -4740,7 +4904,7 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"model", "model_id",
"content_batch" "content_batch"
], ],
"title": "BatchCompletionRequest" "title": "BatchCompletionRequest"
@ -4799,6 +4963,44 @@
"title": "CompletionResponse", "title": "CompletionResponse",
"description": "Response from a completion request." "description": "Response from a completion request."
}, },
"BatchCompletionInlineRequest": {
"type": "object",
"properties": {
"model": {
"type": "string"
},
"content_batch": {
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContent"
}
},
"sampling_params": {
"$ref": "#/components/schemas/SamplingParams"
},
"response_format": {
"$ref": "#/components/schemas/ResponseFormat"
},
"logprobs": {
"type": "object",
"properties": {
"top_k": {
"type": "integer",
"default": 0,
"description": "How many tokens (for each position) to return log probabilities for."
}
},
"additionalProperties": false,
"title": "LogProbConfig"
}
},
"additionalProperties": false,
"required": [
"model",
"content_batch"
],
"title": "BatchCompletionInlineRequest"
},
"CancelTrainingJobRequest": { "CancelTrainingJobRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -4812,51 +5014,6 @@
], ],
"title": "CancelTrainingJobRequest" "title": "CancelTrainingJobRequest"
}, },
"ToolConfig": {
"type": "object",
"properties": {
"tool_choice": {
"oneOf": [
{
"type": "string",
"enum": [
"auto",
"required",
"none"
],
"title": "ToolChoice",
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
},
{
"type": "string"
}
],
"default": "auto",
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
},
"tool_prompt_format": {
"type": "string",
"enum": [
"json",
"function_tag",
"python_list"
],
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
},
"system_message_behavior": {
"type": "string",
"enum": [
"append",
"replace"
],
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
"default": "append"
}
},
"additionalProperties": false,
"title": "ToolConfig",
"description": "Configuration for tool use."
},
"ChatCompletionRequest": { "ChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -40,7 +40,36 @@ paths:
schema: schema:
$ref: '#/components/schemas/AppendRowsRequest' $ref: '#/components/schemas/AppendRowsRequest'
required: true required: true
/v1/batch-inference/chat-completion: /v1/inference/batch-chat-completion:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/BatchChatCompletionResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: ''
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/BatchChatCompletionRequest'
required: true
/v1/batch-inference/chat-completion-inline:
post: post:
responses: responses:
'200': '200':
@ -67,9 +96,38 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/BatchChatCompletionRequest' $ref: '#/components/schemas/BatchChatCompletionInlineRequest'
required: true required: true
/v1/batch-inference/completion: /v1/inference/batch-completion:
post:
responses:
'200':
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/BatchCompletionResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: ''
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/BatchCompletionRequest'
required: true
/v1/batch-inference/completion-inline:
post: post:
responses: responses:
'200': '200':
@ -96,7 +154,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/BatchCompletionRequest' $ref: '#/components/schemas/BatchCompletionInlineRequest'
required: true required: true
/v1/post-training/job/cancel: /v1/post-training/job/cancel:
post: post:
@ -3009,6 +3067,54 @@ components:
- tool_name - tool_name
- arguments - arguments
title: ToolCall title: ToolCall
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ToolDefinition: ToolDefinition:
type: object type: object
properties: properties:
@ -3145,7 +3251,7 @@ components:
BatchChatCompletionRequest: BatchChatCompletionRequest:
type: object type: object
properties: properties:
model: model_id:
type: string type: string
messages_batch: messages_batch:
type: array type: array
@ -3159,26 +3265,8 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/ToolDefinition' $ref: '#/components/schemas/ToolDefinition'
tool_choice: tool_config:
type: string $ref: '#/components/schemas/ToolConfig'
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following capabilities
of the model.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
title: ToolPromptFormat
description: >-
Prompt format for calling custom / zero shot tools.
response_format: response_format:
$ref: '#/components/schemas/ResponseFormat' $ref: '#/components/schemas/ResponseFormat'
logprobs: logprobs:
@ -3193,7 +3281,7 @@ components:
title: LogProbConfig title: LogProbConfig
additionalProperties: false additionalProperties: false
required: required:
- model - model_id
- messages_batch - messages_batch
title: BatchChatCompletionRequest title: BatchChatCompletionRequest
BatchChatCompletionResponse: BatchChatCompletionResponse:
@ -3258,11 +3346,47 @@ components:
- logprobs_by_token - logprobs_by_token
title: TokenLogProbs title: TokenLogProbs
description: Log probabilities for generated tokens. description: Log probabilities for generated tokens.
BatchCompletionRequest: BatchChatCompletionInlineRequest:
type: object type: object
properties: properties:
model: model:
type: string type: string
messages_batch:
type: array
items:
type: array
items:
$ref: '#/components/schemas/Message'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
tools:
type: array
items:
$ref: '#/components/schemas/ToolDefinition'
tool_config:
$ref: '#/components/schemas/ToolConfig'
response_format:
$ref: '#/components/schemas/ResponseFormat'
logprobs:
type: object
properties:
top_k:
type: integer
default: 0
description: >-
How many tokens (for each position) to return log probabilities for.
additionalProperties: false
title: LogProbConfig
additionalProperties: false
required:
- model
- messages_batch
title: BatchChatCompletionInlineRequest
BatchCompletionRequest:
type: object
properties:
model_id:
type: string
content_batch: content_batch:
type: array type: array
items: items:
@ -3283,7 +3407,7 @@ components:
title: LogProbConfig title: LogProbConfig
additionalProperties: false additionalProperties: false
required: required:
- model - model_id
- content_batch - content_batch
title: BatchCompletionRequest title: BatchCompletionRequest
BatchCompletionResponse: BatchCompletionResponse:
@ -3326,6 +3450,34 @@ components:
- stop_reason - stop_reason
title: CompletionResponse title: CompletionResponse
description: Response from a completion request. description: Response from a completion request.
BatchCompletionInlineRequest:
type: object
properties:
model:
type: string
content_batch:
type: array
items:
$ref: '#/components/schemas/InterleavedContent'
sampling_params:
$ref: '#/components/schemas/SamplingParams'
response_format:
$ref: '#/components/schemas/ResponseFormat'
logprobs:
type: object
properties:
top_k:
type: integer
default: 0
description: >-
How many tokens (for each position) to return log probabilities for.
additionalProperties: false
title: LogProbConfig
additionalProperties: false
required:
- model
- content_batch
title: BatchCompletionInlineRequest
CancelTrainingJobRequest: CancelTrainingJobRequest:
type: object type: object
properties: properties:
@ -3335,54 +3487,6 @@ components:
required: required:
- job_uuid - job_uuid
title: CancelTrainingJobRequest title: CancelTrainingJobRequest
ToolConfig:
type: object
properties:
tool_choice:
oneOf:
- type: string
enum:
- auto
- required
- none
title: ToolChoice
description: >-
Whether tool use is required or automatic. This is a hint to the model
which may not be followed. It depends on the Instruction Following
capabilities of the model.
- type: string
default: auto
description: >-
(Optional) Whether tool use is automatic, required, or none. Can also
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
tool_prompt_format:
type: string
enum:
- json
- function_tag
- python_list
description: >-
(Optional) Instructs the model how to format tool calls. By default, Llama
Stack will attempt to use a format that is best adapted to the model.
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
syntax -- a list of function calls.
system_message_behavior:
type: string
enum:
- append
- replace
description: >-
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
Replaces the default system prompt with the provided system message. The
system message can include the string '{{function_definitions}}' to indicate
where the function definitions should be inserted.
default: append
additionalProperties: false
title: ToolConfig
description: Configuration for tool use.
ChatCompletionRequest: ChatCompletionRequest:
type: object type: object
properties: properties:

View file

@ -6,37 +6,24 @@
from typing import List, Optional, Protocol, runtime_checkable from typing import List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, BatchChatCompletionResponse,
CompletionResponse, BatchCompletionResponse,
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.schema_utils import webmethod
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
class BatchInference(Protocol): class BatchInference(Protocol):
@webmethod(route="/batch-inference/completion", method="POST") @webmethod(route="/batch-inference/completion-inline", method="POST")
async def batch_completion( async def batch_completion_inline(
self, self,
model: str, model: str,
content_batch: List[InterleavedContent], content_batch: List[InterleavedContent],
@ -45,16 +32,14 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion-inline", method="POST")
async def batch_chat_completion( async def batch_chat_completion_inline(
self, self,
model: str, model: str,
messages_batch: List[List[Message]], messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
# zero-shot tool definitions as input to the model
tools: Optional[List[ToolDefinition]] = list, tools: Optional[List[ToolDefinition]] = list,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_config: Optional[ToolConfig] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse: ... ) -> BatchChatCompletionResponse: ...

View file

@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
document = "document" document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
batch: List[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
batch: List[ChatCompletionResponse]
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Inference(Protocol): class Inference(Protocol):
@ -716,6 +726,17 @@ class Inference(Protocol):
""" """
... ...
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
self, self,
@ -756,6 +777,19 @@ class Inference(Protocol):
""" """
... ...
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
async def embeddings( async def embeddings(
self, self,

View file

@ -400,6 +400,9 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
mro = type(obj).__mro__ mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol): for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"): if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if value.__webmethod__.experimental:
continue
if not hasattr(obj, name): if not hasattr(obj, name):
missing_methods.append((name, "missing")) missing_methods.append((name, "missing"))
elif not callable(getattr(obj, name)): elif not callable(getattr(obj, name)):

View file

@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
@ -334,6 +336,30 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -398,6 +424,20 @@ class InferenceRouter(Inference):
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = metrics if response.metrics is None else response.metrics + metrics
return response return response
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,

View file

@ -140,7 +140,12 @@ class Llama3:
return Llama3(model, tokenizer, model_args) return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs): def __init__(
self,
model: Transformer | CrossAttentionTransformer,
tokenizer: Tokenizer,
args: ModelArgs,
):
self.args = args self.args = args
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -285,7 +290,7 @@ class Llama3:
source="output", source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx, batch_idx=idx,
finished=eos_reached[idx], finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]), ignore_token=cur_pos < len(prompt_tokens[idx]),
) )
) )

View file

@ -233,7 +233,7 @@ class Llama4:
source="output", source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None), logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx, batch_idx=idx,
finished=eos_reached[idx], finished=eos_reached[idx].item(),
ignore_token=cur_pos < len(prompt_tokens[idx]), ignore_token=cur_pos < len(prompt_tokens[idx]),
) )
) )

View file

@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}", checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}", quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}", model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"model": model, "model": model,
"max_seq_len": 4096,
"checkpoint_dir": checkpoint_dir, "checkpoint_dir": checkpoint_dir,
"quantization": { "quantization": {
"type": quantization_type, "type": quantization_type,
}, },
"model_parallel_size": model_parallel_size, "model_parallel_size": model_parallel_size,
"max_batch_size": max_batch_size,
"max_seq_len": max_seq_len,
} }

View file

@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent, ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent, CompletionRequestWithRawContent,
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
return get_default_tool_prompt_format(request.model) return get_default_tool_prompt_format(request.model)
# TODO: combine Llama3 and Llama4 generators since they are almost identical now class LlamaGenerator:
class Llama4Generator:
def __init__( def __init__(
self, self,
config: MetaReferenceInferenceConfig, config: MetaReferenceInferenceConfig,
@ -144,7 +143,8 @@ class Llama4Generator:
else: else:
quantization_mode = None quantization_mode = None
self.inner_generator = Llama4.build( cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
self.inner_generator = cls.build(
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len, max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
@ -158,142 +158,55 @@ class Llama4Generator:
def completion( def completion(
self, self,
request: CompletionRequestWithRawContent, request_batch: List[CompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params or SamplingParams() first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate( for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_content(request.content)], llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(first_request.logprobs),
echo=False, echo=False,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
self.tokenizer, self.tokenizer,
self.args.vocab_size, self.args.vocab_size,
request.response_format, first_request.response_format,
), ),
): ):
yield result[0] yield result
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequestWithRawContent, request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
sampling_params = request.sampling_params or SamplingParams() first_request = request_batch[0]
sampling_params = first_request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len: if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1 max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params) temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate( for result in self.inner_generator.generate(
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))], llm_inputs=[
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
for request in request_batch
],
max_gen_len=max_gen_len, max_gen_len=max_gen_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
logprobs=bool(request.logprobs), logprobs=bool(first_request.logprobs),
echo=False, echo=False,
logits_processor=get_logits_processor( logits_processor=get_logits_processor(
self.tokenizer, self.tokenizer,
self.args.vocab_size, self.args.vocab_size,
request.response_format, first_request.response_format,
), ),
): ):
yield result[0] yield result
class Llama3Generator:
def __init__(
self,
config: MetaReferenceInferenceConfig,
model_id: str,
llama_model: Model,
):
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
resolved_model = resolve_model(model_id)
if resolved_model is None:
# if the model is not a native llama model, get the default checkpoint_dir based on model id
ckpt_dir = model_checkpoint_dir(model_id)
else:
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
if config.quantization:
if config.quantization.type == "fp8_mixed":
quantization_mode = QuantizationMode.fp8_mixed
elif config.quantization.type == "int4_mixed":
quantization_mode = QuantizationMode.int4_mixed
elif config.quantization.type == "bf16":
quantization_mode = None
else:
raise ValueError(f"Unsupported quantization mode {config.quantization}")
else:
quantization_mode = None
self.inner_generator = Llama3.build(
ckpt_dir=ckpt_dir,
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
world_size=config.model_parallel_size or llama_model.pth_file_count,
quantization_mode=quantization_mode,
)
self.tokenizer = self.inner_generator.tokenizer
self.args = self.inner_generator.args
self.formatter = self.inner_generator.formatter
def completion(
self,
request: CompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_content(request.content)],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
sampling_params = request.sampling_params or SamplingParams()
max_gen_len = sampling_params.max_tokens
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
temperature, top_p = _infer_sampling_params(sampling_params)
for result in self.inner_generator.generate(
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=bool(request.logprobs),
echo=False,
logits_processor=get_logits_processor(
self.tokenizer,
self.args.vocab_size,
request.response_format,
),
):
yield result[0]

View file

@ -9,6 +9,7 @@ import logging
import os import os
from typing import AsyncGenerator, List, Optional, Union from typing import AsyncGenerator, List, Optional, Union
from pydantic import BaseModel
from termcolor import cprint from termcolor import cprint
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -17,6 +18,8 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus, ToolCallParseStatus,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
@ -38,6 +41,7 @@ from llama_stack.apis.inference import (
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
@ -65,7 +69,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generators import Llama3Generator, Llama4Generator from .generators import LlamaGenerator
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -74,12 +78,8 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator: def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
return Llama3Generator(config, model_id, llama_model) return LlamaGenerator(config, model_id, llama_model)
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
return Llama4Generator(config, model_id, llama_model)
class MetaReferenceInferenceImpl( class MetaReferenceInferenceImpl(
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
async def load_model(self, model_id, llama_model) -> None: async def load_model(self, model_id, llama_model) -> None:
log.info(f"Loading model `{model_id}`") log.info(f"Loading model `{model_id}`")
if llama_model.model_family in {
ModelFamily.llama3,
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
}:
builder_fn = llama3_builder_fn
elif llama_model.model_family == ModelFamily.llama4:
builder_fn = llama4_builder_fn
else:
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
builder_params = [self.config, model_id, llama_model] builder_params = [self.config, model_id, llama_model]
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator = LlamaModelParallelGenerator( self.generator = LlamaModelParallelGenerator(
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count, model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
builder_fn=builder_fn, builder_fn=llama_builder_fn,
builder_params=builder_params, builder_params=builder_params,
formatter=( formatter=(
Llama4ChatFormat(Llama4Tokenizer.get_instance()) Llama4ChatFormat(Llama4Tokenizer.get_instance())
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
) )
self.generator.start() self.generator.start()
else: else:
self.generator = builder_fn(*builder_params) self.generator = llama_builder_fn(*builder_params)
self.model_id = model_id self.model_id = model_id
self.llama_model = llama_model self.llama_model = llama_model
print("Warming up...")
await self.completion(
model_id=model_id,
content="Hello, world!",
sampling_params=SamplingParams(max_tokens=10),
)
await self.chat_completion(
model_id=model_id,
messages=[UserMessage(content="Hi how are you?")],
sampling_params=SamplingParams(max_tokens=20),
)
print("Warmed up!")
def check_model(self, request) -> None: def check_model(self, request) -> None:
if self.model_id is None or self.llama_model is None: if self.model_id is None or self.llama_model is None:
raise RuntimeError( raise RuntimeError(
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
if request.stream: if request.stream:
return self._stream_completion(request) return self._stream_completion(request)
else: else:
return await self._nonstream_completion(request) results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
for x in impl(): for x in impl():
yield x yield x
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl(): def impl():
tokens = [] states = [ItemState() for _ in request_batch]
logprobs = []
stop_reason = None
for token_result in self.generator.completion(request): results = []
tokens.append(token_result.token) for token_results in self.generator.completion(request_batch):
if token_result.token == tokenizer.eot_id: for result in token_results:
stop_reason = StopReason.end_of_turn idx = result.batch_idx
elif token_result.token == tokenizer.eom_id: state = states[idx]
stop_reason = StopReason.end_of_message if state.finished or result.ignore_token:
continue
if request.logprobs: state.finished = result.finished
assert len(token_result.logprobs) == 1 if first_request.logprobs:
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) state.tokens.append(result.token)
if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
if stop_reason is None: for state in states:
stop_reason = StopReason.out_of_tokens if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens: if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
tokens = tokens[:-1] state.tokens = state.tokens[:-1]
content = self.generator.formatter.tokenizer.decode(tokens) content = self.generator.formatter.tokenizer.decode(state.tokens)
return CompletionResponse( results.append(
CompletionResponse(
content=content, content=content,
stop_reason=stop_reason, stop_reason=state.stop_reason,
logprobs=logprobs if request.logprobs else None, logprobs=state.logprobs if first_request.logprobs else None,
) )
)
return results
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
response_format=response_format, response_format=response_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
tool_config=tool_config, tool_config=tool_config or ToolConfig(),
) )
self.check_model(request) self.check_model(request)
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
if request.stream: if request.stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
else: else:
return await self._nonstream_chat_completion(request) results = await self._nonstream_chat_completion([request])
return results[0]
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: List[ChatCompletionRequest]
) -> List[ChatCompletionResponse]:
tokenizer = self.generator.formatter.tokenizer tokenizer = self.generator.formatter.tokenizer
first_request = request_batch[0]
class ItemState(BaseModel):
tokens: List[int] = []
logprobs: List[TokenLogProbs] = []
stop_reason: StopReason | None = None
finished: bool = False
def impl(): def impl():
tokens = [] states = [ItemState() for _ in request_batch]
logprobs = []
stop_reason = None
for token_result in self.generator.chat_completion(request): for token_results in self.generator.chat_completion(request_batch):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": first = token_results[0]
cprint(token_result.text, "cyan", end="") if not first.finished:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="")
tokens.append(token_result.token) for result in token_results:
idx = result.batch_idx
state = states[idx]
if state.finished or result.ignore_token:
continue
if token_result.token == tokenizer.eot_id: state.finished = result.finished
stop_reason = StopReason.end_of_turn if first_request.logprobs:
elif token_result.token == tokenizer.eom_id: state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
stop_reason = StopReason.end_of_message
if request.logprobs: state.tokens.append(result.token)
assert len(token_result.logprobs) == 1 if result.token == tokenizer.eot_id:
state.stop_reason = StopReason.end_of_turn
elif result.token == tokenizer.eom_id:
state.stop_reason = StopReason.end_of_message
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]})) results = []
for state in states:
if state.stop_reason is None:
state.stop_reason = StopReason.out_of_tokens
if stop_reason is None: raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
stop_reason = StopReason.out_of_tokens results.append(
ChatCompletionResponse(
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
return ChatCompletionResponse(
completion_message=CompletionMessage( completion_message=CompletionMessage(
content=raw_message.content, content=raw_message.content,
stop_reason=raw_message.stop_reason, stop_reason=raw_message.stop_reason,
tool_calls=raw_message.tool_calls, tool_calls=raw_message.tool_calls,
), ),
logprobs=logprobs if request.logprobs else None, logprobs=state.logprobs if first_request.logprobs else None,
) )
)
return results
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
async with SEMAPHORE: async with SEMAPHORE:
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
for token_result in self.generator.chat_completion(request): for token_result in self.generator.chat_completion(request):
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="") cprint(token_result.text, "cyan", end="")
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="")
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""
elif token_result.token == tokenizer.eom_id:
stop_reason = StopReason.end_of_message
text = ""
else:
text = token_result.text
if request.logprobs:
assert len(token_result.logprobs) == 1
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
tokens.append(token_result.token) tokens.append(token_result.token)

View file

@ -6,7 +6,7 @@
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Any, Callable, Generator from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()` # the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any): def __call__(self, task: Any):
if isinstance(req, ChatCompletionRequestWithRawContent): if task[0] == "chat_completion":
return self.llama.chat_completion(req) return self.llama.chat_completion(task[1])
elif isinstance(req, CompletionRequestWithRawContent): elif task[0] == "completion":
return self.llama.completion(req) return self.llama.completion(task[1])
else: else:
raise ValueError(f"Unexpected task type {type(req)}") raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb( def init_model_cb(
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion( def completion(
self, self,
request: CompletionRequestWithRawContent, request_batch: List[CompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request_batch)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(("completion", req_obj))
yield from gen yield from gen
def chat_completion( def chat_completion(
self, self,
request: ChatCompletionRequestWithRawContent, request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator: ) -> Generator:
req_obj = deepcopy(request) req_obj = deepcopy(request_batch)
gen = self.group.run_inference(req_obj) gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen yield from gen

View file

@ -19,7 +19,7 @@ import tempfile
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Callable, Generator, Literal, Optional, Union from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
import torch import torch
import zmq import zmq
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel): class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent] task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
class TaskResponse(BaseModel): class TaskResponse(BaseModel):
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
result: GenerationResult result: List[GenerationResult]
class ExceptionResponse(BaseModel): class ExceptionResponse(BaseModel):
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
def run_inference( def run_inference(
self, self,
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent], req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
) -> Generator: ) -> Generator:
assert not self.running, "inference already running" assert not self.running, "inference already running"

View file

@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
Inference, Inference,
InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion") raise ValueError("Sentence transformers don't support chat completion")
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.providers.datatypes import (
Api,
InlineProviderSpec,
ProviderSpec,
)
META_REFERENCE_DEPS = [
"accelerate",
"blobfile",
"fairscale",
"torch",
"torchvision",
"transformers",
"zmq",
"lm-format-enforcer",
"sentence-transformers",
"torchao==0.5.0",
"fbgemm-gpu-genai==1.1.2",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.inference,
provider_type="inline::meta-reference",
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.batch_inference.meta_reference",
config_class="llama_stack.providers.inline.batch_inference.meta_reference.MetaReferenceInferenceConfig",
),
]

View file

@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
} }
return await self.openai_client.chat.completions.create(**params) # type: ignore return await self.openai_client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict: async def _convert_content(content) -> dict:

View file

@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
user=user, user=user,
) )
return await self.client.chat.completions.create(**params) # type: ignore return await self.client.chat.completions.create(**params) # type: ignore
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for Ollama")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for Ollama")

View file

@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
user=user, user=user,
) )
return litellm.completion(**params) return litellm.completion(**params)
async def batch_completion(
self,
model_id: str,
content_batch: List[InterleavedContent],
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
async def batch_chat_completion(
self,
model_id: str,
messages_batch: List[List[Message]],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_config: Optional[ToolConfig] = None,
response_format: Optional[ResponseFormat] = None,
logprobs: Optional[LogProbConfig] = None,
):
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")

View file

@ -20,6 +20,7 @@ class WebMethod:
raw_bytes_request_body: Optional[bool] = False raw_bytes_request_body: Optional[bool] = False
# A descriptive name of the corresponding span created by tracing # A descriptive name of the corresponding span created by tracing
descriptive_name: Optional[str] = None descriptive_name: Optional[str] = None
experimental: Optional[bool] = False
T = TypeVar("T", bound=Callable[..., Any]) T = TypeVar("T", bound=Callable[..., Any])
@ -33,6 +34,7 @@ def webmethod(
response_examples: Optional[List[Any]] = None, response_examples: Optional[List[Any]] = None,
raw_bytes_request_body: Optional[bool] = False, raw_bytes_request_body: Optional[bool] = False,
descriptive_name: Optional[str] = None, descriptive_name: Optional[str] = None,
experimental: Optional[bool] = False,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
""" """
Decorator that supplies additional metadata to an endpoint operation function. Decorator that supplies additional metadata to an endpoint operation function.
@ -52,6 +54,7 @@ def webmethod(
response_examples=response_examples, response_examples=response_examples,
raw_bytes_request_body=raw_bytes_request_body, raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name, descriptive_name=descriptive_name,
experimental=experimental,
) )
return func return func

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -28,11 +29,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.SAFETY_MODEL} model: ${env.SAFETY_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null} checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss

View file

@ -16,11 +16,12 @@ providers:
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization:
type: ${env.QUANTIZATION_TYPE:bf16} type: ${env.QUANTIZATION_TYPE:bf16}
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0} model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
max_batch_size: ${env.MAX_BATCH_SIZE:1}
max_seq_len: ${env.MAX_SEQ_LEN:4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}

View file

@ -0,0 +1,111 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.models.llama.sku_list import resolve_model
from ..test_cases.test_case import TestCase
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
models = {m.identifier: m for m in client_with_models.models.list()}
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type in (
"remote::openai",
"remote::anthropic",
"remote::gemini",
"remote::groq",
"remote::llama-openai-compat",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
def get_llama_model(client_with_models, model_id):
models = {}
for m in client_with_models.models.list():
models[m.identifier] = m
models[m.provider_resource_id] = m
assert model_id in models, f"Model {model_id} not found"
model = models[model_id]
ids = (model.identifier, model.provider_resource_id)
for mid in ids:
if resolve_model(mid):
return mid
return model.metadata.get("llama_model", None)
def get_llama_tokenizer():
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
tokenizer = Tokenizer.get_instance()
formatter = ChatFormat(tokenizer)
return tokenizer, formatter
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:batch_completion",
],
)
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
content_batch = tc["contents"]
response = client_with_models.inference.batch_completion(
content_batch=content_batch,
model_id=text_model_id,
sampling_params={
"max_tokens": 50,
},
)
assert len(response.batch) == len(content_batch)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.content}")
assert len(r.content) > 10
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:batch_completion",
],
)
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
qa_pairs = tc["qa_pairs"]
message_batch = [
[
{
"role": "user",
"content": qa["question"],
}
]
for qa in qa_pairs
]
response = client_with_models.inference.batch_chat_completion(
messages_batch=message_batch,
model_id=text_model_id,
)
assert len(response.batch) == len(qa_pairs)
for i, r in enumerate(response.batch):
print(f"response {i}: {r.completion_message.content}")
assert len(r.completion_message.content) > 0
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()

View file

@ -537,5 +537,31 @@
} }
] ]
} }
},
"batch_completion": {
"data": {
"qa_pairs": [
{
"question": "What is the capital of France?",
"answer": "Paris"
},
{
"question": "Who wrote the book '1984'?",
"answer": "George Orwell"
},
{
"question": "Which planet has rings around it with a name starting with letter S?",
"answer": "Saturn"
},
{
"question": "When did the first moon landing happen?",
"answer": "1969"
},
{
"question": "What word says 'hello' in Spanish?",
"answer": "Hola"
}
]
}
} }
} }

View file

@ -44,5 +44,18 @@
"year_retired": "2003" "year_retired": "2003"
} }
} }
},
"batch_completion": {
"data": {
"contents": [
"Micheael Jordan is born in ",
"Roses are red, violets are ",
"If you had a million dollars, what would you do with it? ",
"All you need is ",
"The capital of France is ",
"It is a good day to ",
"The answer to the universe is "
]
}
} }
} }