diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a1f6a6f30..8106a54dc 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -86,92 +86,6 @@ } } }, - "/v1/inference/batch-chat-completion": { - "post": { - "responses": { - "200": { - "description": "A BatchChatCompletionResponse with the full completions.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchChatCompletionResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "Generate chat completions for a batch of messages using the specified model.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchChatCompletionRequest" - } - } - }, - "required": true - } - } - }, - "/v1/inference/batch-completion": { - "post": { - "responses": { - "200": { - "description": "A BatchCompletionResponse with the full completions.", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchCompletionResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "Inference" - ], - "description": "Generate completions for a batch of content using the specified model.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/BatchCompletionRequest" - } - } - }, - "required": true - } - } - }, "/v1/post-training/job/cancel": { "post": { "responses": { @@ -240,7 +154,7 @@ } }, "tags": [ - "BatchInference (Coming Soon)" + "Inference" ], "description": "Generate a chat completion for the given messages using the specified model.", "parameters": [], @@ -288,7 +202,7 @@ } }, "tags": [ - "BatchInference (Coming Soon)" + "Inference" ], "description": "Generate a completion for the given content using the specified model.", "parameters": [], @@ -5176,6 +5090,20 @@ ], "title": "AppendRowsRequest" }, + "CancelTrainingJobRequest": { + "type": "object", + "properties": { + "job_uuid": { + "type": "string", + "description": "The UUID of the job to cancel." + } + }, + "additionalProperties": false, + "required": [ + "job_uuid" + ], + "title": "CancelTrainingJobRequest" + }, "CompletionMessage": { "type": "object", "properties": { @@ -5881,26 +5809,23 @@ "title": "UserMessage", "description": "A message from the user in a chat conversation." }, - "BatchChatCompletionRequest": { + "ChatCompletionRequest": { "type": "object", "properties": { "model_id": { "type": "string", "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." }, - "messages_batch": { + "messages": { "type": "array", "items": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - } + "$ref": "#/components/schemas/Message" }, - "description": "The messages to generate completions for." + "description": "List of messages in the conversation." }, "sampling_params": { "$ref": "#/components/schemas/SamplingParams", - "description": "(Optional) Parameters to control the sampling strategy." + "description": "Parameters to control the sampling strategy." }, "tools": { "type": "array", @@ -5909,13 +5834,31 @@ }, "description": "(Optional) List of tool definitions available to the model." }, - "tool_config": { - "$ref": "#/components/schemas/ToolConfig", - "description": "(Optional) Configuration for tool use." + "tool_choice": { + "type": "string", + "enum": [ + "auto", + "required", + "none" + ], + "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead." + }, + "tool_prompt_format": { + "type": "string", + "enum": [ + "json", + "function_tag", + "python_list" + ], + "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. .. deprecated:: Use tool_config instead." }, "response_format": { "$ref": "#/components/schemas/ResponseFormat", - "description": "(Optional) Grammar specification for guided (structured) decoding." + "description": "(Optional) Grammar specification for guided (structured) decoding. There are two options: - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it." + }, + "stream": { + "type": "boolean", + "description": "(Optional) If True, generate an SSE event stream of the response. Defaults to False." }, "logprobs": { "type": "object", @@ -5928,32 +5871,18 @@ }, "additionalProperties": false, "description": "(Optional) If specified, log probabilities for each token position will be returned." + }, + "tool_config": { + "$ref": "#/components/schemas/ToolConfig", + "description": "(Optional) Configuration for tool use." } }, "additionalProperties": false, "required": [ "model_id", - "messages_batch" + "messages" ], - "title": "BatchChatCompletionRequest" - }, - "BatchChatCompletionResponse": { - "type": "object", - "properties": { - "batch": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ChatCompletionResponse" - }, - "description": "List of chat completion responses, one for each conversation in the batch" - } - }, - "additionalProperties": false, - "required": [ - "batch" - ], - "title": "BatchChatCompletionResponse", - "description": "Response from a batch chat completion request." + "title": "ChatCompletionRequest" }, "ChatCompletionResponse": { "type": "object", @@ -6033,194 +5962,6 @@ "title": "TokenLogProbs", "description": "Log probabilities for generated tokens." }, - "BatchCompletionRequest": { - "type": "object", - "properties": { - "model_id": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." - }, - "content_batch": { - "type": "array", - "items": { - "$ref": "#/components/schemas/InterleavedContent" - }, - "description": "The content to generate completions for." - }, - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams", - "description": "(Optional) Parameters to control the sampling strategy." - }, - "response_format": { - "$ref": "#/components/schemas/ResponseFormat", - "description": "(Optional) Grammar specification for guided (structured) decoding." - }, - "logprobs": { - "type": "object", - "properties": { - "top_k": { - "type": "integer", - "default": 0, - "description": "How many tokens (for each position) to return log probabilities for." - } - }, - "additionalProperties": false, - "description": "(Optional) If specified, log probabilities for each token position will be returned." - } - }, - "additionalProperties": false, - "required": [ - "model_id", - "content_batch" - ], - "title": "BatchCompletionRequest" - }, - "BatchCompletionResponse": { - "type": "object", - "properties": { - "batch": { - "type": "array", - "items": { - "$ref": "#/components/schemas/CompletionResponse" - }, - "description": "List of completion responses, one for each input in the batch" - } - }, - "additionalProperties": false, - "required": [ - "batch" - ], - "title": "BatchCompletionResponse", - "description": "Response from a batch completion request." - }, - "CompletionResponse": { - "type": "object", - "properties": { - "metrics": { - "type": "array", - "items": { - "$ref": "#/components/schemas/MetricInResponse" - }, - "description": "(Optional) List of metrics associated with the API response" - }, - "content": { - "type": "string", - "description": "The generated completion text" - }, - "stop_reason": { - "type": "string", - "enum": [ - "end_of_turn", - "end_of_message", - "out_of_tokens" - ], - "description": "Reason why generation stopped" - }, - "logprobs": { - "type": "array", - "items": { - "$ref": "#/components/schemas/TokenLogProbs" - }, - "description": "Optional log probabilities for generated tokens" - } - }, - "additionalProperties": false, - "required": [ - "content", - "stop_reason" - ], - "title": "CompletionResponse", - "description": "Response from a completion request." - }, - "CancelTrainingJobRequest": { - "type": "object", - "properties": { - "job_uuid": { - "type": "string", - "description": "The UUID of the job to cancel." - } - }, - "additionalProperties": false, - "required": [ - "job_uuid" - ], - "title": "CancelTrainingJobRequest" - }, - "ChatCompletionRequest": { - "type": "object", - "properties": { - "model_id": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - }, - "description": "List of messages in the conversation." - }, - "sampling_params": { - "$ref": "#/components/schemas/SamplingParams", - "description": "Parameters to control the sampling strategy." - }, - "tools": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolDefinition" - }, - "description": "(Optional) List of tool definitions available to the model." - }, - "tool_choice": { - "type": "string", - "enum": [ - "auto", - "required", - "none" - ], - "description": "(Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. .. deprecated:: Use tool_config instead." - }, - "tool_prompt_format": { - "type": "string", - "enum": [ - "json", - "function_tag", - "python_list" - ], - "description": "(Optional) Instructs the model how to format tool calls. By default, Llama Stack will attempt to use a format that is best adapted to the model. - `ToolPromptFormat.json`: The tool calls are formatted as a JSON object. - `ToolPromptFormat.function_tag`: The tool calls are enclosed in a tag. - `ToolPromptFormat.python_list`: The tool calls are output as Python syntax -- a list of function calls. .. deprecated:: Use tool_config instead." - }, - "response_format": { - "$ref": "#/components/schemas/ResponseFormat", - "description": "(Optional) Grammar specification for guided (structured) decoding. There are two options: - `ResponseFormat.json_schema`: The grammar is a JSON schema. Most providers support this format. - `ResponseFormat.grammar`: The grammar is a BNF grammar. This format is more flexible, but not all providers support it." - }, - "stream": { - "type": "boolean", - "description": "(Optional) If True, generate an SSE event stream of the response. Defaults to False." - }, - "logprobs": { - "type": "object", - "properties": { - "top_k": { - "type": "integer", - "default": 0, - "description": "How many tokens (for each position) to return log probabilities for." - } - }, - "additionalProperties": false, - "description": "(Optional) If specified, log probabilities for each token position will be returned." - }, - "tool_config": { - "$ref": "#/components/schemas/ToolConfig", - "description": "(Optional) Configuration for tool use." - } - }, - "additionalProperties": false, - "required": [ - "model_id", - "messages" - ], - "title": "ChatCompletionRequest" - }, "ChatCompletionResponseEvent": { "type": "object", "properties": { @@ -6433,6 +6174,45 @@ ], "title": "CompletionRequest" }, + "CompletionResponse": { + "type": "object", + "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricInResponse" + }, + "description": "(Optional) List of metrics associated with the API response" + }, + "content": { + "type": "string", + "description": "The generated completion text" + }, + "stop_reason": { + "type": "string", + "enum": [ + "end_of_turn", + "end_of_message", + "out_of_tokens" + ], + "description": "Reason why generation stopped" + }, + "logprobs": { + "type": "array", + "items": { + "$ref": "#/components/schemas/TokenLogProbs" + }, + "description": "Optional log probabilities for generated tokens" + } + }, + "additionalProperties": false, + "required": [ + "content", + "stop_reason" + ], + "title": "CompletionResponse", + "description": "Response from a completion request." + }, "CompletionResponseStreamChunk": { "type": "object", "properties": { @@ -17480,11 +17260,6 @@ "description": "Main functionalities provided by this API:\n- Create agents with specific instructions and ability to use tools.\n- Interactions with agents are grouped into sessions (\"threads\"), and each interaction is called a \"turn\".\n- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).\n- Agents can be provided with various shields (see the Safety API for more details).\n- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.", "x-displayName": "Agents API for creating and interacting with agentic systems." }, - { - "name": "BatchInference (Coming Soon)", - "description": "This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.\n\nNOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs\nincluding (post-training, evals, etc).", - "x-displayName": "Batch inference API for generating completions and chat completions." - }, { "name": "Benchmarks" }, @@ -17555,7 +17330,6 @@ "name": "Operations", "tags": [ "Agents", - "BatchInference (Coming Soon)", "Benchmarks", "DatasetIO", "Datasets", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 33142e3ff..f10af5e44 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -42,68 +42,6 @@ paths: schema: $ref: '#/components/schemas/AppendRowsRequest' required: true - /v1/inference/batch-chat-completion: - post: - responses: - '200': - description: >- - A BatchChatCompletionResponse with the full completions. - content: - application/json: - schema: - $ref: '#/components/schemas/BatchChatCompletionResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: >- - Generate chat completions for a batch of messages using the specified model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/BatchChatCompletionRequest' - required: true - /v1/inference/batch-completion: - post: - responses: - '200': - description: >- - A BatchCompletionResponse with the full completions. - content: - application/json: - schema: - $ref: '#/components/schemas/BatchCompletionResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Inference - description: >- - Generate completions for a batch of content using the specified model. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/BatchCompletionRequest' - required: true /v1/post-training/job/cancel: post: responses: @@ -154,7 +92,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - BatchInference (Coming Soon) + - Inference description: >- Generate a chat completion for the given messages using the specified model. parameters: [] @@ -189,7 +127,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - BatchInference (Coming Soon) + - Inference description: >- Generate a completion for the given content using the specified model. parameters: [] @@ -3668,6 +3606,16 @@ components: required: - rows title: AppendRowsRequest + CancelTrainingJobRequest: + type: object + properties: + job_uuid: + type: string + description: The UUID of the job to cancel. + additionalProperties: false + required: + - job_uuid + title: CancelTrainingJobRequest CompletionMessage: type: object properties: @@ -4185,224 +4133,6 @@ components: title: UserMessage description: >- A message from the user in a chat conversation. - BatchChatCompletionRequest: - type: object - properties: - model_id: - type: string - description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. - messages_batch: - type: array - items: - type: array - items: - $ref: '#/components/schemas/Message' - description: >- - The messages to generate completions for. - sampling_params: - $ref: '#/components/schemas/SamplingParams' - description: >- - (Optional) Parameters to control the sampling strategy. - tools: - type: array - items: - $ref: '#/components/schemas/ToolDefinition' - description: >- - (Optional) List of tool definitions available to the model. - tool_config: - $ref: '#/components/schemas/ToolConfig' - description: (Optional) Configuration for tool use. - response_format: - $ref: '#/components/schemas/ResponseFormat' - description: >- - (Optional) Grammar specification for guided (structured) decoding. - logprobs: - type: object - properties: - top_k: - type: integer - default: 0 - description: >- - How many tokens (for each position) to return log probabilities for. - additionalProperties: false - description: >- - (Optional) If specified, log probabilities for each token position will - be returned. - additionalProperties: false - required: - - model_id - - messages_batch - title: BatchChatCompletionRequest - BatchChatCompletionResponse: - type: object - properties: - batch: - type: array - items: - $ref: '#/components/schemas/ChatCompletionResponse' - description: >- - List of chat completion responses, one for each conversation in the batch - additionalProperties: false - required: - - batch - title: BatchChatCompletionResponse - description: >- - Response from a batch chat completion request. - ChatCompletionResponse: - type: object - properties: - metrics: - type: array - items: - $ref: '#/components/schemas/MetricInResponse' - description: >- - (Optional) List of metrics associated with the API response - completion_message: - $ref: '#/components/schemas/CompletionMessage' - description: The complete response message - logprobs: - type: array - items: - $ref: '#/components/schemas/TokenLogProbs' - description: >- - Optional log probabilities for generated tokens - additionalProperties: false - required: - - completion_message - title: ChatCompletionResponse - description: Response from a chat completion request. - MetricInResponse: - type: object - properties: - metric: - type: string - description: The name of the metric - value: - oneOf: - - type: integer - - type: number - description: The numeric value of the metric - unit: - type: string - description: >- - (Optional) The unit of measurement for the metric value - additionalProperties: false - required: - - metric - - value - title: MetricInResponse - description: >- - A metric value included in API responses. - TokenLogProbs: - type: object - properties: - logprobs_by_token: - type: object - additionalProperties: - type: number - description: >- - Dictionary mapping tokens to their log probabilities - additionalProperties: false - required: - - logprobs_by_token - title: TokenLogProbs - description: Log probabilities for generated tokens. - BatchCompletionRequest: - type: object - properties: - model_id: - type: string - description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. - content_batch: - type: array - items: - $ref: '#/components/schemas/InterleavedContent' - description: The content to generate completions for. - sampling_params: - $ref: '#/components/schemas/SamplingParams' - description: >- - (Optional) Parameters to control the sampling strategy. - response_format: - $ref: '#/components/schemas/ResponseFormat' - description: >- - (Optional) Grammar specification for guided (structured) decoding. - logprobs: - type: object - properties: - top_k: - type: integer - default: 0 - description: >- - How many tokens (for each position) to return log probabilities for. - additionalProperties: false - description: >- - (Optional) If specified, log probabilities for each token position will - be returned. - additionalProperties: false - required: - - model_id - - content_batch - title: BatchCompletionRequest - BatchCompletionResponse: - type: object - properties: - batch: - type: array - items: - $ref: '#/components/schemas/CompletionResponse' - description: >- - List of completion responses, one for each input in the batch - additionalProperties: false - required: - - batch - title: BatchCompletionResponse - description: >- - Response from a batch completion request. - CompletionResponse: - type: object - properties: - metrics: - type: array - items: - $ref: '#/components/schemas/MetricInResponse' - description: >- - (Optional) List of metrics associated with the API response - content: - type: string - description: The generated completion text - stop_reason: - type: string - enum: - - end_of_turn - - end_of_message - - out_of_tokens - description: Reason why generation stopped - logprobs: - type: array - items: - $ref: '#/components/schemas/TokenLogProbs' - description: >- - Optional log probabilities for generated tokens - additionalProperties: false - required: - - content - - stop_reason - title: CompletionResponse - description: Response from a completion request. - CancelTrainingJobRequest: - type: object - properties: - job_uuid: - type: string - description: The UUID of the job to cancel. - additionalProperties: false - required: - - job_uuid - title: CancelTrainingJobRequest ChatCompletionRequest: type: object properties: @@ -4481,6 +4211,65 @@ components: - model_id - messages title: ChatCompletionRequest + ChatCompletionResponse: + type: object + properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response + completion_message: + $ref: '#/components/schemas/CompletionMessage' + description: The complete response message + logprobs: + type: array + items: + $ref: '#/components/schemas/TokenLogProbs' + description: >- + Optional log probabilities for generated tokens + additionalProperties: false + required: + - completion_message + title: ChatCompletionResponse + description: Response from a chat completion request. + MetricInResponse: + type: object + properties: + metric: + type: string + description: The name of the metric + value: + oneOf: + - type: integer + - type: number + description: The numeric value of the metric + unit: + type: string + description: >- + (Optional) The unit of measurement for the metric value + additionalProperties: false + required: + - metric + - value + title: MetricInResponse + description: >- + A metric value included in API responses. + TokenLogProbs: + type: object + properties: + logprobs_by_token: + type: object + additionalProperties: + type: number + description: >- + Dictionary mapping tokens to their log probabilities + additionalProperties: false + required: + - logprobs_by_token + title: TokenLogProbs + description: Log probabilities for generated tokens. ChatCompletionResponseEvent: type: object properties: @@ -4658,6 +4447,37 @@ components: - model_id - content title: CompletionRequest + CompletionResponse: + type: object + properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response + content: + type: string + description: The generated completion text + stop_reason: + type: string + enum: + - end_of_turn + - end_of_message + - out_of_tokens + description: Reason why generation stopped + logprobs: + type: array + items: + $ref: '#/components/schemas/TokenLogProbs' + description: >- + Optional log probabilities for generated tokens + additionalProperties: false + required: + - content + - stop_reason + title: CompletionResponse + description: Response from a completion request. CompletionResponseStreamChunk: type: object properties: @@ -12981,18 +12801,6 @@ tags: the RAG Tool and Vector IO APIs for more details. x-displayName: >- Agents API for creating and interacting with agentic systems. - - name: BatchInference (Coming Soon) - description: >- - This is an asynchronous API. If the request is successful, the response will - be a job which can be polled for completion. - - - NOTE: This API is not yet implemented and is subject to change in concert with - other asynchronous APIs - - including (post-training, evals, etc). - x-displayName: >- - Batch inference API for generating completions and chat completions. - name: Benchmarks - name: DatasetIO - name: Datasets @@ -13032,7 +12840,6 @@ x-tagGroups: - name: Operations tags: - Agents - - BatchInference (Coming Soon) - Benchmarks - DatasetIO - Datasets diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index e2c73e33c..144eb00f7 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -549,7 +549,6 @@ class Generator: if op.defining_class.__name__ in [ "SyntheticDataGeneration", "PostTraining", - "BatchInference", ]: op.defining_class.__name__ = f"{op.defining_class.__name__} (Coming Soon)" print(op.defining_class.__name__) diff --git a/docs/source/references/python_sdk_reference/index.md b/docs/source/references/python_sdk_reference/index.md index b1a9396fe..e0b29363e 100644 --- a/docs/source/references/python_sdk_reference/index.md +++ b/docs/source/references/python_sdk_reference/index.md @@ -139,18 +139,7 @@ Methods: - client.agents.turn.create(session_id, \*, agent_id, \*\*params) -> TurnCreateResponse - client.agents.turn.retrieve(turn_id, \*, agent_id, session_id) -> Turn -## BatchInference -Types: - -```python -from llama_stack_client.types import BatchInferenceChatCompletionResponse -``` - -Methods: - -- client.batch_inference.chat_completion(\*\*params) -> BatchInferenceChatCompletionResponse -- client.batch_inference.completion(\*\*params) -> BatchCompletion ## Datasets diff --git a/llama_stack/apis/batch_inference/__init__.py b/llama_stack/apis/batch_inference/__init__.py deleted file mode 100644 index b9b2944b2..000000000 --- a/llama_stack/apis/batch_inference/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .batch_inference import * diff --git a/llama_stack/apis/batch_inference/batch_inference.py b/llama_stack/apis/batch_inference/batch_inference.py deleted file mode 100644 index b2aa637e2..000000000 --- a/llama_stack/apis/batch_inference/batch_inference.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Protocol, runtime_checkable - -from llama_stack.apis.common.job_types import Job -from llama_stack.apis.inference import ( - InterleavedContent, - LogProbConfig, - Message, - ResponseFormat, - SamplingParams, - ToolChoice, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.schema_utils import webmethod - - -@runtime_checkable -class BatchInference(Protocol): - """Batch inference API for generating completions and chat completions. - - This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion. - - NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs - including (post-training, evals, etc). - """ - - @webmethod(route="/batch-inference/completion", method="POST") - async def completion( - self, - model: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> Job: - """Generate completions for a batch of content. - - :param model: The model to use for the completion. - :param content_batch: The content to complete. - :param sampling_params: The sampling parameters to use for the completion. - :param response_format: The response format to use for the completion. - :param logprobs: The logprobs to use for the completion. - :returns: A job for the completion. - """ - ... - - @webmethod(route="/batch-inference/chat-completion", method="POST") - async def chat_completion( - self, - model: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - # zero-shot tool definitions as input to the model - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> Job: - """Generate chat completions for a batch of messages. - - :param model: The model to use for the chat completion. - :param messages_batch: The messages to complete. - :param sampling_params: The sampling parameters to use for the completion. - :param tools: The tools to use for the chat completion. - :param tool_choice: The tool choice to use for the chat completion. - :param tool_prompt_format: The tool prompt format to use for the chat completion. - :param response_format: The response format to use for the chat completion. - :param logprobs: The logprobs to use for the chat completion. - :returns: A job for the chat completion. - """ - ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index bd4737ca7..9eb00549f 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -973,26 +973,6 @@ class EmbeddingTaskType(Enum): document = "document" -@json_schema_type -class BatchCompletionResponse(BaseModel): - """Response from a batch completion request. - - :param batch: List of completion responses, one for each input in the batch - """ - - batch: list[CompletionResponse] - - -@json_schema_type -class BatchChatCompletionResponse(BaseModel): - """Response from a batch chat completion request. - - :param batch: List of chat completion responses, one for each conversation in the batch - """ - - batch: list[ChatCompletionResponse] - - class OpenAICompletionWithInputMessages(OpenAIChatCompletion): input_messages: list[OpenAIMessageParam] @@ -1049,27 +1029,6 @@ class InferenceProvider(Protocol): """ ... - @webmethod(route="/inference/batch-completion", method="POST", experimental=True) - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> BatchCompletionResponse: - """Generate completions for a batch of content using the specified model. - - :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. - :param content_batch: The content to generate completions for. - :param sampling_params: (Optional) Parameters to control the sampling strategy. - :param response_format: (Optional) Grammar specification for guided (structured) decoding. - :param logprobs: (Optional) If specified, log probabilities for each token position will be returned. - :returns: A BatchCompletionResponse with the full completions. - """ - raise NotImplementedError("Batch completion is not implemented") - return # this is so mypy's safe-super rule will consider the method concrete - @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( self, @@ -1110,31 +1069,6 @@ class InferenceProvider(Protocol): """ ... - @webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True) - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> BatchChatCompletionResponse: - """Generate chat completions for a batch of messages using the specified model. - - :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. - :param messages_batch: The messages to generate completions for. - :param sampling_params: (Optional) Parameters to control the sampling strategy. - :param tools: (Optional) List of tool definitions available to the model. - :param tool_config: (Optional) Configuration for tool use. - :param response_format: (Optional) Grammar specification for guided (structured) decoding. - :param logprobs: (Optional) If specified, log probabilities for each token position will be returned. - :returns: A BatchChatCompletionResponse with the full completions. - """ - raise NotImplementedError("Batch chat completion is not implemented") - return # this is so mypy's safe-super rule will consider the method concrete - @webmethod(route="/inference/embeddings", method="POST") async def embeddings( self, diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 4b66601bb..2954f5080 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -20,8 +20,6 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( - BatchChatCompletionResponse, - BatchCompletionResponse, ChatCompletionResponse, ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, @@ -268,30 +266,6 @@ class InferenceRouter(Inference): ) return response - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> BatchChatCompletionResponse: - logger.debug( - f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", - ) - provider = await self.routing_table.get_provider_impl(model_id) - return await provider.batch_chat_completion( - model_id=model_id, - messages_batch=messages_batch, - tools=tools, - tool_config=tool_config, - sampling_params=sampling_params, - response_format=response_format, - logprobs=logprobs, - ) - async def completion( self, model_id: str, @@ -333,20 +307,6 @@ class InferenceRouter(Inference): return response - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ) -> BatchCompletionResponse: - logger.debug( - f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", - ) - provider = await self.routing_table.get_provider_impl(model_id) - return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) - async def embeddings( self, model_id: str, diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 87a3978c1..1ed23a12a 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -14,7 +14,6 @@ from typing import Any import yaml from llama_stack.apis.agents import Agents -from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -52,7 +51,6 @@ class LlamaStack( Providers, VectorDBs, Inference, - BatchInference, Agents, Safety, SyntheticDataGeneration, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 88d7a98ec..f9e295014 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import ( ToolCallParseStatus, ) from llama_stack.apis.inference import ( - BatchChatCompletionResponse, - BatchCompletionResponse, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, @@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl( results = await self._nonstream_completion([request]) return results[0] - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - ) -> BatchCompletionResponse: - if sampling_params is None: - sampling_params = SamplingParams() - if logprobs: - assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - - content_batch = [ - augment_content_with_response_format_prompt(response_format, content) for content in content_batch - ] - - request_batch = [] - for content in content_batch: - request = CompletionRequest( - model=model_id, - content=content, - sampling_params=sampling_params, - response_format=response_format, - stream=stream, - logprobs=logprobs, - ) - self.check_model(request) - request = await convert_request_to_raw(request) - request_batch.append(request) - - results = await self._nonstream_completion(request_batch) - return BatchCompletionResponse(batch=results) - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: tokenizer = self.generator.formatter.tokenizer @@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl( results = await self._nonstream_chat_completion([request]) return results[0] - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - tools: list[ToolDefinition] | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> BatchChatCompletionResponse: - if sampling_params is None: - sampling_params = SamplingParams() - if logprobs: - assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}" - - # wrapper request to make it easier to pass around (internal only, not exposed to API) - request_batch = [] - for messages in messages_batch: - request = ChatCompletionRequest( - model=model_id, - messages=messages, - sampling_params=sampling_params, - tools=tools or [], - response_format=response_format, - logprobs=logprobs, - tool_config=tool_config or ToolConfig(), - ) - self.check_model(request) - - # augment and rewrite messages depending on the model - request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value) - # download media and convert to raw content so we can send it to the model - request = await convert_request_to_raw(request) - request_batch.append(request) - - if self.config.create_distributed_process_group: - if SEMAPHORE.locked(): - raise RuntimeError("Only one concurrent request is supported") - - results = await self._nonstream_chat_completion(request_batch) - return BatchChatCompletionResponse(batch=results) - async def _nonstream_chat_completion( self, request_batch: list[ChatCompletionRequest] ) -> list[ChatCompletionResponse]: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 0f73c9321..57934a9c8 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -22,8 +22,6 @@ logger = get_logger(name=__name__, category="inference::openai") # | completion | LiteLLMOpenAIMixin | # | chat_completion | LiteLLMOpenAIMixin | # | embedding | LiteLLMOpenAIMixin | -# | batch_completion | LiteLLMOpenAIMixin | -# | batch_chat_completion | LiteLLMOpenAIMixin | # | openai_completion | OpenAIMixin | # | openai_chat_completion | OpenAIMixin | # | openai_embeddings | OpenAIMixin | diff --git a/tests/integration/inference/test_batch_inference.py b/tests/integration/inference/test_batch_inference.py deleted file mode 100644 index 9a1a62ce0..000000000 --- a/tests/integration/inference/test_batch_inference.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -import pytest - -from ..test_cases.test_case import TestCase - - -def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id): - models = {m.identifier: m for m in client_with_models.models.list()} - models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) - provider_id = models[model_id].provider_id - providers = {p.provider_id: p for p in client_with_models.providers.list()} - provider = providers[provider_id] - if provider.provider_type not in ("inline::meta-reference",): - pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference") - - -@pytest.mark.parametrize( - "test_case", - [ - "inference:completion:batch_completion", - ], -) -def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case): - skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) - tc = TestCase(test_case) - - content_batch = tc["contents"] - response = client_with_models.inference.batch_completion( - content_batch=content_batch, - model_id=text_model_id, - sampling_params={ - "max_tokens": 50, - }, - ) - assert len(response.batch) == len(content_batch) - for i, r in enumerate(response.batch): - print(f"response {i}: {r.content}") - assert len(r.content) > 10 - - -@pytest.mark.parametrize( - "test_case", - [ - "inference:chat_completion:batch_completion", - ], -) -def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case): - skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id) - tc = TestCase(test_case) - qa_pairs = tc["qa_pairs"] - - message_batch = [ - [ - { - "role": "user", - "content": qa["question"], - } - ] - for qa in qa_pairs - ] - - response = client_with_models.inference.batch_chat_completion( - messages_batch=message_batch, - model_id=text_model_id, - ) - assert len(response.batch) == len(qa_pairs) - for i, r in enumerate(response.batch): - print(f"response {i}: {r.completion_message.content}") - assert len(r.completion_message.content) > 0 - assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()