mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
feat: add batch inference API to llama stack inference
This commit is contained in:
parent
ed58a94b30
commit
0cfb2e2473
24 changed files with 1041 additions and 377 deletions
301
docs/_static/llama-stack-spec.html
vendored
301
docs/_static/llama-stack-spec.html
vendored
|
@ -85,7 +85,50 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/batch-inference/chat-completion": {
|
"/v1/inference/batch-chat-completion": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/BatchChatCompletionResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/BatchChatCompletionRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/v1/batch-inference/chat-completion-inline": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -120,7 +163,7 @@
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/BatchChatCompletionRequest"
|
"$ref": "#/components/schemas/BatchChatCompletionInlineRequest"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -128,7 +171,50 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/batch-inference/completion": {
|
"/v1/inference/batch-completion": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "OK",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/BatchCompletionResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/BatchCompletionRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"/v1/batch-inference/completion-inline": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -163,7 +249,7 @@
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/BatchCompletionRequest"
|
"$ref": "#/components/schemas/BatchCompletionInlineRequest"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -4366,6 +4452,51 @@
|
||||||
],
|
],
|
||||||
"title": "ToolCall"
|
"title": "ToolCall"
|
||||||
},
|
},
|
||||||
|
"ToolConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_choice": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"auto",
|
||||||
|
"required",
|
||||||
|
"none"
|
||||||
|
],
|
||||||
|
"title": "ToolChoice",
|
||||||
|
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "auto",
|
||||||
|
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
|
||||||
|
},
|
||||||
|
"tool_prompt_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"json",
|
||||||
|
"function_tag",
|
||||||
|
"python_list"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
||||||
|
},
|
||||||
|
"system_message_behavior": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"append",
|
||||||
|
"replace"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
|
||||||
|
"default": "append"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "ToolConfig",
|
||||||
|
"description": "Configuration for tool use."
|
||||||
|
},
|
||||||
"ToolDefinition": {
|
"ToolDefinition": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4554,7 +4685,7 @@
|
||||||
"BatchChatCompletionRequest": {
|
"BatchChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"messages_batch": {
|
"messages_batch": {
|
||||||
|
@ -4575,25 +4706,8 @@
|
||||||
"$ref": "#/components/schemas/ToolDefinition"
|
"$ref": "#/components/schemas/ToolDefinition"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tool_choice": {
|
"tool_config": {
|
||||||
"type": "string",
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
"enum": [
|
|
||||||
"auto",
|
|
||||||
"required",
|
|
||||||
"none"
|
|
||||||
],
|
|
||||||
"title": "ToolChoice",
|
|
||||||
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
|
||||||
},
|
|
||||||
"tool_prompt_format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"json",
|
|
||||||
"function_tag",
|
|
||||||
"python_list"
|
|
||||||
],
|
|
||||||
"title": "ToolPromptFormat",
|
|
||||||
"description": "Prompt format for calling custom / zero shot tools."
|
|
||||||
},
|
},
|
||||||
"response_format": {
|
"response_format": {
|
||||||
"$ref": "#/components/schemas/ResponseFormat"
|
"$ref": "#/components/schemas/ResponseFormat"
|
||||||
|
@ -4613,7 +4727,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"messages_batch"
|
"messages_batch"
|
||||||
],
|
],
|
||||||
"title": "BatchChatCompletionRequest"
|
"title": "BatchChatCompletionRequest"
|
||||||
|
@ -4707,12 +4821,62 @@
|
||||||
"title": "TokenLogProbs",
|
"title": "TokenLogProbs",
|
||||||
"description": "Log probabilities for generated tokens."
|
"description": "Log probabilities for generated tokens."
|
||||||
},
|
},
|
||||||
"BatchCompletionRequest": {
|
"BatchChatCompletionInlineRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
|
"messages_batch": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/Message"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sampling_params": {
|
||||||
|
"$ref": "#/components/schemas/SamplingParams"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/ToolDefinition"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tool_config": {
|
||||||
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
|
},
|
||||||
|
"response_format": {
|
||||||
|
"$ref": "#/components/schemas/ResponseFormat"
|
||||||
|
},
|
||||||
|
"logprobs": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 0,
|
||||||
|
"description": "How many tokens (for each position) to return log probabilities for."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "LogProbConfig"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"messages_batch"
|
||||||
|
],
|
||||||
|
"title": "BatchChatCompletionInlineRequest"
|
||||||
|
},
|
||||||
|
"BatchCompletionRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model_id": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
"content_batch": {
|
"content_batch": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
|
@ -4740,7 +4904,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"content_batch"
|
"content_batch"
|
||||||
],
|
],
|
||||||
"title": "BatchCompletionRequest"
|
"title": "BatchCompletionRequest"
|
||||||
|
@ -4799,6 +4963,44 @@
|
||||||
"title": "CompletionResponse",
|
"title": "CompletionResponse",
|
||||||
"description": "Response from a completion request."
|
"description": "Response from a completion request."
|
||||||
},
|
},
|
||||||
|
"BatchCompletionInlineRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"content_batch": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/InterleavedContent"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sampling_params": {
|
||||||
|
"$ref": "#/components/schemas/SamplingParams"
|
||||||
|
},
|
||||||
|
"response_format": {
|
||||||
|
"$ref": "#/components/schemas/ResponseFormat"
|
||||||
|
},
|
||||||
|
"logprobs": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"default": 0,
|
||||||
|
"description": "How many tokens (for each position) to return log probabilities for."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "LogProbConfig"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"content_batch"
|
||||||
|
],
|
||||||
|
"title": "BatchCompletionInlineRequest"
|
||||||
|
},
|
||||||
"CancelTrainingJobRequest": {
|
"CancelTrainingJobRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4812,51 +5014,6 @@
|
||||||
],
|
],
|
||||||
"title": "CancelTrainingJobRequest"
|
"title": "CancelTrainingJobRequest"
|
||||||
},
|
},
|
||||||
"ToolConfig": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"tool_choice": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"auto",
|
|
||||||
"required",
|
|
||||||
"none"
|
|
||||||
],
|
|
||||||
"title": "ToolChoice",
|
|
||||||
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"default": "auto",
|
|
||||||
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
|
|
||||||
},
|
|
||||||
"tool_prompt_format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"json",
|
|
||||||
"function_tag",
|
|
||||||
"python_list"
|
|
||||||
],
|
|
||||||
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
|
||||||
},
|
|
||||||
"system_message_behavior": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"append",
|
|
||||||
"replace"
|
|
||||||
],
|
|
||||||
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
|
|
||||||
"default": "append"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"title": "ToolConfig",
|
|
||||||
"description": "Configuration for tool use."
|
|
||||||
},
|
|
||||||
"ChatCompletionRequest": {
|
"ChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
256
docs/_static/llama-stack-spec.yaml
vendored
256
docs/_static/llama-stack-spec.yaml
vendored
|
@ -40,7 +40,36 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/AppendRowsRequest'
|
$ref: '#/components/schemas/AppendRowsRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/batch-inference/chat-completion:
|
/v1/inference/batch-chat-completion:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/BatchChatCompletionResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: ''
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/BatchChatCompletionRequest'
|
||||||
|
required: true
|
||||||
|
/v1/batch-inference/chat-completion-inline:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -67,9 +96,38 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/BatchChatCompletionRequest'
|
$ref: '#/components/schemas/BatchChatCompletionInlineRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/batch-inference/completion:
|
/v1/inference/batch-completion:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: OK
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/BatchCompletionResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: ''
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/BatchCompletionRequest'
|
||||||
|
required: true
|
||||||
|
/v1/batch-inference/completion-inline:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -96,7 +154,7 @@ paths:
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/BatchCompletionRequest'
|
$ref: '#/components/schemas/BatchCompletionInlineRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/post-training/job/cancel:
|
/v1/post-training/job/cancel:
|
||||||
post:
|
post:
|
||||||
|
@ -3009,6 +3067,54 @@ components:
|
||||||
- tool_name
|
- tool_name
|
||||||
- arguments
|
- arguments
|
||||||
title: ToolCall
|
title: ToolCall
|
||||||
|
ToolConfig:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_choice:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
enum:
|
||||||
|
- auto
|
||||||
|
- required
|
||||||
|
- none
|
||||||
|
title: ToolChoice
|
||||||
|
description: >-
|
||||||
|
Whether tool use is required or automatic. This is a hint to the model
|
||||||
|
which may not be followed. It depends on the Instruction Following
|
||||||
|
capabilities of the model.
|
||||||
|
- type: string
|
||||||
|
default: auto
|
||||||
|
description: >-
|
||||||
|
(Optional) Whether tool use is automatic, required, or none. Can also
|
||||||
|
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
||||||
|
tool_prompt_format:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- json
|
||||||
|
- function_tag
|
||||||
|
- python_list
|
||||||
|
description: >-
|
||||||
|
(Optional) Instructs the model how to format tool calls. By default, Llama
|
||||||
|
Stack will attempt to use a format that is best adapted to the model.
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
||||||
|
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
||||||
|
syntax -- a list of function calls.
|
||||||
|
system_message_behavior:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- append
|
||||||
|
- replace
|
||||||
|
description: >-
|
||||||
|
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
|
||||||
|
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
|
||||||
|
Replaces the default system prompt with the provided system message. The
|
||||||
|
system message can include the string '{{function_definitions}}' to indicate
|
||||||
|
where the function definitions should be inserted.
|
||||||
|
default: append
|
||||||
|
additionalProperties: false
|
||||||
|
title: ToolConfig
|
||||||
|
description: Configuration for tool use.
|
||||||
ToolDefinition:
|
ToolDefinition:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3145,7 +3251,7 @@ components:
|
||||||
BatchChatCompletionRequest:
|
BatchChatCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
messages_batch:
|
messages_batch:
|
||||||
type: array
|
type: array
|
||||||
|
@ -3159,26 +3265,8 @@ components:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/ToolDefinition'
|
$ref: '#/components/schemas/ToolDefinition'
|
||||||
tool_choice:
|
tool_config:
|
||||||
type: string
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
enum:
|
|
||||||
- auto
|
|
||||||
- required
|
|
||||||
- none
|
|
||||||
title: ToolChoice
|
|
||||||
description: >-
|
|
||||||
Whether tool use is required or automatic. This is a hint to the model
|
|
||||||
which may not be followed. It depends on the Instruction Following capabilities
|
|
||||||
of the model.
|
|
||||||
tool_prompt_format:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- json
|
|
||||||
- function_tag
|
|
||||||
- python_list
|
|
||||||
title: ToolPromptFormat
|
|
||||||
description: >-
|
|
||||||
Prompt format for calling custom / zero shot tools.
|
|
||||||
response_format:
|
response_format:
|
||||||
$ref: '#/components/schemas/ResponseFormat'
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
logprobs:
|
logprobs:
|
||||||
|
@ -3193,7 +3281,7 @@ components:
|
||||||
title: LogProbConfig
|
title: LogProbConfig
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- messages_batch
|
- messages_batch
|
||||||
title: BatchChatCompletionRequest
|
title: BatchChatCompletionRequest
|
||||||
BatchChatCompletionResponse:
|
BatchChatCompletionResponse:
|
||||||
|
@ -3258,11 +3346,47 @@ components:
|
||||||
- logprobs_by_token
|
- logprobs_by_token
|
||||||
title: TokenLogProbs
|
title: TokenLogProbs
|
||||||
description: Log probabilities for generated tokens.
|
description: Log probabilities for generated tokens.
|
||||||
BatchCompletionRequest:
|
BatchChatCompletionInlineRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
|
messages_batch:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/Message'
|
||||||
|
sampling_params:
|
||||||
|
$ref: '#/components/schemas/SamplingParams'
|
||||||
|
tools:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/ToolDefinition'
|
||||||
|
tool_config:
|
||||||
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
|
response_format:
|
||||||
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
|
logprobs:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
top_k:
|
||||||
|
type: integer
|
||||||
|
default: 0
|
||||||
|
description: >-
|
||||||
|
How many tokens (for each position) to return log probabilities for.
|
||||||
|
additionalProperties: false
|
||||||
|
title: LogProbConfig
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- messages_batch
|
||||||
|
title: BatchChatCompletionInlineRequest
|
||||||
|
BatchCompletionRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model_id:
|
||||||
|
type: string
|
||||||
content_batch:
|
content_batch:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
@ -3283,7 +3407,7 @@ components:
|
||||||
title: LogProbConfig
|
title: LogProbConfig
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- content_batch
|
- content_batch
|
||||||
title: BatchCompletionRequest
|
title: BatchCompletionRequest
|
||||||
BatchCompletionResponse:
|
BatchCompletionResponse:
|
||||||
|
@ -3326,6 +3450,34 @@ components:
|
||||||
- stop_reason
|
- stop_reason
|
||||||
title: CompletionResponse
|
title: CompletionResponse
|
||||||
description: Response from a completion request.
|
description: Response from a completion request.
|
||||||
|
BatchCompletionInlineRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
content_batch:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/InterleavedContent'
|
||||||
|
sampling_params:
|
||||||
|
$ref: '#/components/schemas/SamplingParams'
|
||||||
|
response_format:
|
||||||
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
|
logprobs:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
top_k:
|
||||||
|
type: integer
|
||||||
|
default: 0
|
||||||
|
description: >-
|
||||||
|
How many tokens (for each position) to return log probabilities for.
|
||||||
|
additionalProperties: false
|
||||||
|
title: LogProbConfig
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- content_batch
|
||||||
|
title: BatchCompletionInlineRequest
|
||||||
CancelTrainingJobRequest:
|
CancelTrainingJobRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3335,54 +3487,6 @@ components:
|
||||||
required:
|
required:
|
||||||
- job_uuid
|
- job_uuid
|
||||||
title: CancelTrainingJobRequest
|
title: CancelTrainingJobRequest
|
||||||
ToolConfig:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
tool_choice:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
enum:
|
|
||||||
- auto
|
|
||||||
- required
|
|
||||||
- none
|
|
||||||
title: ToolChoice
|
|
||||||
description: >-
|
|
||||||
Whether tool use is required or automatic. This is a hint to the model
|
|
||||||
which may not be followed. It depends on the Instruction Following
|
|
||||||
capabilities of the model.
|
|
||||||
- type: string
|
|
||||||
default: auto
|
|
||||||
description: >-
|
|
||||||
(Optional) Whether tool use is automatic, required, or none. Can also
|
|
||||||
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
|
||||||
tool_prompt_format:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- json
|
|
||||||
- function_tag
|
|
||||||
- python_list
|
|
||||||
description: >-
|
|
||||||
(Optional) Instructs the model how to format tool calls. By default, Llama
|
|
||||||
Stack will attempt to use a format that is best adapted to the model.
|
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
|
||||||
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
|
||||||
syntax -- a list of function calls.
|
|
||||||
system_message_behavior:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- append
|
|
||||||
- replace
|
|
||||||
description: >-
|
|
||||||
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
|
|
||||||
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
|
|
||||||
Replaces the default system prompt with the provided system message. The
|
|
||||||
system message can include the string '{{function_definitions}}' to indicate
|
|
||||||
where the function definitions should be inserted.
|
|
||||||
default: append
|
|
||||||
additionalProperties: false
|
|
||||||
title: ToolConfig
|
|
||||||
description: Configuration for tool use.
|
|
||||||
ChatCompletionRequest:
|
ChatCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -6,37 +6,24 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
CompletionResponse,
|
BatchCompletionResponse,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchCompletionResponse(BaseModel):
|
|
||||||
batch: List[CompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
|
||||||
batch: List[ChatCompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
@webmethod(route="/batch-inference/completion", method="POST")
|
@webmethod(route="/batch-inference/completion-inline", method="POST")
|
||||||
async def batch_completion(
|
async def batch_completion_inline(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: List[InterleavedContent],
|
||||||
|
@ -45,16 +32,14 @@ class BatchInference(Protocol):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> BatchCompletionResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion-inline", method="POST")
|
||||||
async def batch_chat_completion(
|
async def batch_chat_completion_inline(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: List[List[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = list,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> BatchChatCompletionResponse: ...
|
||||||
|
|
|
@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
|
||||||
document = "document"
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
batch: List[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
@ -716,6 +726,17 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -756,6 +777,19 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -400,6 +400,9 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
mro = type(obj).__mro__
|
mro = type(obj).__mro__
|
||||||
for name, value in inspect.getmembers(protocol):
|
for name, value in inspect.getmembers(protocol):
|
||||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||||
|
if value.__webmethod__.experimental:
|
||||||
|
continue
|
||||||
|
|
||||||
if not hasattr(obj, name):
|
if not hasattr(obj, name):
|
||||||
missing_methods.append((name, "missing"))
|
missing_methods.append((name, "missing"))
|
||||||
elif not callable(getattr(obj, name)):
|
elif not callable(getattr(obj, name)):
|
||||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
BatchChatCompletionResponse,
|
||||||
|
BatchCompletionResponse,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -334,6 +336,30 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages_batch=messages_batch,
|
||||||
|
tools=tools,
|
||||||
|
tool_config=tool_config,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -398,6 +424,20 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -140,7 +140,12 @@ class Llama3:
|
||||||
|
|
||||||
return Llama3(model, tokenizer, model_args)
|
return Llama3(model, tokenizer, model_args)
|
||||||
|
|
||||||
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Transformer | CrossAttentionTransformer,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
args: ModelArgs,
|
||||||
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -285,7 +290,7 @@ class Llama3:
|
||||||
source="output",
|
source="output",
|
||||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
finished=eos_reached[idx],
|
finished=eos_reached[idx].item(),
|
||||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -233,7 +233,7 @@ class Llama4:
|
||||||
source="output",
|
source="output",
|
||||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
finished=eos_reached[idx],
|
finished=eos_reached[idx].item(),
|
||||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||||
|
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||||
|
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_seq_len": 4096,
|
|
||||||
"checkpoint_dir": checkpoint_dir,
|
"checkpoint_dir": checkpoint_dir,
|
||||||
"quantization": {
|
"quantization": {
|
||||||
"type": quantization_type,
|
"type": quantization_type,
|
||||||
},
|
},
|
||||||
"model_parallel_size": model_parallel_size,
|
"model_parallel_size": model_parallel_size,
|
||||||
|
"max_batch_size": max_batch_size,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_types import Model
|
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
|
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||||
return get_default_tool_prompt_format(request.model)
|
return get_default_tool_prompt_format(request.model)
|
||||||
|
|
||||||
|
|
||||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
class LlamaGenerator:
|
||||||
class Llama4Generator:
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
|
@ -144,7 +143,8 @@ class Llama4Generator:
|
||||||
else:
|
else:
|
||||||
quantization_mode = None
|
quantization_mode = None
|
||||||
|
|
||||||
self.inner_generator = Llama4.build(
|
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||||
|
self.inner_generator = cls.build(
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
|
@ -158,142 +158,55 @@ class Llama4Generator:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequestWithRawContent,
|
request_batch: List[CompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
first_request = request_batch[0]
|
||||||
|
sampling_params = first_request.sampling_params or SamplingParams()
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
for result in self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(first_request.logprobs),
|
||||||
echo=False,
|
echo=False,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
first_request.response_format,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield result[0]
|
yield result
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequestWithRawContent,
|
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
first_request = request_batch[0]
|
||||||
|
sampling_params = first_request.sampling_params or SamplingParams()
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
for result in self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
llm_inputs=[
|
||||||
|
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||||
|
for request in request_batch
|
||||||
|
],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(first_request.logprobs),
|
||||||
echo=False,
|
echo=False,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
first_request.response_format,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield result[0]
|
yield result
|
||||||
|
|
||||||
|
|
||||||
class Llama3Generator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: MetaReferenceInferenceConfig,
|
|
||||||
model_id: str,
|
|
||||||
llama_model: Model,
|
|
||||||
):
|
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
|
||||||
ckpt_dir = config.checkpoint_dir
|
|
||||||
else:
|
|
||||||
resolved_model = resolve_model(model_id)
|
|
||||||
if resolved_model is None:
|
|
||||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
|
||||||
ckpt_dir = model_checkpoint_dir(model_id)
|
|
||||||
else:
|
|
||||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
|
||||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
|
||||||
|
|
||||||
if config.quantization:
|
|
||||||
if config.quantization.type == "fp8_mixed":
|
|
||||||
quantization_mode = QuantizationMode.fp8_mixed
|
|
||||||
elif config.quantization.type == "int4_mixed":
|
|
||||||
quantization_mode = QuantizationMode.int4_mixed
|
|
||||||
elif config.quantization.type == "bf16":
|
|
||||||
quantization_mode = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
|
||||||
else:
|
|
||||||
quantization_mode = None
|
|
||||||
|
|
||||||
self.inner_generator = Llama3.build(
|
|
||||||
ckpt_dir=ckpt_dir,
|
|
||||||
max_seq_len=config.max_seq_len,
|
|
||||||
max_batch_size=config.max_batch_size,
|
|
||||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
|
||||||
quantization_mode=quantization_mode,
|
|
||||||
)
|
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
|
||||||
self.args = self.inner_generator.args
|
|
||||||
self.formatter = self.inner_generator.formatter
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request: CompletionRequestWithRawContent,
|
|
||||||
) -> Generator:
|
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
for result in self.inner_generator.generate(
|
|
||||||
model_inputs=[self.formatter.encode_content(request.content)],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
request.response_format,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
yield result[0]
|
|
||||||
|
|
||||||
def chat_completion(
|
|
||||||
self,
|
|
||||||
request: ChatCompletionRequestWithRawContent,
|
|
||||||
) -> Generator:
|
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
for result in self.inner_generator.generate(
|
|
||||||
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
request.response_format,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
yield result[0]
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -17,6 +18,8 @@ from llama_stack.apis.common.content_types import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
BatchChatCompletionResponse,
|
||||||
|
BatchCompletionResponse,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -38,6 +41,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
|
@ -65,7 +69,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generators import Llama3Generator, Llama4Generator
|
from .generators import LlamaGenerator
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
@ -74,12 +78,8 @@ log = logging.getLogger(__name__)
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||||
return Llama3Generator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
|
||||||
|
|
||||||
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
|
|
||||||
return Llama4Generator(config, model_id, llama_model)
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(
|
class MetaReferenceInferenceImpl(
|
||||||
|
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
|
||||||
async def load_model(self, model_id, llama_model) -> None:
|
async def load_model(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
log.info(f"Loading model `{model_id}`")
|
||||||
|
|
||||||
if llama_model.model_family in {
|
|
||||||
ModelFamily.llama3,
|
|
||||||
ModelFamily.llama3_1,
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
}:
|
|
||||||
builder_fn = llama3_builder_fn
|
|
||||||
elif llama_model.model_family == ModelFamily.llama4:
|
|
||||||
builder_fn = llama4_builder_fn
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
|
|
||||||
|
|
||||||
builder_params = [self.config, model_id, llama_model]
|
builder_params = [self.config, model_id, llama_model]
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||||
builder_fn=builder_fn,
|
builder_fn=llama_builder_fn,
|
||||||
builder_params=builder_params,
|
builder_params=builder_params,
|
||||||
formatter=(
|
formatter=(
|
||||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||||
|
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
|
||||||
)
|
)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = builder_fn(*builder_params)
|
self.generator = llama_builder_fn(*builder_params)
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
|
print("Warming up...")
|
||||||
|
await self.completion(
|
||||||
|
model_id=model_id,
|
||||||
|
content="Hello, world!",
|
||||||
|
sampling_params=SamplingParams(max_tokens=10),
|
||||||
|
)
|
||||||
|
await self.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[UserMessage(content="Hi how are you?")],
|
||||||
|
sampling_params=SamplingParams(max_tokens=20),
|
||||||
|
)
|
||||||
|
print("Warmed up!")
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
def check_model(self, request) -> None:
|
||||||
if self.model_id is None or self.llama_model is None:
|
if self.model_id is None or self.llama_model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_completion(request)
|
return self._stream_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
results = await self._nonstream_completion([request])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
|
content_batch = [
|
||||||
|
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||||
|
]
|
||||||
|
|
||||||
|
request_batch = []
|
||||||
|
for content in content_batch:
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
request_batch.append(request)
|
||||||
|
|
||||||
|
results = await self._nonstream_completion(request_batch)
|
||||||
|
return BatchCompletionResponse(batch=results)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
|
first_request = request_batch[0]
|
||||||
|
|
||||||
|
class ItemState(BaseModel):
|
||||||
|
tokens: List[int] = []
|
||||||
|
logprobs: List[TokenLogProbs] = []
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
states = [ItemState() for _ in request_batch]
|
||||||
logprobs = []
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
results = []
|
||||||
tokens.append(token_result.token)
|
for token_results in self.generator.completion(request_batch):
|
||||||
if token_result.token == tokenizer.eot_id:
|
for result in token_results:
|
||||||
stop_reason = StopReason.end_of_turn
|
idx = result.batch_idx
|
||||||
elif token_result.token == tokenizer.eom_id:
|
state = states[idx]
|
||||||
stop_reason = StopReason.end_of_message
|
if state.finished or result.ignore_token:
|
||||||
|
continue
|
||||||
|
|
||||||
if request.logprobs:
|
state.finished = result.finished
|
||||||
assert len(token_result.logprobs) == 1
|
if first_request.logprobs:
|
||||||
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
state.tokens.append(result.token)
|
||||||
|
if result.token == tokenizer.eot_id:
|
||||||
|
state.stop_reason = StopReason.end_of_turn
|
||||||
|
elif result.token == tokenizer.eom_id:
|
||||||
|
state.stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
if stop_reason is None:
|
for state in states:
|
||||||
stop_reason = StopReason.out_of_tokens
|
if state.stop_reason is None:
|
||||||
|
state.stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||||
tokens = tokens[:-1]
|
state.tokens = state.tokens[:-1]
|
||||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||||
return CompletionResponse(
|
results.append(
|
||||||
content=content,
|
CompletionResponse(
|
||||||
stop_reason=stop_reason,
|
content=content,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
stop_reason=state.stop_reason,
|
||||||
)
|
logprobs=state.logprobs if first_request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
|
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config or ToolConfig(),
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
|
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
results = await self._nonstream_chat_completion([request])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
request_batch = []
|
||||||
|
for messages in messages_batch:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config or ToolConfig(),
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
|
||||||
|
# augment and rewrite messages depending on the model
|
||||||
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||||
|
# download media and convert to raw content so we can send it to the model
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
request_batch.append(request)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
if SEMAPHORE.locked():
|
||||||
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
|
results = await self._nonstream_chat_completion(request_batch)
|
||||||
|
return BatchChatCompletionResponse(batch=results)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request_batch: List[ChatCompletionRequest]
|
||||||
|
) -> List[ChatCompletionResponse]:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
|
first_request = request_batch[0]
|
||||||
|
|
||||||
|
class ItemState(BaseModel):
|
||||||
|
tokens: List[int] = []
|
||||||
|
logprobs: List[TokenLogProbs] = []
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
states = [ItemState() for _ in request_batch]
|
||||||
logprobs = []
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_results in self.generator.chat_completion(request_batch):
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
first = token_results[0]
|
||||||
cprint(token_result.text, "cyan", end="")
|
if not first.finished:
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||||
|
cprint(first.text, "cyan", end="")
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
cprint(f"<{first.token}>", "magenta", end="")
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
for result in token_results:
|
||||||
|
idx = result.batch_idx
|
||||||
|
state = states[idx]
|
||||||
|
if state.finished or result.ignore_token:
|
||||||
|
continue
|
||||||
|
|
||||||
if token_result.token == tokenizer.eot_id:
|
state.finished = result.finished
|
||||||
stop_reason = StopReason.end_of_turn
|
if first_request.logprobs:
|
||||||
elif token_result.token == tokenizer.eom_id:
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
|
|
||||||
if request.logprobs:
|
state.tokens.append(result.token)
|
||||||
assert len(token_result.logprobs) == 1
|
if result.token == tokenizer.eot_id:
|
||||||
|
state.stop_reason = StopReason.end_of_turn
|
||||||
|
elif result.token == tokenizer.eom_id:
|
||||||
|
state.stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
results = []
|
||||||
|
for state in states:
|
||||||
|
if state.stop_reason is None:
|
||||||
|
state.stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
if stop_reason is None:
|
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||||
stop_reason = StopReason.out_of_tokens
|
results.append(
|
||||||
|
ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=raw_message.content,
|
||||||
|
stop_reason=raw_message.stop_reason,
|
||||||
|
tool_calls=raw_message.tool_calls,
|
||||||
|
),
|
||||||
|
logprobs=state.logprobs if first_request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
return results
|
||||||
return ChatCompletionResponse(
|
|
||||||
completion_message=CompletionMessage(
|
|
||||||
content=raw_message.content,
|
|
||||||
stop_reason=raw_message.stop_reason,
|
|
||||||
tool_calls=raw_message.tool_calls,
|
|
||||||
),
|
|
||||||
logprobs=logprobs if request.logprobs else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
|
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_result in self.generator.chat_completion(request):
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, "cyan", end="")
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
cprint(f"<{token_result.token}>", "magenta", end="")
|
||||||
|
|
||||||
|
if token_result.token == tokenizer.eot_id:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.token == tokenizer.eom_id:
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Generator
|
from typing import Any, Callable, Generator, List
|
||||||
|
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
|
@ -23,13 +23,13 @@ class ModelRunner:
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
def __call__(self, req: Any):
|
def __call__(self, task: Any):
|
||||||
if isinstance(req, ChatCompletionRequestWithRawContent):
|
if task[0] == "chat_completion":
|
||||||
return self.llama.chat_completion(req)
|
return self.llama.chat_completion(task[1])
|
||||||
elif isinstance(req, CompletionRequestWithRawContent):
|
elif task[0] == "completion":
|
||||||
return self.llama.completion(req)
|
return self.llama.completion(task[1])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected task type {type(req)}")
|
raise ValueError(f"Unexpected task type {task[0]}")
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
|
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequestWithRawContent,
|
request_batch: List[CompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request_batch)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(("completion", req_obj))
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequestWithRawContent,
|
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request_batch)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
|
@ -19,7 +19,7 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Generator, Literal, Optional, Union
|
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||||
result: GenerationResult
|
result: List[GenerationResult]
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
|
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
|
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
Inference,
|
Inference,
|
||||||
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||||
|
|
39
llama_stack/providers/registry/batch_inference.py
Normal file
39
llama_stack/providers/registry/batch_inference.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
)
|
||||||
|
|
||||||
|
META_REFERENCE_DEPS = [
|
||||||
|
"accelerate",
|
||||||
|
"blobfile",
|
||||||
|
"fairscale",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"transformers",
|
||||||
|
"zmq",
|
||||||
|
"lm-format-enforcer",
|
||||||
|
"sentence-transformers",
|
||||||
|
"torchao==0.5.0",
|
||||||
|
"fbgemm-gpu-genai==1.1.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::meta-reference",
|
||||||
|
pip_packages=META_REFERENCE_DEPS,
|
||||||
|
module="llama_stack.providers.inline.batch_inference.meta_reference",
|
||||||
|
config_class="llama_stack.providers.inline.batch_inference.meta_reference.MetaReferenceInferenceConfig",
|
||||||
|
),
|
||||||
|
]
|
|
@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
|
||||||
}
|
}
|
||||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return await self.client.chat.completions.create(**params) # type: ignore
|
return await self.client.chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
|
@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return litellm.completion(**params)
|
return litellm.completion(**params)
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||||
|
|
|
@ -20,6 +20,7 @@ class WebMethod:
|
||||||
raw_bytes_request_body: Optional[bool] = False
|
raw_bytes_request_body: Optional[bool] = False
|
||||||
# A descriptive name of the corresponding span created by tracing
|
# A descriptive name of the corresponding span created by tracing
|
||||||
descriptive_name: Optional[str] = None
|
descriptive_name: Optional[str] = None
|
||||||
|
experimental: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable[..., Any])
|
T = TypeVar("T", bound=Callable[..., Any])
|
||||||
|
@ -33,6 +34,7 @@ def webmethod(
|
||||||
response_examples: Optional[List[Any]] = None,
|
response_examples: Optional[List[Any]] = None,
|
||||||
raw_bytes_request_body: Optional[bool] = False,
|
raw_bytes_request_body: Optional[bool] = False,
|
||||||
descriptive_name: Optional[str] = None,
|
descriptive_name: Optional[str] = None,
|
||||||
|
experimental: Optional[bool] = False,
|
||||||
) -> Callable[[T], T]:
|
) -> Callable[[T], T]:
|
||||||
"""
|
"""
|
||||||
Decorator that supplies additional metadata to an endpoint operation function.
|
Decorator that supplies additional metadata to an endpoint operation function.
|
||||||
|
@ -52,6 +54,7 @@ def webmethod(
|
||||||
response_examples=response_examples,
|
response_examples=response_examples,
|
||||||
raw_bytes_request_body=raw_bytes_request_body,
|
raw_bytes_request_body=raw_bytes_request_body,
|
||||||
descriptive_name=descriptive_name,
|
descriptive_name=descriptive_name,
|
||||||
|
experimental=experimental,
|
||||||
)
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
@ -28,11 +29,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.SAFETY_MODEL}
|
model: ${env.SAFETY_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
@ -16,11 +16,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
111
tests/integration/inference/test_batch_inference.py
Normal file
111
tests/integration/inference/test_batch_inference.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
|
||||||
|
from ..test_cases.test_case import TestCase
|
||||||
|
|
||||||
|
PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"}
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
||||||
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||||
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||||
|
provider_id = models[model_id].provider_id
|
||||||
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
|
provider = providers[provider_id]
|
||||||
|
if provider.provider_type in (
|
||||||
|
"remote::openai",
|
||||||
|
"remote::anthropic",
|
||||||
|
"remote::gemini",
|
||||||
|
"remote::groq",
|
||||||
|
"remote::llama-openai-compat",
|
||||||
|
):
|
||||||
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_model(client_with_models, model_id):
|
||||||
|
models = {}
|
||||||
|
for m in client_with_models.models.list():
|
||||||
|
models[m.identifier] = m
|
||||||
|
models[m.provider_resource_id] = m
|
||||||
|
|
||||||
|
assert model_id in models, f"Model {model_id} not found"
|
||||||
|
|
||||||
|
model = models[model_id]
|
||||||
|
ids = (model.identifier, model.provider_resource_id)
|
||||||
|
for mid in ids:
|
||||||
|
if resolve_model(mid):
|
||||||
|
return mid
|
||||||
|
|
||||||
|
return model.metadata.get("llama_model", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama_tokenizer():
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
formatter = ChatFormat(tokenizer)
|
||||||
|
return tokenizer, formatter
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:completion:batch_completion",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
content_batch = tc["contents"]
|
||||||
|
response = client_with_models.inference.batch_completion(
|
||||||
|
content_batch=content_batch,
|
||||||
|
model_id=text_model_id,
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert len(response.batch) == len(content_batch)
|
||||||
|
for i, r in enumerate(response.batch):
|
||||||
|
print(f"response {i}: {r.content}")
|
||||||
|
assert len(r.content) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:chat_completion:batch_completion",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
qa_pairs = tc["qa_pairs"]
|
||||||
|
|
||||||
|
message_batch = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": qa["question"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for qa in qa_pairs
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client_with_models.inference.batch_chat_completion(
|
||||||
|
messages_batch=message_batch,
|
||||||
|
model_id=text_model_id,
|
||||||
|
)
|
||||||
|
assert len(response.batch) == len(qa_pairs)
|
||||||
|
for i, r in enumerate(response.batch):
|
||||||
|
print(f"response {i}: {r.completion_message.content}")
|
||||||
|
assert len(r.completion_message.content) > 0
|
||||||
|
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()
|
|
@ -537,5 +537,31 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"batch_completion": {
|
||||||
|
"data": {
|
||||||
|
"qa_pairs": [
|
||||||
|
{
|
||||||
|
"question": "What is the capital of France?",
|
||||||
|
"answer": "Paris"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Who wrote the book '1984'?",
|
||||||
|
"answer": "George Orwell"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Which planet has rings around it with a name starting with letter S?",
|
||||||
|
"answer": "Saturn"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "When did the first moon landing happen?",
|
||||||
|
"answer": "1969"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "What word says 'hello' in Spanish?",
|
||||||
|
"answer": "Hola"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,5 +44,18 @@
|
||||||
"year_retired": "2003"
|
"year_retired": "2003"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"batch_completion": {
|
||||||
|
"data": {
|
||||||
|
"contents": [
|
||||||
|
"Micheael Jordan is born in ",
|
||||||
|
"Roses are red, violets are ",
|
||||||
|
"If you had a million dollars, what would you do with it? ",
|
||||||
|
"All you need is ",
|
||||||
|
"The capital of France is ",
|
||||||
|
"It is a good day to ",
|
||||||
|
"The answer to the universe is "
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue