forked from phoenix-oss/llama-stack-mirror
feat: add batch inference API to llama stack inference (#1945)
# What does this PR do? This PR adds two methods to the Inference API: - `batch_completion` - `batch_chat_completion` The motivation is for evaluations targeting a local inference engine (like meta-reference or vllm) where batch APIs provide for a substantial amount of acceleration. Why did I not add this to `Api.batch_inference` though? That just resulted in a _lot_ more book-keeping given the structure of Llama Stack. Had I done that, I would have needed to create a notion of a "batch model" resource, setup routing based on that, etc. This does not sound ideal. So what's the future of the batch inference API? I am not sure. Maybe we can keep it for true _asynchronous_ execution. So you can submit requests, and it can return a Job instance, etc. ## Test Plan Run meta-reference-gpu using: ```bash export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000 export MODEL_PARALLEL_SIZE=4 export MAX_BATCH_SIZE=32 export MAX_SEQ_LEN=6144 LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu ``` Then run the batch inference test case.
This commit is contained in:
parent
854c2ad264
commit
f34f22f8c7
23 changed files with 698 additions and 389 deletions
135
docs/_static/llama-stack-spec.html
vendored
135
docs/_static/llama-stack-spec.html
vendored
|
@ -85,7 +85,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/batch-inference/chat-completion": {
|
"/v1/inference/batch-chat-completion": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -112,7 +112,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"BatchInference (Coming Soon)"
|
"Inference"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": [],
|
"parameters": [],
|
||||||
|
@ -128,7 +128,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/v1/batch-inference/completion": {
|
"/v1/inference/batch-completion": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
|
@ -155,7 +155,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"BatchInference (Coming Soon)"
|
"Inference"
|
||||||
],
|
],
|
||||||
"description": "",
|
"description": "",
|
||||||
"parameters": [],
|
"parameters": [],
|
||||||
|
@ -239,7 +239,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Inference"
|
"BatchInference (Coming Soon)"
|
||||||
],
|
],
|
||||||
"description": "Generate a chat completion for the given messages using the specified model.",
|
"description": "Generate a chat completion for the given messages using the specified model.",
|
||||||
"parameters": [],
|
"parameters": [],
|
||||||
|
@ -287,7 +287,7 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
"Inference"
|
"BatchInference (Coming Soon)"
|
||||||
],
|
],
|
||||||
"description": "Generate a completion for the given content using the specified model.",
|
"description": "Generate a completion for the given content using the specified model.",
|
||||||
"parameters": [],
|
"parameters": [],
|
||||||
|
@ -4366,6 +4366,51 @@
|
||||||
],
|
],
|
||||||
"title": "ToolCall"
|
"title": "ToolCall"
|
||||||
},
|
},
|
||||||
|
"ToolConfig": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tool_choice": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"auto",
|
||||||
|
"required",
|
||||||
|
"none"
|
||||||
|
],
|
||||||
|
"title": "ToolChoice",
|
||||||
|
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"default": "auto",
|
||||||
|
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
|
||||||
|
},
|
||||||
|
"tool_prompt_format": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"json",
|
||||||
|
"function_tag",
|
||||||
|
"python_list"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
||||||
|
},
|
||||||
|
"system_message_behavior": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"append",
|
||||||
|
"replace"
|
||||||
|
],
|
||||||
|
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
|
||||||
|
"default": "append"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"title": "ToolConfig",
|
||||||
|
"description": "Configuration for tool use."
|
||||||
|
},
|
||||||
"ToolDefinition": {
|
"ToolDefinition": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -4554,7 +4599,7 @@
|
||||||
"BatchChatCompletionRequest": {
|
"BatchChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"messages_batch": {
|
"messages_batch": {
|
||||||
|
@ -4575,25 +4620,8 @@
|
||||||
"$ref": "#/components/schemas/ToolDefinition"
|
"$ref": "#/components/schemas/ToolDefinition"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tool_choice": {
|
"tool_config": {
|
||||||
"type": "string",
|
"$ref": "#/components/schemas/ToolConfig"
|
||||||
"enum": [
|
|
||||||
"auto",
|
|
||||||
"required",
|
|
||||||
"none"
|
|
||||||
],
|
|
||||||
"title": "ToolChoice",
|
|
||||||
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
|
||||||
},
|
|
||||||
"tool_prompt_format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"json",
|
|
||||||
"function_tag",
|
|
||||||
"python_list"
|
|
||||||
],
|
|
||||||
"title": "ToolPromptFormat",
|
|
||||||
"description": "Prompt format for calling custom / zero shot tools."
|
|
||||||
},
|
},
|
||||||
"response_format": {
|
"response_format": {
|
||||||
"$ref": "#/components/schemas/ResponseFormat"
|
"$ref": "#/components/schemas/ResponseFormat"
|
||||||
|
@ -4613,7 +4641,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"messages_batch"
|
"messages_batch"
|
||||||
],
|
],
|
||||||
"title": "BatchChatCompletionRequest"
|
"title": "BatchChatCompletionRequest"
|
||||||
|
@ -4710,7 +4738,7 @@
|
||||||
"BatchCompletionRequest": {
|
"BatchCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"model": {
|
"model_id": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"content_batch": {
|
"content_batch": {
|
||||||
|
@ -4740,7 +4768,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"model",
|
"model_id",
|
||||||
"content_batch"
|
"content_batch"
|
||||||
],
|
],
|
||||||
"title": "BatchCompletionRequest"
|
"title": "BatchCompletionRequest"
|
||||||
|
@ -4812,51 +4840,6 @@
|
||||||
],
|
],
|
||||||
"title": "CancelTrainingJobRequest"
|
"title": "CancelTrainingJobRequest"
|
||||||
},
|
},
|
||||||
"ToolConfig": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"tool_choice": {
|
|
||||||
"oneOf": [
|
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"auto",
|
|
||||||
"required",
|
|
||||||
"none"
|
|
||||||
],
|
|
||||||
"title": "ToolChoice",
|
|
||||||
"description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"default": "auto",
|
|
||||||
"description": "(Optional) Whether tool use is automatic, required, or none. Can also specify a tool name to use a specific tool. Defaults to ToolChoice.auto."
|
|
||||||
},
|
|
||||||
"tool_prompt_format": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"json",
|
|
||||||
"function_tag",
|
|
||||||
"python_list"
|
|
||||||
],
|
|
||||||
"description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name> tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls."
|
|
||||||
},
|
|
||||||
"system_message_behavior": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": [
|
|
||||||
"append",
|
|
||||||
"replace"
|
|
||||||
],
|
|
||||||
"description": "(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`: Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`: Replaces the default system prompt with the provided system message. The system message can include the string '{{function_definitions}}' to indicate where the function definitions should be inserted.",
|
|
||||||
"default": "append"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"title": "ToolConfig",
|
|
||||||
"description": "Configuration for tool use."
|
|
||||||
},
|
|
||||||
"ChatCompletionRequest": {
|
"ChatCompletionRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -11173,7 +11156,9 @@
|
||||||
"x-displayName": "Agents API for creating and interacting with agentic systems."
|
"x-displayName": "Agents API for creating and interacting with agentic systems."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "BatchInference (Coming Soon)"
|
"name": "BatchInference (Coming Soon)",
|
||||||
|
"description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).",
|
||||||
|
"x-displayName": "Batch inference API for generating completions and chat completions."
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Benchmarks"
|
"name": "Benchmarks"
|
||||||
|
|
149
docs/_static/llama-stack-spec.yaml
vendored
149
docs/_static/llama-stack-spec.yaml
vendored
|
@ -40,7 +40,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/AppendRowsRequest'
|
$ref: '#/components/schemas/AppendRowsRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/batch-inference/chat-completion:
|
/v1/inference/batch-chat-completion:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -60,7 +60,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- BatchInference (Coming Soon)
|
- Inference
|
||||||
description: ''
|
description: ''
|
||||||
parameters: []
|
parameters: []
|
||||||
requestBody:
|
requestBody:
|
||||||
|
@ -69,7 +69,7 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/BatchChatCompletionRequest'
|
$ref: '#/components/schemas/BatchChatCompletionRequest'
|
||||||
required: true
|
required: true
|
||||||
/v1/batch-inference/completion:
|
/v1/inference/batch-completion:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -89,7 +89,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- BatchInference (Coming Soon)
|
- Inference
|
||||||
description: ''
|
description: ''
|
||||||
parameters: []
|
parameters: []
|
||||||
requestBody:
|
requestBody:
|
||||||
|
@ -148,7 +148,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Inference
|
- BatchInference (Coming Soon)
|
||||||
description: >-
|
description: >-
|
||||||
Generate a chat completion for the given messages using the specified model.
|
Generate a chat completion for the given messages using the specified model.
|
||||||
parameters: []
|
parameters: []
|
||||||
|
@ -183,7 +183,7 @@ paths:
|
||||||
default:
|
default:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Inference
|
- BatchInference (Coming Soon)
|
||||||
description: >-
|
description: >-
|
||||||
Generate a completion for the given content using the specified model.
|
Generate a completion for the given content using the specified model.
|
||||||
parameters: []
|
parameters: []
|
||||||
|
@ -3009,6 +3009,54 @@ components:
|
||||||
- tool_name
|
- tool_name
|
||||||
- arguments
|
- arguments
|
||||||
title: ToolCall
|
title: ToolCall
|
||||||
|
ToolConfig:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tool_choice:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
enum:
|
||||||
|
- auto
|
||||||
|
- required
|
||||||
|
- none
|
||||||
|
title: ToolChoice
|
||||||
|
description: >-
|
||||||
|
Whether tool use is required or automatic. This is a hint to the model
|
||||||
|
which may not be followed. It depends on the Instruction Following
|
||||||
|
capabilities of the model.
|
||||||
|
- type: string
|
||||||
|
default: auto
|
||||||
|
description: >-
|
||||||
|
(Optional) Whether tool use is automatic, required, or none. Can also
|
||||||
|
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
||||||
|
tool_prompt_format:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- json
|
||||||
|
- function_tag
|
||||||
|
- python_list
|
||||||
|
description: >-
|
||||||
|
(Optional) Instructs the model how to format tool calls. By default, Llama
|
||||||
|
Stack will attempt to use a format that is best adapted to the model.
|
||||||
|
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
||||||
|
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
||||||
|
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
||||||
|
syntax -- a list of function calls.
|
||||||
|
system_message_behavior:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- append
|
||||||
|
- replace
|
||||||
|
description: >-
|
||||||
|
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
|
||||||
|
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
|
||||||
|
Replaces the default system prompt with the provided system message. The
|
||||||
|
system message can include the string '{{function_definitions}}' to indicate
|
||||||
|
where the function definitions should be inserted.
|
||||||
|
default: append
|
||||||
|
additionalProperties: false
|
||||||
|
title: ToolConfig
|
||||||
|
description: Configuration for tool use.
|
||||||
ToolDefinition:
|
ToolDefinition:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -3145,7 +3193,7 @@ components:
|
||||||
BatchChatCompletionRequest:
|
BatchChatCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
messages_batch:
|
messages_batch:
|
||||||
type: array
|
type: array
|
||||||
|
@ -3159,26 +3207,8 @@ components:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
$ref: '#/components/schemas/ToolDefinition'
|
$ref: '#/components/schemas/ToolDefinition'
|
||||||
tool_choice:
|
tool_config:
|
||||||
type: string
|
$ref: '#/components/schemas/ToolConfig'
|
||||||
enum:
|
|
||||||
- auto
|
|
||||||
- required
|
|
||||||
- none
|
|
||||||
title: ToolChoice
|
|
||||||
description: >-
|
|
||||||
Whether tool use is required or automatic. This is a hint to the model
|
|
||||||
which may not be followed. It depends on the Instruction Following capabilities
|
|
||||||
of the model.
|
|
||||||
tool_prompt_format:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- json
|
|
||||||
- function_tag
|
|
||||||
- python_list
|
|
||||||
title: ToolPromptFormat
|
|
||||||
description: >-
|
|
||||||
Prompt format for calling custom / zero shot tools.
|
|
||||||
response_format:
|
response_format:
|
||||||
$ref: '#/components/schemas/ResponseFormat'
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
logprobs:
|
logprobs:
|
||||||
|
@ -3193,7 +3223,7 @@ components:
|
||||||
title: LogProbConfig
|
title: LogProbConfig
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- messages_batch
|
- messages_batch
|
||||||
title: BatchChatCompletionRequest
|
title: BatchChatCompletionRequest
|
||||||
BatchChatCompletionResponse:
|
BatchChatCompletionResponse:
|
||||||
|
@ -3261,7 +3291,7 @@ components:
|
||||||
BatchCompletionRequest:
|
BatchCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
model:
|
model_id:
|
||||||
type: string
|
type: string
|
||||||
content_batch:
|
content_batch:
|
||||||
type: array
|
type: array
|
||||||
|
@ -3283,7 +3313,7 @@ components:
|
||||||
title: LogProbConfig
|
title: LogProbConfig
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model_id
|
||||||
- content_batch
|
- content_batch
|
||||||
title: BatchCompletionRequest
|
title: BatchCompletionRequest
|
||||||
BatchCompletionResponse:
|
BatchCompletionResponse:
|
||||||
|
@ -3335,54 +3365,6 @@ components:
|
||||||
required:
|
required:
|
||||||
- job_uuid
|
- job_uuid
|
||||||
title: CancelTrainingJobRequest
|
title: CancelTrainingJobRequest
|
||||||
ToolConfig:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
tool_choice:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
enum:
|
|
||||||
- auto
|
|
||||||
- required
|
|
||||||
- none
|
|
||||||
title: ToolChoice
|
|
||||||
description: >-
|
|
||||||
Whether tool use is required or automatic. This is a hint to the model
|
|
||||||
which may not be followed. It depends on the Instruction Following
|
|
||||||
capabilities of the model.
|
|
||||||
- type: string
|
|
||||||
default: auto
|
|
||||||
description: >-
|
|
||||||
(Optional) Whether tool use is automatic, required, or none. Can also
|
|
||||||
specify a tool name to use a specific tool. Defaults to ToolChoice.auto.
|
|
||||||
tool_prompt_format:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- json
|
|
||||||
- function_tag
|
|
||||||
- python_list
|
|
||||||
description: >-
|
|
||||||
(Optional) Instructs the model how to format tool calls. By default, Llama
|
|
||||||
Stack will attempt to use a format that is best adapted to the model.
|
|
||||||
- `ToolPromptFormat.json`: The tool calls are formatted as a JSON object.
|
|
||||||
- `ToolPromptFormat.function_tag`: The tool calls are enclosed in a <function=function_name>
|
|
||||||
tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python
|
|
||||||
syntax -- a list of function calls.
|
|
||||||
system_message_behavior:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- append
|
|
||||||
- replace
|
|
||||||
description: >-
|
|
||||||
(Optional) Config for how to override the default system prompt. - `SystemMessageBehavior.append`:
|
|
||||||
Appends the provided system message to the default system prompt. - `SystemMessageBehavior.replace`:
|
|
||||||
Replaces the default system prompt with the provided system message. The
|
|
||||||
system message can include the string '{{function_definitions}}' to indicate
|
|
||||||
where the function definitions should be inserted.
|
|
||||||
default: append
|
|
||||||
additionalProperties: false
|
|
||||||
title: ToolConfig
|
|
||||||
description: Configuration for tool use.
|
|
||||||
ChatCompletionRequest:
|
ChatCompletionRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -7632,6 +7614,17 @@ tags:
|
||||||
x-displayName: >-
|
x-displayName: >-
|
||||||
Agents API for creating and interacting with agentic systems.
|
Agents API for creating and interacting with agentic systems.
|
||||||
- name: BatchInference (Coming Soon)
|
- name: BatchInference (Coming Soon)
|
||||||
|
description: >-
|
||||||
|
This is an asynchronous API. If the request is successful, the response will
|
||||||
|
be a job which can be polled for completion.
|
||||||
|
|
||||||
|
|
||||||
|
NOTE: This API is not yet implemented and is subject to change in concert with
|
||||||
|
other asynchronous APIs
|
||||||
|
|
||||||
|
including (post-training, evals, etc).
|
||||||
|
x-displayName: >-
|
||||||
|
Batch inference API for generating completions and chat completions.
|
||||||
- name: Benchmarks
|
- name: Benchmarks
|
||||||
- name: DatasetIO
|
- name: DatasetIO
|
||||||
- name: Datasets
|
- name: Datasets
|
||||||
|
|
|
@ -6,11 +6,8 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
|
||||||
CompletionResponse,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchCompletionResponse(BaseModel):
|
|
||||||
batch: List[CompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
|
||||||
batch: List[ChatCompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
|
"""Batch inference API for generating completions and chat completions.
|
||||||
|
|
||||||
|
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
||||||
|
|
||||||
|
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
||||||
|
including (post-training, evals, etc).
|
||||||
|
"""
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/completion", method="POST")
|
@webmethod(route="/batch-inference/completion", method="POST")
|
||||||
async def batch_completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: List[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def batch_chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: List[List[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> Job: ...
|
||||||
|
|
|
@ -681,6 +681,16 @@ class EmbeddingTaskType(Enum):
|
||||||
document = "document"
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
batch: List[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
@ -716,6 +726,17 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-completion", method="POST")
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -756,6 +777,19 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-chat-completion", method="POST")
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
BatchChatCompletionResponse,
|
||||||
|
BatchCompletionResponse,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -334,6 +336,30 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages_batch=messages_batch,
|
||||||
|
tools=tools,
|
||||||
|
tool_config=tool_config,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -398,6 +424,20 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
|
|
@ -226,7 +226,6 @@ class ChatFormat:
|
||||||
arguments_json=json.dumps(tool_arguments),
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
content = ""
|
|
||||||
|
|
||||||
return RawMessage(
|
return RawMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
|
|
@ -140,7 +140,12 @@ class Llama3:
|
||||||
|
|
||||||
return Llama3(model, tokenizer, model_args)
|
return Llama3(model, tokenizer, model_args)
|
||||||
|
|
||||||
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Transformer | CrossAttentionTransformer,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
args: ModelArgs,
|
||||||
|
):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
@ -149,7 +154,7 @@ class Llama3:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
model_inputs: List[LLMInput],
|
llm_inputs: List[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
|
@ -164,15 +169,15 @@ class Llama3:
|
||||||
|
|
||||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
if print_model_input:
|
if print_model_input:
|
||||||
for inp in model_inputs:
|
for inp in llm_inputs:
|
||||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||||
cprint(
|
cprint(
|
||||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
bsz = len(model_inputs)
|
bsz = len(llm_inputs)
|
||||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
@ -193,8 +198,8 @@ class Llama3:
|
||||||
|
|
||||||
is_vision = not isinstance(self.model, Transformer)
|
is_vision = not isinstance(self.model, Transformer)
|
||||||
if is_vision:
|
if is_vision:
|
||||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
|
|
||||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||||
batch_images=images,
|
batch_images=images,
|
||||||
|
@ -229,7 +234,7 @@ class Llama3:
|
||||||
for cur_pos in range(min_prompt_len, total_len):
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
if is_vision:
|
if is_vision:
|
||||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||||
logits = self.model.forward(
|
logits = self.model.forward(
|
||||||
position_ids,
|
position_ids,
|
||||||
tokens,
|
tokens,
|
||||||
|
@ -285,7 +290,7 @@ class Llama3:
|
||||||
source="output",
|
source="output",
|
||||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
finished=eos_reached[idx],
|
finished=eos_reached[idx].item(),
|
||||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -301,7 +301,6 @@ class ChatFormat:
|
||||||
arguments=tool_arguments,
|
arguments=tool_arguments,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
content = ""
|
|
||||||
|
|
||||||
return RawMessage(
|
return RawMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
|
|
@ -233,7 +233,7 @@ class Llama4:
|
||||||
source="output",
|
source="output",
|
||||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
batch_idx=idx,
|
batch_idx=idx,
|
||||||
finished=eos_reached[idx],
|
finished=eos_reached[idx].item(),
|
||||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -52,14 +52,17 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||||
|
max_batch_size: str = "${env.MAX_BATCH_SIZE:1}",
|
||||||
|
max_seq_len: str = "${env.MAX_SEQ_LEN:4096}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_seq_len": 4096,
|
|
||||||
"checkpoint_dir": checkpoint_dir,
|
"checkpoint_dir": checkpoint_dir,
|
||||||
"quantization": {
|
"quantization": {
|
||||||
"type": quantization_type,
|
"type": quantization_type,
|
||||||
},
|
},
|
||||||
"model_parallel_size": model_parallel_size,
|
"model_parallel_size": model_parallel_size,
|
||||||
|
"max_batch_size": max_batch_size,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_types import Model
|
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
|
@ -113,8 +113,7 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||||
return get_default_tool_prompt_format(request.model)
|
return get_default_tool_prompt_format(request.model)
|
||||||
|
|
||||||
|
|
||||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
class LlamaGenerator:
|
||||||
class Llama4Generator:
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
|
@ -144,7 +143,8 @@ class Llama4Generator:
|
||||||
else:
|
else:
|
||||||
quantization_mode = None
|
quantization_mode = None
|
||||||
|
|
||||||
self.inner_generator = Llama4.build(
|
cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
|
||||||
|
self.inner_generator = cls.build(
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
|
@ -158,142 +158,55 @@ class Llama4Generator:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequestWithRawContent,
|
request_batch: List[CompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
first_request = request_batch[0]
|
||||||
|
sampling_params = first_request.sampling_params or SamplingParams()
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
for result in self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
llm_inputs=[self.formatter.encode_content(request.content) for request in request_batch],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(first_request.logprobs),
|
||||||
echo=False,
|
echo=False,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
first_request.response_format,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield result[0]
|
yield result
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequestWithRawContent,
|
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
first_request = request_batch[0]
|
||||||
|
sampling_params = first_request.sampling_params or SamplingParams()
|
||||||
max_gen_len = sampling_params.max_tokens
|
max_gen_len = sampling_params.max_tokens
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
for result in self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
llm_inputs=[
|
||||||
|
self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))
|
||||||
|
for request in request_batch
|
||||||
|
],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
logprobs=bool(request.logprobs),
|
logprobs=bool(first_request.logprobs),
|
||||||
echo=False,
|
echo=False,
|
||||||
logits_processor=get_logits_processor(
|
logits_processor=get_logits_processor(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
first_request.response_format,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
yield result[0]
|
yield result
|
||||||
|
|
||||||
|
|
||||||
class Llama3Generator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: MetaReferenceInferenceConfig,
|
|
||||||
model_id: str,
|
|
||||||
llama_model: Model,
|
|
||||||
):
|
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
|
||||||
ckpt_dir = config.checkpoint_dir
|
|
||||||
else:
|
|
||||||
resolved_model = resolve_model(model_id)
|
|
||||||
if resolved_model is None:
|
|
||||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
|
||||||
ckpt_dir = model_checkpoint_dir(model_id)
|
|
||||||
else:
|
|
||||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
|
||||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
|
||||||
|
|
||||||
if config.quantization:
|
|
||||||
if config.quantization.type == "fp8_mixed":
|
|
||||||
quantization_mode = QuantizationMode.fp8_mixed
|
|
||||||
elif config.quantization.type == "int4_mixed":
|
|
||||||
quantization_mode = QuantizationMode.int4_mixed
|
|
||||||
elif config.quantization.type == "bf16":
|
|
||||||
quantization_mode = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
|
||||||
else:
|
|
||||||
quantization_mode = None
|
|
||||||
|
|
||||||
self.inner_generator = Llama3.build(
|
|
||||||
ckpt_dir=ckpt_dir,
|
|
||||||
max_seq_len=config.max_seq_len,
|
|
||||||
max_batch_size=config.max_batch_size,
|
|
||||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
|
||||||
quantization_mode=quantization_mode,
|
|
||||||
)
|
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
|
||||||
self.args = self.inner_generator.args
|
|
||||||
self.formatter = self.inner_generator.formatter
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
request: CompletionRequestWithRawContent,
|
|
||||||
) -> Generator:
|
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
for result in self.inner_generator.generate(
|
|
||||||
model_inputs=[self.formatter.encode_content(request.content)],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
request.response_format,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
yield result[0]
|
|
||||||
|
|
||||||
def chat_completion(
|
|
||||||
self,
|
|
||||||
request: ChatCompletionRequestWithRawContent,
|
|
||||||
) -> Generator:
|
|
||||||
sampling_params = request.sampling_params or SamplingParams()
|
|
||||||
max_gen_len = sampling_params.max_tokens
|
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
|
||||||
for result in self.inner_generator.generate(
|
|
||||||
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
logprobs=bool(request.logprobs),
|
|
||||||
echo=False,
|
|
||||||
logits_processor=get_logits_processor(
|
|
||||||
self.tokenizer,
|
|
||||||
self.args.vocab_size,
|
|
||||||
request.response_format,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
yield result[0]
|
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -17,6 +17,8 @@ from llama_stack.apis.common.content_types import (
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
BatchChatCompletionResponse,
|
||||||
|
BatchCompletionResponse,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
|
@ -38,8 +40,10 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
|
@ -65,21 +69,17 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generators import Llama3Generator, Llama4Generator
|
from .generators import LlamaGenerator
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = get_logger(__name__, category="inference")
|
||||||
# there's a single model parallel process running serving the model. for now,
|
# there's a single model parallel process running serving the model. for now,
|
||||||
# we don't support multiple concurrent requests to this process.
|
# we don't support multiple concurrent requests to this process.
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
def llama3_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama3Generator:
|
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
||||||
return Llama3Generator(config, model_id, llama_model)
|
return LlamaGenerator(config, model_id, llama_model)
|
||||||
|
|
||||||
|
|
||||||
def llama4_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> Llama4Generator:
|
|
||||||
return Llama4Generator(config, model_id, llama_model)
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(
|
class MetaReferenceInferenceImpl(
|
||||||
|
@ -139,24 +139,12 @@ class MetaReferenceInferenceImpl(
|
||||||
async def load_model(self, model_id, llama_model) -> None:
|
async def load_model(self, model_id, llama_model) -> None:
|
||||||
log.info(f"Loading model `{model_id}`")
|
log.info(f"Loading model `{model_id}`")
|
||||||
|
|
||||||
if llama_model.model_family in {
|
|
||||||
ModelFamily.llama3,
|
|
||||||
ModelFamily.llama3_1,
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
}:
|
|
||||||
builder_fn = llama3_builder_fn
|
|
||||||
elif llama_model.model_family == ModelFamily.llama4:
|
|
||||||
builder_fn = llama4_builder_fn
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model family: {llama_model.model_family}")
|
|
||||||
|
|
||||||
builder_params = [self.config, model_id, llama_model]
|
builder_params = [self.config, model_id, llama_model]
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||||
builder_fn=builder_fn,
|
builder_fn=llama_builder_fn,
|
||||||
builder_params=builder_params,
|
builder_params=builder_params,
|
||||||
formatter=(
|
formatter=(
|
||||||
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
||||||
|
@ -166,11 +154,24 @@ class MetaReferenceInferenceImpl(
|
||||||
)
|
)
|
||||||
self.generator.start()
|
self.generator.start()
|
||||||
else:
|
else:
|
||||||
self.generator = builder_fn(*builder_params)
|
self.generator = llama_builder_fn(*builder_params)
|
||||||
|
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
self.llama_model = llama_model
|
self.llama_model = llama_model
|
||||||
|
|
||||||
|
log.info("Warming up...")
|
||||||
|
await self.completion(
|
||||||
|
model_id=model_id,
|
||||||
|
content="Hello, world!",
|
||||||
|
sampling_params=SamplingParams(max_tokens=10),
|
||||||
|
)
|
||||||
|
await self.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[UserMessage(content="Hi how are you?")],
|
||||||
|
sampling_params=SamplingParams(max_tokens=20),
|
||||||
|
)
|
||||||
|
log.info("Warmed up!")
|
||||||
|
|
||||||
def check_model(self, request) -> None:
|
def check_model(self, request) -> None:
|
||||||
if self.model_id is None or self.llama_model is None:
|
if self.model_id is None or self.llama_model is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -208,7 +209,43 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_completion(request)
|
return self._stream_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
results = await self._nonstream_completion([request])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
|
content_batch = [
|
||||||
|
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||||
|
]
|
||||||
|
|
||||||
|
request_batch = []
|
||||||
|
for content in content_batch:
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
request_batch.append(request)
|
||||||
|
|
||||||
|
results = await self._nonstream_completion(request_batch)
|
||||||
|
return BatchCompletionResponse(batch=results)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
@ -253,37 +290,54 @@ class MetaReferenceInferenceImpl(
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
async def _nonstream_completion(self, request_batch: List[CompletionRequest]) -> List[CompletionResponse]:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
|
first_request = request_batch[0]
|
||||||
|
|
||||||
|
class ItemState(BaseModel):
|
||||||
|
tokens: List[int] = []
|
||||||
|
logprobs: List[TokenLogProbs] = []
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
states = [ItemState() for _ in request_batch]
|
||||||
logprobs = []
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
results = []
|
||||||
tokens.append(token_result.token)
|
for token_results in self.generator.completion(request_batch):
|
||||||
if token_result.token == tokenizer.eot_id:
|
for result in token_results:
|
||||||
stop_reason = StopReason.end_of_turn
|
idx = result.batch_idx
|
||||||
elif token_result.token == tokenizer.eom_id:
|
state = states[idx]
|
||||||
stop_reason = StopReason.end_of_message
|
if state.finished or result.ignore_token:
|
||||||
|
continue
|
||||||
|
|
||||||
if request.logprobs:
|
state.finished = result.finished
|
||||||
assert len(token_result.logprobs) == 1
|
if first_request.logprobs:
|
||||||
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
state.tokens.append(result.token)
|
||||||
|
if result.token == tokenizer.eot_id:
|
||||||
|
state.stop_reason = StopReason.end_of_turn
|
||||||
|
elif result.token == tokenizer.eom_id:
|
||||||
|
state.stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
if stop_reason is None:
|
for state in states:
|
||||||
stop_reason = StopReason.out_of_tokens
|
if state.stop_reason is None:
|
||||||
|
state.stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
if tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
if state.tokens[-1] in self.generator.formatter.tokenizer.stop_tokens:
|
||||||
tokens = tokens[:-1]
|
state.tokens = state.tokens[:-1]
|
||||||
content = self.generator.formatter.tokenizer.decode(tokens)
|
content = self.generator.formatter.tokenizer.decode(state.tokens)
|
||||||
return CompletionResponse(
|
results.append(
|
||||||
content=content,
|
CompletionResponse(
|
||||||
stop_reason=stop_reason,
|
content=content,
|
||||||
logprobs=logprobs if request.logprobs else None,
|
stop_reason=state.stop_reason,
|
||||||
)
|
logprobs=state.logprobs if first_request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
|
@ -318,7 +372,7 @@ class MetaReferenceInferenceImpl(
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
tool_config=tool_config,
|
tool_config=tool_config or ToolConfig(),
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
|
||||||
|
@ -334,44 +388,110 @@ class MetaReferenceInferenceImpl(
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_chat_completion(request)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
results = await self._nonstream_chat_completion([request])
|
||||||
|
return results[0]
|
||||||
|
|
||||||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if logprobs:
|
||||||
|
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||||
|
|
||||||
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
|
request_batch = []
|
||||||
|
for messages in messages_batch:
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config or ToolConfig(),
|
||||||
|
)
|
||||||
|
self.check_model(request)
|
||||||
|
|
||||||
|
# augment and rewrite messages depending on the model
|
||||||
|
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||||
|
# download media and convert to raw content so we can send it to the model
|
||||||
|
request = await convert_request_to_raw(request)
|
||||||
|
request_batch.append(request)
|
||||||
|
|
||||||
|
if self.config.create_distributed_process_group:
|
||||||
|
if SEMAPHORE.locked():
|
||||||
|
raise RuntimeError("Only one concurrent request is supported")
|
||||||
|
|
||||||
|
results = await self._nonstream_chat_completion(request_batch)
|
||||||
|
return BatchChatCompletionResponse(batch=results)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(
|
||||||
|
self, request_batch: List[ChatCompletionRequest]
|
||||||
|
) -> List[ChatCompletionResponse]:
|
||||||
tokenizer = self.generator.formatter.tokenizer
|
tokenizer = self.generator.formatter.tokenizer
|
||||||
|
|
||||||
|
first_request = request_batch[0]
|
||||||
|
|
||||||
|
class ItemState(BaseModel):
|
||||||
|
tokens: List[int] = []
|
||||||
|
logprobs: List[TokenLogProbs] = []
|
||||||
|
stop_reason: StopReason | None = None
|
||||||
|
finished: bool = False
|
||||||
|
|
||||||
def impl():
|
def impl():
|
||||||
tokens = []
|
states = [ItemState() for _ in request_batch]
|
||||||
logprobs = []
|
|
||||||
stop_reason = None
|
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_results in self.generator.chat_completion(request_batch):
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
first = token_results[0]
|
||||||
cprint(token_result.text, "cyan", end="")
|
if not first.finished and not first.ignore_token:
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
|
||||||
|
cprint(first.text, "cyan", end="")
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
cprint(f"<{first.token}>", "magenta", end="")
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
for result in token_results:
|
||||||
|
idx = result.batch_idx
|
||||||
|
state = states[idx]
|
||||||
|
if state.finished or result.ignore_token:
|
||||||
|
continue
|
||||||
|
|
||||||
if token_result.token == tokenizer.eot_id:
|
state.finished = result.finished
|
||||||
stop_reason = StopReason.end_of_turn
|
if first_request.logprobs:
|
||||||
elif token_result.token == tokenizer.eom_id:
|
state.logprobs.append(TokenLogProbs(logprobs_by_token={result.text: result.logprobs[0]}))
|
||||||
stop_reason = StopReason.end_of_message
|
|
||||||
|
|
||||||
if request.logprobs:
|
state.tokens.append(result.token)
|
||||||
assert len(token_result.logprobs) == 1
|
if result.token == tokenizer.eot_id:
|
||||||
|
state.stop_reason = StopReason.end_of_turn
|
||||||
|
elif result.token == tokenizer.eom_id:
|
||||||
|
state.stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
results = []
|
||||||
|
for state in states:
|
||||||
|
if state.stop_reason is None:
|
||||||
|
state.stop_reason = StopReason.out_of_tokens
|
||||||
|
|
||||||
if stop_reason is None:
|
raw_message = self.generator.formatter.decode_assistant_message(state.tokens, state.stop_reason)
|
||||||
stop_reason = StopReason.out_of_tokens
|
results.append(
|
||||||
|
ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=raw_message.content,
|
||||||
|
stop_reason=raw_message.stop_reason,
|
||||||
|
tool_calls=raw_message.tool_calls,
|
||||||
|
),
|
||||||
|
logprobs=state.logprobs if first_request.logprobs else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
raw_message = self.generator.formatter.decode_assistant_message(tokens, stop_reason)
|
return results
|
||||||
return ChatCompletionResponse(
|
|
||||||
completion_message=CompletionMessage(
|
|
||||||
content=raw_message.content,
|
|
||||||
stop_reason=raw_message.stop_reason,
|
|
||||||
tool_calls=raw_message.tool_calls,
|
|
||||||
),
|
|
||||||
logprobs=logprobs if request.logprobs else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
async with SEMAPHORE:
|
async with SEMAPHORE:
|
||||||
|
@ -398,6 +518,22 @@ class MetaReferenceInferenceImpl(
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_result in self.generator.chat_completion(request):
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, "cyan", end="")
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
cprint(f"<{token_result.token}>", "magenta", end="")
|
||||||
|
|
||||||
|
if token_result.token == tokenizer.eot_id:
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
text = ""
|
||||||
|
elif token_result.token == tokenizer.eom_id:
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
text = ""
|
||||||
|
else:
|
||||||
|
text = token_result.text
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
assert len(token_result.logprobs) == 1
|
||||||
|
|
||||||
|
logprobs.append(TokenLogProbs(logprobs_by_token={token_result.text: token_result.logprobs[0]}))
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Generator
|
from typing import Any, Callable, Generator, List
|
||||||
|
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
|
@ -23,13 +23,13 @@ class ModelRunner:
|
||||||
self.llama = llama
|
self.llama = llama
|
||||||
|
|
||||||
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
|
||||||
def __call__(self, req: Any):
|
def __call__(self, task: Any):
|
||||||
if isinstance(req, ChatCompletionRequestWithRawContent):
|
if task[0] == "chat_completion":
|
||||||
return self.llama.chat_completion(req)
|
return self.llama.chat_completion(task[1])
|
||||||
elif isinstance(req, CompletionRequestWithRawContent):
|
elif task[0] == "completion":
|
||||||
return self.llama.completion(req)
|
return self.llama.completion(task[1])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected task type {type(req)}")
|
raise ValueError(f"Unexpected task type {task[0]}")
|
||||||
|
|
||||||
|
|
||||||
def init_model_cb(
|
def init_model_cb(
|
||||||
|
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
request: CompletionRequestWithRawContent,
|
request_batch: List[CompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request_batch)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(("completion", req_obj))
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequestWithRawContent,
|
request_batch: List[ChatCompletionRequestWithRawContent],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
req_obj = deepcopy(request)
|
req_obj = deepcopy(request_batch)
|
||||||
gen = self.group.run_inference(req_obj)
|
gen = self.group.run_inference(("chat_completion", req_obj))
|
||||||
yield from gen
|
yield from gen
|
||||||
|
|
|
@ -19,7 +19,7 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Generator, Literal, Optional, Union
|
from typing import Callable, Generator, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
|
@ -69,12 +69,12 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
|
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||||
result: GenerationResult
|
result: List[GenerationResult]
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
|
@ -331,7 +331,7 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent],
|
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
Inference,
|
Inference,
|
||||||
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -80,3 +81,25 @@ class SentenceTransformersInferenceImpl(
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
raise ValueError("Sentence transformers don't support chat completion")
|
raise ValueError("Sentence transformers don't support chat completion")
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Sentence Transformers")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||||
|
|
|
@ -437,6 +437,28 @@ class OllamaInferenceAdapter(
|
||||||
}
|
}
|
||||||
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
return await self.openai_client.chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
@ -526,3 +526,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return await self.client.chat.completions.create(**params) # type: ignore
|
return await self.client.chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
|
@ -347,3 +347,25 @@ class LiteLLMOpenAIMixin(
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
return litellm.completion(**params)
|
return litellm.completion(**params)
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch completion is not supported for OpenAI Compat")
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat")
|
||||||
|
|
|
@ -16,11 +16,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
@ -28,11 +29,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.SAFETY_MODEL}
|
model: ${env.SAFETY_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.SAFETY_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
@ -16,11 +16,12 @@ providers:
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: ${env.QUANTIZATION_TYPE:bf16}
|
type: ${env.QUANTIZATION_TYPE:bf16}
|
||||||
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
model_parallel_size: ${env.MODEL_PARALLEL_SIZE:0}
|
||||||
|
max_batch_size: ${env.MAX_BATCH_SIZE:1}
|
||||||
|
max_seq_len: ${env.MAX_SEQ_LEN:4096}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
|
76
tests/integration/inference/test_batch_inference.py
Normal file
76
tests/integration/inference/test_batch_inference.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..test_cases.test_case import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
|
||||||
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||||
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||||
|
provider_id = models[model_id].provider_id
|
||||||
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||||
|
provider = providers[provider_id]
|
||||||
|
if provider.provider_type not in ("inline::meta-reference",):
|
||||||
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:completion:batch_completion",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
|
content_batch = tc["contents"]
|
||||||
|
response = client_with_models.inference.batch_completion(
|
||||||
|
content_batch=content_batch,
|
||||||
|
model_id=text_model_id,
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert len(response.batch) == len(content_batch)
|
||||||
|
for i, r in enumerate(response.batch):
|
||||||
|
print(f"response {i}: {r.content}")
|
||||||
|
assert len(r.content) > 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
"inference:chat_completion:batch_completion",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||||
|
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||||
|
tc = TestCase(test_case)
|
||||||
|
qa_pairs = tc["qa_pairs"]
|
||||||
|
|
||||||
|
message_batch = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": qa["question"],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for qa in qa_pairs
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client_with_models.inference.batch_chat_completion(
|
||||||
|
messages_batch=message_batch,
|
||||||
|
model_id=text_model_id,
|
||||||
|
)
|
||||||
|
assert len(response.batch) == len(qa_pairs)
|
||||||
|
for i, r in enumerate(response.batch):
|
||||||
|
print(f"response {i}: {r.completion_message.content}")
|
||||||
|
assert len(r.completion_message.content) > 0
|
||||||
|
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()
|
|
@ -537,5 +537,31 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"batch_completion": {
|
||||||
|
"data": {
|
||||||
|
"qa_pairs": [
|
||||||
|
{
|
||||||
|
"question": "What is the capital of France?",
|
||||||
|
"answer": "Paris"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Who wrote the book '1984'?",
|
||||||
|
"answer": "George Orwell"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "Which planet has rings around it with a name starting with letter S?",
|
||||||
|
"answer": "Saturn"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "When did the first moon landing happen?",
|
||||||
|
"answer": "1969"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question": "What word says 'hello' in Spanish?",
|
||||||
|
"answer": "Hola"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,5 +44,18 @@
|
||||||
"year_retired": "2003"
|
"year_retired": "2003"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"batch_completion": {
|
||||||
|
"data": {
|
||||||
|
"contents": [
|
||||||
|
"Micheael Jordan is born in ",
|
||||||
|
"Roses are red, violets are ",
|
||||||
|
"If you had a million dollars, what would you do with it? ",
|
||||||
|
"All you need is ",
|
||||||
|
"The capital of France is ",
|
||||||
|
"It is a good day to ",
|
||||||
|
"The answer to the universe is "
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue