From a4b573d75072439cb392f5ea00a7cc7439a520b8 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Sat, 12 Apr 2025 16:29:02 -0400 Subject: [PATCH] Fix OpenAI API response format handling This fixes the schema of OpenAI API chat completion response formats, including how those response formats (and other nested parameters in the chat completion request) get translated into paramters for calls to the backend OpenAI-compatible providers. Signed-off-by: Ben Browning --- docs/_static/llama-stack-spec.html | 119 +++++++++++++++++- docs/_static/llama-stack-spec.yaml | 74 ++++++++++- llama_stack/apis/inference/inference.py | 44 ++++++- llama_stack/distribution/routers/routers.py | 9 +- .../remote/inference/fireworks/fireworks.py | 9 +- .../remote/inference/nvidia/nvidia.py | 9 +- .../remote/inference/ollama/ollama.py | 9 +- .../inference/passthrough/passthrough.py | 9 +- .../remote/inference/together/together.py | 9 +- .../providers/remote/inference/vllm/vllm.py | 9 +- .../utils/inference/litellm_openai_mixin.py | 13 +- .../utils/inference/openai_compat.py | 24 +++- 12 files changed, 307 insertions(+), 30 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 1927f2013..84a9bc67d 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -8965,6 +8965,50 @@ ], "title": "OpenAIImageURL" }, + "OpenAIJSONSchema": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "strict": { + "type": "boolean" + }, + "schema": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "name" + ], + "title": "OpenAIJSONSchema" + }, "OpenAIMessageParam": { "oneOf": [ { @@ -8994,6 +9038,76 @@ } } }, + "OpenAIResponseFormatJSONObject": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "json_object", + "default": "json_object" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "OpenAIResponseFormatJSONObject" + }, + "OpenAIResponseFormatJSONSchema": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "json_schema", + "default": "json_schema" + }, + "json_schema": { + "$ref": "#/components/schemas/OpenAIJSONSchema" + } + }, + "additionalProperties": false, + "required": [ + "type", + "json_schema" + ], + "title": "OpenAIResponseFormatJSONSchema" + }, + "OpenAIResponseFormatParam": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseFormatText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseFormatJSONSchema" + }, + { + "$ref": "#/components/schemas/OpenAIResponseFormatJSONObject" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "text": "#/components/schemas/OpenAIResponseFormatText", + "json_schema": "#/components/schemas/OpenAIResponseFormatJSONSchema", + "json_object": "#/components/schemas/OpenAIResponseFormatJSONObject" + } + } + }, + "OpenAIResponseFormatText": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text", + "default": "text" + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "OpenAIResponseFormatText" + }, "OpenAISystemMessageParam": { "type": "object", "properties": { @@ -9215,10 +9329,7 @@ "description": "(Optional) The penalty for repeated tokens" }, "response_format": { - "type": "object", - "additionalProperties": { - "type": "string" - }, + "$ref": "#/components/schemas/OpenAIResponseFormatParam", "description": "(Optional) The response format to use" }, "seed": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 1070b76a4..3fcc83f15 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6157,6 +6157,29 @@ components: required: - url title: OpenAIImageURL + OpenAIJSONSchema: + type: object + properties: + name: + type: string + description: + type: string + strict: + type: boolean + schema: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - name + title: OpenAIJSONSchema OpenAIMessageParam: oneOf: - $ref: '#/components/schemas/OpenAIUserMessageParam' @@ -6172,6 +6195,53 @@ components: assistant: '#/components/schemas/OpenAIAssistantMessageParam' tool: '#/components/schemas/OpenAIToolMessageParam' developer: '#/components/schemas/OpenAIDeveloperMessageParam' + OpenAIResponseFormatJSONObject: + type: object + properties: + type: + type: string + const: json_object + default: json_object + additionalProperties: false + required: + - type + title: OpenAIResponseFormatJSONObject + OpenAIResponseFormatJSONSchema: + type: object + properties: + type: + type: string + const: json_schema + default: json_schema + json_schema: + $ref: '#/components/schemas/OpenAIJSONSchema' + additionalProperties: false + required: + - type + - json_schema + title: OpenAIResponseFormatJSONSchema + OpenAIResponseFormatParam: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseFormatText' + - $ref: '#/components/schemas/OpenAIResponseFormatJSONSchema' + - $ref: '#/components/schemas/OpenAIResponseFormatJSONObject' + discriminator: + propertyName: type + mapping: + text: '#/components/schemas/OpenAIResponseFormatText' + json_schema: '#/components/schemas/OpenAIResponseFormatJSONSchema' + json_object: '#/components/schemas/OpenAIResponseFormatJSONObject' + OpenAIResponseFormatText: + type: object + properties: + type: + type: string + const: text + default: text + additionalProperties: false + required: + - type + title: OpenAIResponseFormatText OpenAISystemMessageParam: type: object properties: @@ -6331,9 +6401,7 @@ components: description: >- (Optional) The penalty for repeated tokens response_format: - type: object - additionalProperties: - type: string + $ref: '#/components/schemas/OpenAIResponseFormatParam' description: (Optional) The response format to use seed: type: integer diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 0e70c876e..4251d37ab 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -18,7 +18,7 @@ from typing import ( ) from pydantic import BaseModel, Field, field_validator -from typing_extensions import Annotated +from typing_extensions import Annotated, TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model @@ -558,6 +558,46 @@ OpenAIMessageParam = Annotated[ register_schema(OpenAIMessageParam, name="OpenAIMessageParam") +@json_schema_type +class OpenAIResponseFormatText(BaseModel): + type: Literal["text"] = "text" + + +@json_schema_type +class OpenAIJSONSchema(TypedDict, total=False): + name: str + description: Optional[str] = None + strict: Optional[bool] = None + + # Pydantic BaseModel cannot be used with a schema param, since it already + # has one. And, we don't want to alias here because then have to handle + # that alias when converting to OpenAI params. So, to support schema, + # we use a TypedDict. + schema: Optional[Dict[str, Any]] = None + + +@json_schema_type +class OpenAIResponseFormatJSONSchema(BaseModel): + type: Literal["json_schema"] = "json_schema" + json_schema: OpenAIJSONSchema + + +@json_schema_type +class OpenAIResponseFormatJSONObject(BaseModel): + type: Literal["json_object"] = "json_object" + + +OpenAIResponseFormatParam = Annotated[ + Union[ + OpenAIResponseFormatText, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatJSONObject, + ], + Field(discriminator="type"), +] +register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam") + + @json_schema_type class OpenAITopLogProb(BaseModel): """The top log probability for a token from an OpenAI-compatible chat completion response. @@ -903,7 +943,7 @@ class Inference(Protocol): n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index b9623ef3c..b9f363be0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -37,7 +37,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( @@ -530,7 +535,7 @@ class InferenceRouter(Inference): n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b59e9f2cb..8385209f1 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -32,7 +32,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( @@ -336,7 +341,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index d6f717719..b2a244f11 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -35,7 +35,12 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -329,7 +334,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 33b48af46..a24d35ab2 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -39,7 +39,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -393,7 +398,7 @@ class OllamaInferenceAdapter( n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 0eb38c395..63054ae0a 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -26,7 +26,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import Model from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -266,7 +271,7 @@ class PassthroughInferenceAdapter(Inference): n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1615b8cd1..4ebf9956e 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -31,7 +31,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -315,7 +320,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 0044d2e75..eca68e399 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -45,7 +45,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models @@ -487,7 +492,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index cd0f4ec67..6d98a0cb4 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -30,7 +30,12 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) from llama_stack.apis.models.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -270,7 +275,7 @@ class LiteLLMOpenAIMixin( guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: - model_obj = await self._get_model(model) + model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, prompt=prompt, @@ -308,7 +313,7 @@ class LiteLLMOpenAIMixin( n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None, @@ -320,7 +325,7 @@ class LiteLLMOpenAIMixin( top_p: Optional[float] = None, user: Optional[str] = None, ) -> OpenAIChatCompletion: - model_obj = await self._get_model(model) + model_obj = await self.model_store.get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, messages=messages, diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f33cb4443..1fa202475 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -85,7 +85,12 @@ from llama_stack.apis.inference import ( TopPSamplingStrategy, UserMessage, ) -from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletion, + OpenAICompletion, + OpenAICompletionChoice, + OpenAIResponseFormatParam, +) from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -1080,7 +1085,20 @@ async def convert_openai_chat_completion_stream( async def prepare_openai_completion_params(**params): - completion_params = {k: v for k, v in params.items() if v is not None} + async def _prepare_value(value: Any) -> Any: + new_value = value + if isinstance(value, list): + new_value = [await _prepare_value(v) for v in value] + elif isinstance(value, dict): + new_value = {k: await _prepare_value(v) for k, v in value.items()} + elif isinstance(value, BaseModel): + new_value = value.model_dump(exclude_none=True) + return new_value + + completion_params = {} + for k, v in params.items(): + if v is not None: + completion_params[k] = await _prepare_value(v) return completion_params @@ -1167,7 +1185,7 @@ class OpenAIChatCompletionUnsupportedMixin: n: Optional[int] = None, parallel_tool_calls: Optional[bool] = None, presence_penalty: Optional[float] = None, - response_format: Optional[Dict[str, str]] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stream: Optional[bool] = None,