From 972f2395a14d1f522d983192238d37863b7df131 Mon Sep 17 00:00:00 2001 From: Eric Huang Date: Thu, 9 Oct 2025 17:28:44 -0700 Subject: [PATCH] test # What does this PR do? ## Test Plan --- docs/static/deprecated-llama-stack-spec.html | 545 +++++++++--------- docs/static/deprecated-llama-stack-spec.yaml | 309 +++++----- docs/static/llama-stack-spec.html | 545 +++++++++--------- docs/static/llama-stack-spec.yaml | 309 +++++----- docs/static/stainless-llama-stack-spec.html | 545 +++++++++--------- docs/static/stainless-llama-stack-spec.yaml | 309 +++++----- llama_stack/apis/inference/inference.py | 170 +++--- llama_stack/core/library_client.py | 31 +- llama_stack/core/routers/inference.py | 151 ++--- llama_stack/core/server/server.py | 66 ++- .../agents/meta_reference/agent_instance.py | 4 +- .../meta_reference/responses/streaming.py | 4 +- .../inline/batches/reference/batches.py | 8 +- .../inline/eval/meta_reference/eval.py | 15 +- .../inference/meta_reference/inference.py | 35 +- .../sentence_transformers.py | 53 +- .../inline/safety/llama_guard/llama_guard.py | 22 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 5 +- .../tool_runtime/rag/context_retriever.py | 5 +- .../remote/inference/bedrock/bedrock.py | 53 +- .../remote/inference/databricks/databricks.py | 26 +- .../inference/llama_openai_compat/llama.py | 26 +- .../inference/passthrough/passthrough.py | 114 +--- .../remote/inference/runpod/runpod.py | 66 +-- .../providers/remote/inference/vllm/vllm.py | 65 +-- .../utils/inference/litellm_openai_mixin.py | 155 ++--- .../providers/utils/inference/openai_mixin.py | 171 +++--- .../providers/inference/test_remote_vllm.py | 60 +- .../utils/inference/test_openai_mixin.py | 8 +- 29 files changed, 1726 insertions(+), 2149 deletions(-) diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 04a3dca9b..34298bc94 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -7343,6 +7343,233 @@ "title": "OpenAIUserMessageParam", "description": "A message from the user in an OpenAI-compatible chat completion request." }, + "OpenAIChatCompletionRequestParams": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIMessageParam" + } + }, + "frequency_penalty": { + "type": "number" + }, + "function_call": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "functions": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "logprobs": { + "type": "boolean" + }, + "max_completion_tokens": { + "type": "integer" + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "parallel_tool_calls": { + "type": "boolean" + }, + "presence_penalty": { + "type": "number" + }, + "response_format": { + "$ref": "#/components/schemas/OpenAIResponseFormatParam" + }, + "seed": { + "type": "integer" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, + "stream": { + "type": "boolean" + }, + "stream_options": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "top_logprobs": { + "type": "integer" + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model", + "messages" + ], + "title": "OpenAIChatCompletionRequestParams", + "description": "Request parameters for OpenAI-compatible chat completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\nwhich are passed through as extra_body." + }, "OpenAIJSONSchema": { "type": "object", "properties": { @@ -7472,249 +7699,14 @@ "OpenaiChatCompletionRequest": { "type": "object", "properties": { - "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIMessageParam" - }, - "description": "List of messages in the conversation." - }, - "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "function_call": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The function call to use." - }, - "functions": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) List of functions to use." - }, - "logit_bias": { - "type": "object", - "additionalProperties": { - "type": "number" - }, - "description": "(Optional) The logit bias to use." - }, - "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." - }, - "max_completion_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." - }, - "parallel_tool_calls": { - "type": "boolean", - "description": "(Optional) Whether to parallelize tool calls." - }, - "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "response_format": { - "$ref": "#/components/schemas/OpenAIResponseFormatParam", - "description": "(Optional) The response format to use." - }, - "seed": { - "type": "integer", - "description": "(Optional) The seed to use." - }, - "stop": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "type": "string" - } - } - ], - "description": "(Optional) The stop tokens to use." - }, - "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." - }, - "stream_options": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - }, - "description": "(Optional) The stream options to use." - }, - "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." - }, - "tool_choice": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The tool choice to use." - }, - "tools": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) The tools to use." - }, - "top_logprobs": { - "type": "integer", - "description": "(Optional) The top log probabilities to use." - }, - "top_p": { - "type": "number", - "description": "(Optional) The top p to use." - }, - "user": { - "type": "string", - "description": "(Optional) The user to use." + "params": { + "$ref": "#/components/schemas/OpenAIChatCompletionRequestParams", + "description": "Request parameters including model, messages, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ - "model", - "messages" + "params" ], "title": "OpenaiChatCompletionRequest" }, @@ -7900,12 +7892,11 @@ ], "title": "OpenAICompletionWithInputMessages" }, - "OpenaiCompletionRequest": { + "OpenAICompletionRequestParams": { "type": "object", "properties": { "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." + "type": "string" }, "prompt": { "oneOf": [ @@ -7933,47 +7924,37 @@ } } } - ], - "description": "The prompt to generate a completion for." + ] }, "best_of": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "echo": { - "type": "boolean", - "description": "(Optional) Whether to echo the prompt." + "type": "boolean" }, "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "logit_bias": { "type": "object", "additionalProperties": { "type": "number" - }, - "description": "(Optional) The logit bias to use." + } }, "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." + "type": "boolean" }, "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." + "type": "integer" }, "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "seed": { - "type": "integer", - "description": "(Optional) The seed to use." + "type": "integer" }, "stop": { "oneOf": [ @@ -7986,12 +7967,10 @@ "type": "string" } } - ], - "description": "(Optional) The stop tokens to use." + ] }, "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." + "type": "boolean" }, "stream_options": { "type": "object", @@ -8016,20 +7995,19 @@ "type": "object" } ] - }, - "description": "(Optional) The stream options to use." + } }, "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." + "type": "number" }, "top_p": { - "type": "number", - "description": "(Optional) The top p to use." + "type": "number" }, "user": { - "type": "string", - "description": "(Optional) The user to use." + "type": "string" + }, + "suffix": { + "type": "string" }, "guided_choice": { "type": "array", @@ -8039,10 +8017,6 @@ }, "prompt_logprobs": { "type": "integer" - }, - "suffix": { - "type": "string", - "description": "(Optional) The suffix that should be appended to the completion." } }, "additionalProperties": false, @@ -8050,6 +8024,21 @@ "model", "prompt" ], + "title": "OpenAICompletionRequestParams", + "description": "Request parameters for OpenAI-compatible completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\n(like vLLM's guided_choice) which are passed through as extra_body." + }, + "OpenaiCompletionRequest": { + "type": "object", + "properties": { + "params": { + "$ref": "#/components/schemas/OpenAICompletionRequestParams", + "description": "Request parameters including model, prompt, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters." + } + }, + "additionalProperties": false, + "required": [ + "params" + ], "title": "OpenaiCompletionRequest" }, "OpenAICompletion": { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 1a215b877..744d9f460 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -5437,6 +5437,122 @@ components: title: OpenAIUserMessageParam description: >- A message from the user in an OpenAI-compatible chat completion request. + OpenAIChatCompletionRequestParams: + type: object + properties: + model: + type: string + messages: + type: array + items: + $ref: '#/components/schemas/OpenAIMessageParam' + frequency_penalty: + type: number + function_call: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + functions: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + logit_bias: + type: object + additionalProperties: + type: number + logprobs: + type: boolean + max_completion_tokens: + type: integer + max_tokens: + type: integer + n: + type: integer + parallel_tool_calls: + type: boolean + presence_penalty: + type: number + response_format: + $ref: '#/components/schemas/OpenAIResponseFormatParam' + seed: + type: integer + stop: + oneOf: + - type: string + - type: array + items: + type: string + stream: + type: boolean + stream_options: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + temperature: + type: number + tool_choice: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + tools: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + top_logprobs: + type: integer + top_p: + type: number + user: + type: string + additionalProperties: false + required: + - model + - messages + title: OpenAIChatCompletionRequestParams + description: >- + Request parameters for OpenAI-compatible chat completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + which are passed through as extra_body. OpenAIJSONSchema: type: object properties: @@ -5531,145 +5647,15 @@ components: OpenaiChatCompletionRequest: type: object properties: - model: - type: string + params: + $ref: '#/components/schemas/OpenAIChatCompletionRequestParams' description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. - messages: - type: array - items: - $ref: '#/components/schemas/OpenAIMessageParam' - description: List of messages in the conversation. - frequency_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - function_call: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The function call to use. - functions: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) List of functions to use. - logit_bias: - type: object - additionalProperties: - type: number - description: (Optional) The logit bias to use. - logprobs: - type: boolean - description: (Optional) The log probabilities to use. - max_completion_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - max_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - n: - type: integer - description: >- - (Optional) The number of completions to generate. - parallel_tool_calls: - type: boolean - description: >- - (Optional) Whether to parallelize tool calls. - presence_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - response_format: - $ref: '#/components/schemas/OpenAIResponseFormatParam' - description: (Optional) The response format to use. - seed: - type: integer - description: (Optional) The seed to use. - stop: - oneOf: - - type: string - - type: array - items: - type: string - description: (Optional) The stop tokens to use. - stream: - type: boolean - description: >- - (Optional) Whether to stream the response. - stream_options: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The stream options to use. - temperature: - type: number - description: (Optional) The temperature to use. - tool_choice: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tool choice to use. - tools: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tools to use. - top_logprobs: - type: integer - description: >- - (Optional) The top log probabilities to use. - top_p: - type: number - description: (Optional) The top p to use. - user: - type: string - description: (Optional) The user to use. + Request parameters including model, messages, and optional parameters. + Use params.get_extra_body() to extract provider-specific parameters (e.g., + chat_template_kwargs for vLLM). additionalProperties: false required: - - model - - messages + - params title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object @@ -5824,14 +5810,11 @@ components: - model - input_messages title: OpenAICompletionWithInputMessages - OpenaiCompletionRequest: + OpenAICompletionRequestParams: type: object properties: model: type: string - description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. prompt: oneOf: - type: string @@ -5846,52 +5829,34 @@ components: type: array items: type: integer - description: The prompt to generate a completion for. best_of: type: integer - description: >- - (Optional) The number of completions to generate. echo: type: boolean - description: (Optional) Whether to echo the prompt. frequency_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. logit_bias: type: object additionalProperties: type: number - description: (Optional) The logit bias to use. logprobs: type: boolean - description: (Optional) The log probabilities to use. max_tokens: type: integer - description: >- - (Optional) The maximum number of tokens to generate. n: type: integer - description: >- - (Optional) The number of completions to generate. presence_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. seed: type: integer - description: (Optional) The seed to use. stop: oneOf: - type: string - type: array items: type: string - description: (Optional) The stop tokens to use. stream: type: boolean - description: >- - (Optional) Whether to stream the response. stream_options: type: object additionalProperties: @@ -5902,30 +5867,42 @@ components: - type: string - type: array - type: object - description: (Optional) The stream options to use. temperature: type: number - description: (Optional) The temperature to use. top_p: type: number - description: (Optional) The top p to use. user: type: string - description: (Optional) The user to use. + suffix: + type: string guided_choice: type: array items: type: string prompt_logprobs: type: integer - suffix: - type: string - description: >- - (Optional) The suffix that should be appended to the completion. additionalProperties: false required: - model - prompt + title: OpenAICompletionRequestParams + description: >- + Request parameters for OpenAI-compatible completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + (like vLLM's guided_choice) which are passed through as extra_body. + OpenaiCompletionRequest: + type: object + properties: + params: + $ref: '#/components/schemas/OpenAICompletionRequestParams' + description: >- + Request parameters including model, prompt, and optional parameters. Use + params.get_extra_body() to extract provider-specific parameters. + additionalProperties: false + required: + - params title: OpenaiCompletionRequest OpenAICompletion: type: object diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 9cd526176..2ad23dceb 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -4839,6 +4839,233 @@ "title": "OpenAIUserMessageParam", "description": "A message from the user in an OpenAI-compatible chat completion request." }, + "OpenAIChatCompletionRequestParams": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIMessageParam" + } + }, + "frequency_penalty": { + "type": "number" + }, + "function_call": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "functions": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "logprobs": { + "type": "boolean" + }, + "max_completion_tokens": { + "type": "integer" + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "parallel_tool_calls": { + "type": "boolean" + }, + "presence_penalty": { + "type": "number" + }, + "response_format": { + "$ref": "#/components/schemas/OpenAIResponseFormatParam" + }, + "seed": { + "type": "integer" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, + "stream": { + "type": "boolean" + }, + "stream_options": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "top_logprobs": { + "type": "integer" + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model", + "messages" + ], + "title": "OpenAIChatCompletionRequestParams", + "description": "Request parameters for OpenAI-compatible chat completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\nwhich are passed through as extra_body." + }, "OpenAIJSONSchema": { "type": "object", "properties": { @@ -4968,249 +5195,14 @@ "OpenaiChatCompletionRequest": { "type": "object", "properties": { - "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIMessageParam" - }, - "description": "List of messages in the conversation." - }, - "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "function_call": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The function call to use." - }, - "functions": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) List of functions to use." - }, - "logit_bias": { - "type": "object", - "additionalProperties": { - "type": "number" - }, - "description": "(Optional) The logit bias to use." - }, - "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." - }, - "max_completion_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." - }, - "parallel_tool_calls": { - "type": "boolean", - "description": "(Optional) Whether to parallelize tool calls." - }, - "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "response_format": { - "$ref": "#/components/schemas/OpenAIResponseFormatParam", - "description": "(Optional) The response format to use." - }, - "seed": { - "type": "integer", - "description": "(Optional) The seed to use." - }, - "stop": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "type": "string" - } - } - ], - "description": "(Optional) The stop tokens to use." - }, - "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." - }, - "stream_options": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - }, - "description": "(Optional) The stream options to use." - }, - "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." - }, - "tool_choice": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The tool choice to use." - }, - "tools": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) The tools to use." - }, - "top_logprobs": { - "type": "integer", - "description": "(Optional) The top log probabilities to use." - }, - "top_p": { - "type": "number", - "description": "(Optional) The top p to use." - }, - "user": { - "type": "string", - "description": "(Optional) The user to use." + "params": { + "$ref": "#/components/schemas/OpenAIChatCompletionRequestParams", + "description": "Request parameters including model, messages, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ - "model", - "messages" + "params" ], "title": "OpenaiChatCompletionRequest" }, @@ -5396,12 +5388,11 @@ ], "title": "OpenAICompletionWithInputMessages" }, - "OpenaiCompletionRequest": { + "OpenAICompletionRequestParams": { "type": "object", "properties": { "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." + "type": "string" }, "prompt": { "oneOf": [ @@ -5429,47 +5420,37 @@ } } } - ], - "description": "The prompt to generate a completion for." + ] }, "best_of": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "echo": { - "type": "boolean", - "description": "(Optional) Whether to echo the prompt." + "type": "boolean" }, "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "logit_bias": { "type": "object", "additionalProperties": { "type": "number" - }, - "description": "(Optional) The logit bias to use." + } }, "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." + "type": "boolean" }, "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." + "type": "integer" }, "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "seed": { - "type": "integer", - "description": "(Optional) The seed to use." + "type": "integer" }, "stop": { "oneOf": [ @@ -5482,12 +5463,10 @@ "type": "string" } } - ], - "description": "(Optional) The stop tokens to use." + ] }, "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." + "type": "boolean" }, "stream_options": { "type": "object", @@ -5512,20 +5491,19 @@ "type": "object" } ] - }, - "description": "(Optional) The stream options to use." + } }, "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." + "type": "number" }, "top_p": { - "type": "number", - "description": "(Optional) The top p to use." + "type": "number" }, "user": { - "type": "string", - "description": "(Optional) The user to use." + "type": "string" + }, + "suffix": { + "type": "string" }, "guided_choice": { "type": "array", @@ -5535,10 +5513,6 @@ }, "prompt_logprobs": { "type": "integer" - }, - "suffix": { - "type": "string", - "description": "(Optional) The suffix that should be appended to the completion." } }, "additionalProperties": false, @@ -5546,6 +5520,21 @@ "model", "prompt" ], + "title": "OpenAICompletionRequestParams", + "description": "Request parameters for OpenAI-compatible completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\n(like vLLM's guided_choice) which are passed through as extra_body." + }, + "OpenaiCompletionRequest": { + "type": "object", + "properties": { + "params": { + "$ref": "#/components/schemas/OpenAICompletionRequestParams", + "description": "Request parameters including model, prompt, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters." + } + }, + "additionalProperties": false, + "required": [ + "params" + ], "title": "OpenaiCompletionRequest" }, "OpenAICompletion": { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 66ce8e38a..299ce4fad 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -3686,6 +3686,122 @@ components: title: OpenAIUserMessageParam description: >- A message from the user in an OpenAI-compatible chat completion request. + OpenAIChatCompletionRequestParams: + type: object + properties: + model: + type: string + messages: + type: array + items: + $ref: '#/components/schemas/OpenAIMessageParam' + frequency_penalty: + type: number + function_call: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + functions: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + logit_bias: + type: object + additionalProperties: + type: number + logprobs: + type: boolean + max_completion_tokens: + type: integer + max_tokens: + type: integer + n: + type: integer + parallel_tool_calls: + type: boolean + presence_penalty: + type: number + response_format: + $ref: '#/components/schemas/OpenAIResponseFormatParam' + seed: + type: integer + stop: + oneOf: + - type: string + - type: array + items: + type: string + stream: + type: boolean + stream_options: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + temperature: + type: number + tool_choice: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + tools: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + top_logprobs: + type: integer + top_p: + type: number + user: + type: string + additionalProperties: false + required: + - model + - messages + title: OpenAIChatCompletionRequestParams + description: >- + Request parameters for OpenAI-compatible chat completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + which are passed through as extra_body. OpenAIJSONSchema: type: object properties: @@ -3780,145 +3896,15 @@ components: OpenaiChatCompletionRequest: type: object properties: - model: - type: string + params: + $ref: '#/components/schemas/OpenAIChatCompletionRequestParams' description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. - messages: - type: array - items: - $ref: '#/components/schemas/OpenAIMessageParam' - description: List of messages in the conversation. - frequency_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - function_call: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The function call to use. - functions: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) List of functions to use. - logit_bias: - type: object - additionalProperties: - type: number - description: (Optional) The logit bias to use. - logprobs: - type: boolean - description: (Optional) The log probabilities to use. - max_completion_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - max_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - n: - type: integer - description: >- - (Optional) The number of completions to generate. - parallel_tool_calls: - type: boolean - description: >- - (Optional) Whether to parallelize tool calls. - presence_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - response_format: - $ref: '#/components/schemas/OpenAIResponseFormatParam' - description: (Optional) The response format to use. - seed: - type: integer - description: (Optional) The seed to use. - stop: - oneOf: - - type: string - - type: array - items: - type: string - description: (Optional) The stop tokens to use. - stream: - type: boolean - description: >- - (Optional) Whether to stream the response. - stream_options: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The stream options to use. - temperature: - type: number - description: (Optional) The temperature to use. - tool_choice: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tool choice to use. - tools: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tools to use. - top_logprobs: - type: integer - description: >- - (Optional) The top log probabilities to use. - top_p: - type: number - description: (Optional) The top p to use. - user: - type: string - description: (Optional) The user to use. + Request parameters including model, messages, and optional parameters. + Use params.get_extra_body() to extract provider-specific parameters (e.g., + chat_template_kwargs for vLLM). additionalProperties: false required: - - model - - messages + - params title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object @@ -4073,14 +4059,11 @@ components: - model - input_messages title: OpenAICompletionWithInputMessages - OpenaiCompletionRequest: + OpenAICompletionRequestParams: type: object properties: model: type: string - description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. prompt: oneOf: - type: string @@ -4095,52 +4078,34 @@ components: type: array items: type: integer - description: The prompt to generate a completion for. best_of: type: integer - description: >- - (Optional) The number of completions to generate. echo: type: boolean - description: (Optional) Whether to echo the prompt. frequency_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. logit_bias: type: object additionalProperties: type: number - description: (Optional) The logit bias to use. logprobs: type: boolean - description: (Optional) The log probabilities to use. max_tokens: type: integer - description: >- - (Optional) The maximum number of tokens to generate. n: type: integer - description: >- - (Optional) The number of completions to generate. presence_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. seed: type: integer - description: (Optional) The seed to use. stop: oneOf: - type: string - type: array items: type: string - description: (Optional) The stop tokens to use. stream: type: boolean - description: >- - (Optional) Whether to stream the response. stream_options: type: object additionalProperties: @@ -4151,30 +4116,42 @@ components: - type: string - type: array - type: object - description: (Optional) The stream options to use. temperature: type: number - description: (Optional) The temperature to use. top_p: type: number - description: (Optional) The top p to use. user: type: string - description: (Optional) The user to use. + suffix: + type: string guided_choice: type: array items: type: string prompt_logprobs: type: integer - suffix: - type: string - description: >- - (Optional) The suffix that should be appended to the completion. additionalProperties: false required: - model - prompt + title: OpenAICompletionRequestParams + description: >- + Request parameters for OpenAI-compatible completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + (like vLLM's guided_choice) which are passed through as extra_body. + OpenaiCompletionRequest: + type: object + properties: + params: + $ref: '#/components/schemas/OpenAICompletionRequestParams' + description: >- + Request parameters including model, prompt, and optional parameters. Use + params.get_extra_body() to extract provider-specific parameters. + additionalProperties: false + required: + - params title: OpenaiCompletionRequest OpenAICompletion: type: object diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 3478d3338..b56cbf421 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -6848,6 +6848,233 @@ "title": "OpenAIUserMessageParam", "description": "A message from the user in an OpenAI-compatible chat completion request." }, + "OpenAIChatCompletionRequestParams": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "messages": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIMessageParam" + } + }, + "frequency_penalty": { + "type": "number" + }, + "function_call": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "functions": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "logit_bias": { + "type": "object", + "additionalProperties": { + "type": "number" + } + }, + "logprobs": { + "type": "boolean" + }, + "max_completion_tokens": { + "type": "integer" + }, + "max_tokens": { + "type": "integer" + }, + "n": { + "type": "integer" + }, + "parallel_tool_calls": { + "type": "boolean" + }, + "presence_penalty": { + "type": "number" + }, + "response_format": { + "$ref": "#/components/schemas/OpenAIResponseFormatParam" + }, + "seed": { + "type": "integer" + }, + "stop": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ] + }, + "stream": { + "type": "boolean" + }, + "stream_options": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "temperature": { + "type": "number" + }, + "tool_choice": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + ] + }, + "tools": { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "top_logprobs": { + "type": "integer" + }, + "top_p": { + "type": "number" + }, + "user": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "model", + "messages" + ], + "title": "OpenAIChatCompletionRequestParams", + "description": "Request parameters for OpenAI-compatible chat completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\nwhich are passed through as extra_body." + }, "OpenAIJSONSchema": { "type": "object", "properties": { @@ -6977,249 +7204,14 @@ "OpenaiChatCompletionRequest": { "type": "object", "properties": { - "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." - }, - "messages": { - "type": "array", - "items": { - "$ref": "#/components/schemas/OpenAIMessageParam" - }, - "description": "List of messages in the conversation." - }, - "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "function_call": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The function call to use." - }, - "functions": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) List of functions to use." - }, - "logit_bias": { - "type": "object", - "additionalProperties": { - "type": "number" - }, - "description": "(Optional) The logit bias to use." - }, - "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." - }, - "max_completion_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." - }, - "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." - }, - "parallel_tool_calls": { - "type": "boolean", - "description": "(Optional) Whether to parallelize tool calls." - }, - "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." - }, - "response_format": { - "$ref": "#/components/schemas/OpenAIResponseFormatParam", - "description": "(Optional) The response format to use." - }, - "seed": { - "type": "integer", - "description": "(Optional) The seed to use." - }, - "stop": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "array", - "items": { - "type": "string" - } - } - ], - "description": "(Optional) The stop tokens to use." - }, - "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." - }, - "stream_options": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - }, - "description": "(Optional) The stream options to use." - }, - "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." - }, - "tool_choice": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - } - ], - "description": "(Optional) The tool choice to use." - }, - "tools": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "(Optional) The tools to use." - }, - "top_logprobs": { - "type": "integer", - "description": "(Optional) The top log probabilities to use." - }, - "top_p": { - "type": "number", - "description": "(Optional) The top p to use." - }, - "user": { - "type": "string", - "description": "(Optional) The user to use." + "params": { + "$ref": "#/components/schemas/OpenAIChatCompletionRequestParams", + "description": "Request parameters including model, messages, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters (e.g., chat_template_kwargs for vLLM)." } }, "additionalProperties": false, "required": [ - "model", - "messages" + "params" ], "title": "OpenaiChatCompletionRequest" }, @@ -7405,12 +7397,11 @@ ], "title": "OpenAICompletionWithInputMessages" }, - "OpenaiCompletionRequest": { + "OpenAICompletionRequestParams": { "type": "object", "properties": { "model": { - "type": "string", - "description": "The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint." + "type": "string" }, "prompt": { "oneOf": [ @@ -7438,47 +7429,37 @@ } } } - ], - "description": "The prompt to generate a completion for." + ] }, "best_of": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "echo": { - "type": "boolean", - "description": "(Optional) Whether to echo the prompt." + "type": "boolean" }, "frequency_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "logit_bias": { "type": "object", "additionalProperties": { "type": "number" - }, - "description": "(Optional) The logit bias to use." + } }, "logprobs": { - "type": "boolean", - "description": "(Optional) The log probabilities to use." + "type": "boolean" }, "max_tokens": { - "type": "integer", - "description": "(Optional) The maximum number of tokens to generate." + "type": "integer" }, "n": { - "type": "integer", - "description": "(Optional) The number of completions to generate." + "type": "integer" }, "presence_penalty": { - "type": "number", - "description": "(Optional) The penalty for repeated tokens." + "type": "number" }, "seed": { - "type": "integer", - "description": "(Optional) The seed to use." + "type": "integer" }, "stop": { "oneOf": [ @@ -7491,12 +7472,10 @@ "type": "string" } } - ], - "description": "(Optional) The stop tokens to use." + ] }, "stream": { - "type": "boolean", - "description": "(Optional) Whether to stream the response." + "type": "boolean" }, "stream_options": { "type": "object", @@ -7521,20 +7500,19 @@ "type": "object" } ] - }, - "description": "(Optional) The stream options to use." + } }, "temperature": { - "type": "number", - "description": "(Optional) The temperature to use." + "type": "number" }, "top_p": { - "type": "number", - "description": "(Optional) The top p to use." + "type": "number" }, "user": { - "type": "string", - "description": "(Optional) The user to use." + "type": "string" + }, + "suffix": { + "type": "string" }, "guided_choice": { "type": "array", @@ -7544,10 +7522,6 @@ }, "prompt_logprobs": { "type": "integer" - }, - "suffix": { - "type": "string", - "description": "(Optional) The suffix that should be appended to the completion." } }, "additionalProperties": false, @@ -7555,6 +7529,21 @@ "model", "prompt" ], + "title": "OpenAICompletionRequestParams", + "description": "Request parameters for OpenAI-compatible completion endpoint.\nThis model uses extra=\"allow\" to capture provider-specific parameters\n(like vLLM's guided_choice) which are passed through as extra_body." + }, + "OpenaiCompletionRequest": { + "type": "object", + "properties": { + "params": { + "$ref": "#/components/schemas/OpenAICompletionRequestParams", + "description": "Request parameters including model, prompt, and optional parameters. Use params.get_extra_body() to extract provider-specific parameters." + } + }, + "additionalProperties": false, + "required": [ + "params" + ], "title": "OpenaiCompletionRequest" }, "OpenAICompletion": { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 6c04542bf..fc5ea0dc1 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -5131,6 +5131,122 @@ components: title: OpenAIUserMessageParam description: >- A message from the user in an OpenAI-compatible chat completion request. + OpenAIChatCompletionRequestParams: + type: object + properties: + model: + type: string + messages: + type: array + items: + $ref: '#/components/schemas/OpenAIMessageParam' + frequency_penalty: + type: number + function_call: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + functions: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + logit_bias: + type: object + additionalProperties: + type: number + logprobs: + type: boolean + max_completion_tokens: + type: integer + max_tokens: + type: integer + n: + type: integer + parallel_tool_calls: + type: boolean + presence_penalty: + type: number + response_format: + $ref: '#/components/schemas/OpenAIResponseFormatParam' + seed: + type: integer + stop: + oneOf: + - type: string + - type: array + items: + type: string + stream: + type: boolean + stream_options: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + temperature: + type: number + tool_choice: + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + tools: + type: array + items: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + top_logprobs: + type: integer + top_p: + type: number + user: + type: string + additionalProperties: false + required: + - model + - messages + title: OpenAIChatCompletionRequestParams + description: >- + Request parameters for OpenAI-compatible chat completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + which are passed through as extra_body. OpenAIJSONSchema: type: object properties: @@ -5225,145 +5341,15 @@ components: OpenaiChatCompletionRequest: type: object properties: - model: - type: string + params: + $ref: '#/components/schemas/OpenAIChatCompletionRequestParams' description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. - messages: - type: array - items: - $ref: '#/components/schemas/OpenAIMessageParam' - description: List of messages in the conversation. - frequency_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - function_call: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The function call to use. - functions: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) List of functions to use. - logit_bias: - type: object - additionalProperties: - type: number - description: (Optional) The logit bias to use. - logprobs: - type: boolean - description: (Optional) The log probabilities to use. - max_completion_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - max_tokens: - type: integer - description: >- - (Optional) The maximum number of tokens to generate. - n: - type: integer - description: >- - (Optional) The number of completions to generate. - parallel_tool_calls: - type: boolean - description: >- - (Optional) Whether to parallelize tool calls. - presence_penalty: - type: number - description: >- - (Optional) The penalty for repeated tokens. - response_format: - $ref: '#/components/schemas/OpenAIResponseFormatParam' - description: (Optional) The response format to use. - seed: - type: integer - description: (Optional) The seed to use. - stop: - oneOf: - - type: string - - type: array - items: - type: string - description: (Optional) The stop tokens to use. - stream: - type: boolean - description: >- - (Optional) Whether to stream the response. - stream_options: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The stream options to use. - temperature: - type: number - description: (Optional) The temperature to use. - tool_choice: - oneOf: - - type: string - - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tool choice to use. - tools: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: (Optional) The tools to use. - top_logprobs: - type: integer - description: >- - (Optional) The top log probabilities to use. - top_p: - type: number - description: (Optional) The top p to use. - user: - type: string - description: (Optional) The user to use. + Request parameters including model, messages, and optional parameters. + Use params.get_extra_body() to extract provider-specific parameters (e.g., + chat_template_kwargs for vLLM). additionalProperties: false required: - - model - - messages + - params title: OpenaiChatCompletionRequest OpenAIChatCompletion: type: object @@ -5518,14 +5504,11 @@ components: - model - input_messages title: OpenAICompletionWithInputMessages - OpenaiCompletionRequest: + OpenAICompletionRequestParams: type: object properties: model: type: string - description: >- - The identifier of the model to use. The model must be registered with - Llama Stack and available via the /models endpoint. prompt: oneOf: - type: string @@ -5540,52 +5523,34 @@ components: type: array items: type: integer - description: The prompt to generate a completion for. best_of: type: integer - description: >- - (Optional) The number of completions to generate. echo: type: boolean - description: (Optional) Whether to echo the prompt. frequency_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. logit_bias: type: object additionalProperties: type: number - description: (Optional) The logit bias to use. logprobs: type: boolean - description: (Optional) The log probabilities to use. max_tokens: type: integer - description: >- - (Optional) The maximum number of tokens to generate. n: type: integer - description: >- - (Optional) The number of completions to generate. presence_penalty: type: number - description: >- - (Optional) The penalty for repeated tokens. seed: type: integer - description: (Optional) The seed to use. stop: oneOf: - type: string - type: array items: type: string - description: (Optional) The stop tokens to use. stream: type: boolean - description: >- - (Optional) Whether to stream the response. stream_options: type: object additionalProperties: @@ -5596,30 +5561,42 @@ components: - type: string - type: array - type: object - description: (Optional) The stream options to use. temperature: type: number - description: (Optional) The temperature to use. top_p: type: number - description: (Optional) The top p to use. user: type: string - description: (Optional) The user to use. + suffix: + type: string guided_choice: type: array items: type: string prompt_logprobs: type: integer - suffix: - type: string - description: >- - (Optional) The suffix that should be appended to the completion. additionalProperties: false required: - model - prompt + title: OpenAICompletionRequestParams + description: >- + Request parameters for OpenAI-compatible completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + + (like vLLM's guided_choice) which are passed through as extra_body. + OpenaiCompletionRequest: + type: object + properties: + params: + $ref: '#/components/schemas/OpenAICompletionRequestParams' + description: >- + Request parameters including model, prompt, and optional parameters. Use + params.get_extra_body() to extract provider-specific parameters. + additionalProperties: false + required: + - params title: OpenaiCompletionRequest OpenAICompletion: type: object diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 62a988ea6..0a48fc456 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -14,7 +14,7 @@ from typing import ( runtime_checkable, ) -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from typing_extensions import TypedDict from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent @@ -995,6 +995,81 @@ class ListOpenAIChatCompletionResponse(BaseModel): object: Literal["list"] = "list" +@json_schema_type +class OpenAICompletionRequestParams(BaseModel): + """Request parameters for OpenAI-compatible completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + (like vLLM's guided_choice) which are passed through as extra_body. + """ + + model_config = ConfigDict(extra="allow") + + # Required parameters + model: str + prompt: str | list[str] | list[int] | list[list[int]] + + # Standard OpenAI completion parameters + best_of: int | None = None + echo: bool | None = None + frequency_penalty: float | None = None + logit_bias: dict[str, float] | None = None + logprobs: bool | None = None + max_tokens: int | None = None + n: int | None = None + presence_penalty: float | None = None + seed: int | None = None + stop: str | list[str] | None = None + stream: bool | None = None + stream_options: dict[str, Any] | None = None + temperature: float | None = None + top_p: float | None = None + user: str | None = None + suffix: str | None = None + + # vLLM-specific parameters (documented here but also allowed via extra fields) + guided_choice: list[str] | None = None + prompt_logprobs: int | None = None + + +@json_schema_type +class OpenAIChatCompletionRequestParams(BaseModel): + """Request parameters for OpenAI-compatible chat completion endpoint. + + This model uses extra="allow" to capture provider-specific parameters + which are passed through as extra_body. + """ + + model_config = ConfigDict(extra="allow") + + # Required parameters + model: str + messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)] + + # Standard OpenAI chat completion parameters + frequency_penalty: float | None = None + function_call: str | dict[str, Any] | None = None + functions: list[dict[str, Any]] | None = None + logit_bias: dict[str, float] | None = None + logprobs: bool | None = None + max_completion_tokens: int | None = None + max_tokens: int | None = None + n: int | None = None + parallel_tool_calls: bool | None = None + presence_penalty: float | None = None + response_format: OpenAIResponseFormatParam | None = None + seed: int | None = None + stop: str | list[str] | None = None + stream: bool | None = None + stream_options: dict[str, Any] | None = None + temperature: float | None = None + tool_choice: str | dict[str, Any] | None = None + tools: list[dict[str, Any]] | None = None + top_logprobs: int | None = None + top_p: float | None = None + user: str | None = None + + @runtime_checkable @trace_protocol class InferenceProvider(Protocol): @@ -1029,52 +1104,14 @@ class InferenceProvider(Protocol): @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_completion( self, - # Standard OpenAI completion parameters - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - # vLLM-specific parameters - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - # for fill-in-the-middle type completion - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: """Create completion. Generate an OpenAI-compatible completion for the given prompt using the specified model. - :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. - :param prompt: The prompt to generate a completion for. - :param best_of: (Optional) The number of completions to generate. - :param echo: (Optional) Whether to echo the prompt. - :param frequency_penalty: (Optional) The penalty for repeated tokens. - :param logit_bias: (Optional) The logit bias to use. - :param logprobs: (Optional) The log probabilities to use. - :param max_tokens: (Optional) The maximum number of tokens to generate. - :param n: (Optional) The number of completions to generate. - :param presence_penalty: (Optional) The penalty for repeated tokens. - :param seed: (Optional) The seed to use. - :param stop: (Optional) The stop tokens to use. - :param stream: (Optional) Whether to stream the response. - :param stream_options: (Optional) The stream options to use. - :param temperature: (Optional) The temperature to use. - :param top_p: (Optional) The top p to use. - :param user: (Optional) The user to use. - :param suffix: (Optional) The suffix that should be appended to the completion. + :param params: Request parameters including model, prompt, and optional parameters. + Use params.get_extra_body() to extract provider-specific parameters. :returns: An OpenAICompletion. """ ... @@ -1083,57 +1120,14 @@ class InferenceProvider(Protocol): @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """Create chat completions. Generate an OpenAI-compatible chat completion for the given messages using the specified model. - :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. - :param messages: List of messages in the conversation. - :param frequency_penalty: (Optional) The penalty for repeated tokens. - :param function_call: (Optional) The function call to use. - :param functions: (Optional) List of functions to use. - :param logit_bias: (Optional) The logit bias to use. - :param logprobs: (Optional) The log probabilities to use. - :param max_completion_tokens: (Optional) The maximum number of tokens to generate. - :param max_tokens: (Optional) The maximum number of tokens to generate. - :param n: (Optional) The number of completions to generate. - :param parallel_tool_calls: (Optional) Whether to parallelize tool calls. - :param presence_penalty: (Optional) The penalty for repeated tokens. - :param response_format: (Optional) The response format to use. - :param seed: (Optional) The seed to use. - :param stop: (Optional) The stop tokens to use. - :param stream: (Optional) Whether to stream the response. - :param stream_options: (Optional) The stream options to use. - :param temperature: (Optional) The temperature to use. - :param tool_choice: (Optional) The tool choice to use. - :param tools: (Optional) The tools to use. - :param top_logprobs: (Optional) The top log probabilities to use. - :param top_p: (Optional) The top p to use. - :param user: (Optional) The user to use. + :param params: Request parameters including model, messages, and optional parameters. + Use params.get_extra_body() to extract provider-specific parameters (e.g., chat_template_kwargs for vLLM). :returns: An OpenAIChatCompletion. """ ... diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index 0d9f9f134..370882fe9 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -383,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body, field_names = self._handle_file_uploads(options, body) - body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) + body = self._convert_body(matched_func, body, exclude_params=set(field_names)) trace_path = webmethod.descriptive_name or route_path await start_trace(trace_path, {"__location__": "library_client"}) @@ -446,7 +446,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params - body = self._convert_body(path, options.method, body) + # Prepare body for the function call (handles both Pydantic and traditional params) + body = self._convert_body(func, body) trace_path = webmethod.descriptive_name or route_path await start_trace(trace_path, {"__location__": "library_client"}) @@ -493,18 +494,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _convert_body( - self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None - ) -> dict: + def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict: if not body: return {} - assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy exclude_params = exclude_params or set() - - func, _, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) + params_list = [p for p in sig.parameters.values() if p.name != "self"] + # Check if the method expects a single Pydantic model parameter + if len(params_list) == 1: + param = params_list[0] + param_type = param.annotation + try: + if isinstance(param_type, type) and issubclass(param_type, BaseModel): + # Strip NOT_GIVENs before passing to Pydantic + clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN} + + # If the body has a single key matching the parameter name, unwrap it + if len(clean_body) == 1 and param.name in clean_body: + clean_body = clean_body[param.name] + + return {param.name: param_type(**clean_body)} + except (TypeError, AttributeError): + pass + + # Traditional parameter conversion path # Strip NOT_GIVENs to use the defaults in signature body = {k: v for k, v in body.items() if v is not NOT_GIVEN} diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 847f6a2d2..b02c0b788 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -8,11 +8,11 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator from datetime import UTC, datetime -from typing import Annotated, Any +from typing import Any from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam -from pydantic import Field, TypeAdapter +from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -31,15 +31,16 @@ from llama_stack.apis.inference import ( OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestParams, OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCallFunction, OpenAIChoice, OpenAIChoiceLogprobs, OpenAICompletion, + OpenAICompletionRequestParams, OpenAICompletionWithInputMessages, OpenAIEmbeddingsResponse, OpenAIMessageParam, - OpenAIResponseFormatParam, Order, StopReason, ToolPromptFormat, @@ -181,61 +182,23 @@ class InferenceRouter(Inference): async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: logger.debug( - f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", - ) - model_obj = await self._get_model(model, ModelType.llm) - params = dict( - model=model_obj.identifier, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - guided_choice=guided_choice, - prompt_logprobs=prompt_logprobs, - suffix=suffix, + f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}", ) + model_obj = await self._get_model(params.model, ModelType.llm) + + # Update params with the resolved model identifier + params.model = model_obj.identifier + provider = await self.routing_table.get_provider_impl(model_obj.identifier) - if stream: - return await provider.openai_completion(**params) + if params.stream: + return await provider.openai_completion(params) # TODO: Metrics do NOT work with openai_completion stream=True due to the fact # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. - # response_stream = await provider.openai_completion(**params) - response = await provider.openai_completion(**params) + response = await provider.openai_completion(params) if self.telemetry: metrics = self._construct_metrics( prompt_tokens=response.usage.prompt_tokens, @@ -254,93 +217,49 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, - model: str, - messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: logger.debug( - f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", + f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}", ) - model_obj = await self._get_model(model, ModelType.llm) + model_obj = await self._get_model(params.model, ModelType.llm) # Use the OpenAI client for a bit of extra input validation without # exposing the OpenAI client itself as part of our API surface - if tool_choice: - TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) - if tools is None: + if params.tool_choice: + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice) + if params.tools is None: raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") - if tools: - for tool in tools: + if params.tools: + for tool in params.tools: TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) # Some providers make tool calls even when tool_choice is "none" # so just clear them both out to avoid unexpected tool calls - if tool_choice == "none" and tools is not None: - tool_choice = None - tools = None + if params.tool_choice == "none" and params.tools is not None: + params.tool_choice = None + params.tools = None + + # Update params with the resolved model identifier + params.model = model_obj.identifier - params = dict( - model=model_obj.identifier, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) provider = await self.routing_table.get_provider_impl(model_obj.identifier) - if stream: - response_stream = await provider.openai_chat_completion(**params) + if params.stream: + response_stream = await provider.openai_chat_completion(params) # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] # We need to add metrics to each chunk and store the final completion return self.stream_tokens_and_compute_metrics_openai_chat( response=response_stream, model=model_obj, - messages=messages, + messages=params.messages, ) response = await self._nonstream_openai_chat_completion(provider, params) # Store the response with the ID that will be returned to the client if self.store: - asyncio.create_task(self.store.store_chat_completion(response, messages)) + asyncio.create_task(self.store.store_chat_completion(response, params.messages)) if self.telemetry: metrics = self._construct_metrics( @@ -396,8 +315,10 @@ class InferenceRouter(Inference): return await self.store.get_chat_completion(completion_id) raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") - async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: - response = await provider.openai_chat_completion(**params) + async def _nonstream_openai_chat_completion( + self, provider: Inference, params: OpenAIChatCompletionRequestParams + ) -> OpenAIChatCompletion: + response = await provider.openai_chat_completion(params) for choice in response.choices: # some providers return an empty list for no tool calls in non-streaming responses # but the OpenAI API returns None. So, set tool_calls to None if it's empty diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index e19092816..4c53d916a 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -13,12 +13,13 @@ import logging # allow-direct-logging import os import sys import traceback +import types import warnings from collections.abc import Callable from contextlib import asynccontextmanager from importlib.metadata import version as parse_version from pathlib import Path -from typing import Annotated, Any, get_origin +from typing import Annotated, Any, Union, get_origin import httpx import rich.pretty @@ -177,7 +178,17 @@ async def lifespan(app: StackApp): def is_streaming_request(func_name: str, request: Request, **kwargs): # TODO: pass the api method and punt it to the Protocol definition directly - return kwargs.get("stream", False) + # Check for stream parameter at top level (old API style) + if "stream" in kwargs: + return kwargs["stream"] + + # Check for stream parameter inside Pydantic request params (new API style) + if "params" in kwargs: + params = kwargs["params"] + if hasattr(params, "stream"): + return params.stream + + return False async def maybe_await(value): @@ -282,21 +293,46 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: if method == "post": # Annotate parameters that are in the path with Path(...) and others with Body(...), # but preserve existing File() and Form() annotations for multipart form data - new_params = ( - [new_params[0]] - + [ - ( + def get_body_embed_value(param: inspect.Parameter) -> bool: + """Determine if Body should use embed=True or embed=False. + + For OpenAI-compatible endpoints (param name is 'params'), use embed=False + so the request body is parsed directly as the model (not nested). + This allows OpenAI clients to send standard OpenAI format. + For other endpoints, use embed=True for SDK compatibility. + """ + # Get the actual type, stripping Optional/Union if present + param_type = param.annotation + origin = get_origin(param_type) + # Check for Union types (both typing.Union and types.UnionType for | syntax) + if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType): + # Handle Optional[T] / T | None + args = param_type.__args__ if hasattr(param_type, "__args__") else [] + param_type = next((arg for arg in args if arg is not type(None)), param_type) + + # Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible) + try: + is_basemodel = isinstance(param_type, type) and issubclass(param_type, BaseModel) + if is_basemodel and param.name == "params": + return False # Use embed=False for OpenAI-compatible endpoints + return True # Use embed=True for everything else + except TypeError: + # Not a class, use default embed=True + return True + + original_params = new_params[1:] # Skip request parameter + new_params = [new_params[0]] # Keep request parameter + + for param in original_params: + if param.name in path_params: + new_params.append( param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)]) - if param.name in path_params - else ( - param # Keep original annotation if it's already an Annotated type - if get_origin(param.annotation) is Annotated - else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)]) - ) ) - for param in new_params[1:] - ] - ) + elif get_origin(param.annotation) is Annotated: + new_params.append(param) # Keep existing annotation + else: + embed = get_body_embed_value(param) + new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)])) route_handler.__signature__ = sig.replace(parameters=new_params) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b17c720e9..483092eef 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -49,6 +49,7 @@ from llama_stack.apis.inference import ( Inference, Message, OpenAIAssistantMessageParam, + OpenAIChatCompletionRequestParams, OpenAIDeveloperMessageParam, OpenAIMessageParam, OpenAISystemMessageParam, @@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin): max_tokens = getattr(sampling_params, "max_tokens", None) # Use OpenAI chat completion - openai_stream = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=self.agent_config.model, messages=openai_messages, tools=openai_tools if openai_tools else None, @@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin): max_tokens=max_tokens, stream=True, ) + openai_stream = await self.inference_api.openai_chat_completion(params) # Convert OpenAI stream back to Llama Stack format response_stream = convert_openai_chat_completion_stream( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 9487edc61..daac995ca 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -41,6 +41,7 @@ from llama_stack.apis.inference import ( Inference, OpenAIAssistantMessageParam, OpenAIChatCompletion, + OpenAIChatCompletionRequestParams, OpenAIChatCompletionToolCall, OpenAIChoice, OpenAIMessageParam, @@ -130,7 +131,7 @@ class StreamingResponseOrchestrator: # (some providers don't support non-empty response_format when tools are present) response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") - completion_result = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=self.ctx.model, messages=messages, tools=self.ctx.chat_tools, @@ -138,6 +139,7 @@ class StreamingResponseOrchestrator: temperature=self.ctx.temperature, response_format=response_format, ) + completion_result = await self.inference_api.openai_chat_completion(params) # Process streaming chunks and build complete response completion_result_data = None diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index e049518a4..7d0b63312 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import ( Inference, OpenAIAssistantMessageParam, + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, OpenAIDeveloperMessageParam, OpenAIMessageParam, OpenAISystemMessageParam, @@ -601,7 +603,8 @@ class ReferenceBatchesImpl(Batches): # TODO(SECURITY): review body for security issues if request.url == "/v1/chat/completions": request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] - chat_response = await self.inference_api.openai_chat_completion(**request.body) + params = OpenAIChatCompletionRequestParams(**request.body) + chat_response = await self.inference_api.openai_chat_completion(params) # this is for mypy, we don't allow streaming so we'll get the right type assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" @@ -615,7 +618,8 @@ class ReferenceBatchesImpl(Batches): }, } else: # /v1/completions - completion_response = await self.inference_api.openai_completion(**request.body) + params = OpenAICompletionRequestParams(**request.body) + completion_response = await self.inference_api.openai_completion(params) # this is for mypy, we don't allow streaming so we'll get the right type assert hasattr(completion_response, "model_dump_json"), ( diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 0dfe23dca..97b79fe9f 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage +from llama_stack.apis.inference import ( + Inference, + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, + OpenAISystemMessageParam, + OpenAIUserMessageParam, + UserMessage, +) from llama_stack.apis.scoring import Scoring from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( @@ -168,11 +175,12 @@ class MetaReferenceEvalImpl( sampling_params["stop"] = candidate.sampling_params.stop input_content = json.loads(x[ColumnName.completion_input.value]) - response = await self.inference_api.openai_completion( + params = OpenAICompletionRequestParams( model=candidate.model, prompt=input_content, **sampling_params, ) + response = await self.inference_api.openai_completion(params) generations.append({ColumnName.generated_answer.value: response.choices[0].text}) elif ColumnName.chat_completion_input.value in x: chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) @@ -187,11 +195,12 @@ class MetaReferenceEvalImpl( messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += input_messages - response = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=candidate.model, messages=messages, **sampling_params, ) + response = await self.inference_api.openai_chat_completion(params) generations.append({ColumnName.generated_answer.value: response.choices[0].message.content}) else: raise ValueError("Invalid input row") diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index fd65fa10d..1748cbbde 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -6,16 +6,16 @@ import asyncio from collections.abc import AsyncIterator -from typing import Any from llama_stack.apis.inference import ( InferenceProvider, + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, - OpenAIMessageParam, - OpenAIResponseFormatParam, + OpenAICompletion, ) from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger @@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() - async def openai_completion(self, *args, **kwargs): + async def openai_completion( + self, + params: OpenAICompletionRequestParams, + ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by meta reference provider") async def should_refresh_models(self) -> bool: @@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl( async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index b984d97bf..34feb27d0 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -5,17 +5,16 @@ # the root directory of this source tree. from collections.abc import AsyncIterator -from typing import Any from llama_stack.apis.inference import ( InferenceProvider, + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIMessageParam, - OpenAIResponseFormatParam, ) from llama_stack.apis.models import ModelType from llama_stack.log import get_logger @@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl( async def openai_completion( self, - # Standard OpenAI completion parameters - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - # vLLM-specific parameters - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - # for fill-in-the-middle type completion - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider") diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 206182343..2adf01200 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -10,7 +10,13 @@ from string import Template from typing import Any from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem -from llama_stack.apis.inference import Inference, Message, UserMessage +from llama_stack.apis.inference import ( + Inference, + Message, + OpenAIChatCompletionRequestParams, + OpenAIUserMessageParam, + UserMessage, +) from llama_stack.apis.safety import ( RunShieldResponse, Safety, @@ -290,20 +296,21 @@ class LlamaGuardShield: else: shield_input_message = self.build_text_shield_input(messages) - response = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=self.model, messages=[shield_input_message], stream=False, temperature=0.0, # default is 1, which is too high for safety ) + response = await self.inference_api.openai_chat_completion(params) content = response.choices[0].message.content content = content.strip() return self.get_shield_response(content) - def build_text_shield_input(self, messages: list[Message]) -> UserMessage: - return UserMessage(content=self.build_prompt(messages)) + def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam: + return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages)) - def build_vision_shield_input(self, messages: list[Message]) -> UserMessage: + def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam: conversation = [] most_recent_img = None @@ -335,7 +342,7 @@ class LlamaGuardShield: prompt.append(most_recent_img) prompt.append(self.build_prompt(conversation[::-1])) - return UserMessage(content=prompt) + return OpenAIUserMessageParam(role="user", content=prompt) def build_prompt(self, messages: list[Message]) -> str: categories = self.get_safety_categories() @@ -377,11 +384,12 @@ class LlamaGuardShield: # TODO: Add Image based support for OpenAI Moderations shield_input_message = self.build_text_shield_input(messages) - response = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=self.model, messages=[shield_input_message], stream=False, ) + response = await self.inference_api.openai_chat_completion(params) content = response.choices[0].message.content content = content.strip() return self.get_moderation_object(content) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index d60efe828..58cb42fa2 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -6,7 +6,7 @@ import re from typing import Any -from llama_stack.apis.inference import Inference +from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequestParams from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn @@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): generated_answer=generated_answer, ) - judge_response = await self.inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=fn_def.params.judge_model, messages=[ { @@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): } ], ) + judge_response = await self.inference_api.openai_chat_completion(params) content = judge_response.choices[0].message.content rating_regexes = fn_def.params.judge_score_regexes diff --git a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py index 9bc22f979..1d52f7a22 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py @@ -8,7 +8,7 @@ from jinja2 import Template from llama_stack.apis.common.content_types import InterleavedContent -from llama_stack.apis.inference import OpenAIUserMessageParam +from llama_stack.apis.inference import OpenAIChatCompletionRequestParams, OpenAIUserMessageParam from llama_stack.apis.tools.rag_tool import ( DefaultRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig, @@ -65,11 +65,12 @@ async def llm_rag_query_generator( model = config.model message = OpenAIUserMessageParam(content=rendered_content) - response = await inference_api.openai_chat_completion( + params = OpenAIChatCompletionRequestParams( model=model, messages=[message], stream=False, ) + response = await inference_api.openai_chat_completion(params) query = response.choices[0].message.content diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 9c8a74b47..4e4a40919 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -6,21 +6,20 @@ import json from collections.abc import AsyncIterator -from typing import Any from botocore.client import BaseClient from llama_stack.apis.inference import ( ChatCompletionRequest, Inference, + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, OpenAIEmbeddingsResponse, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIMessageParam, - OpenAIResponseFormatParam, ) from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client @@ -135,56 +134,12 @@ class BedrockInferenceAdapter( async def openai_completion( self, - # Standard OpenAI completion parameters - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - # vLLM-specific parameters - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - # for fill-in-the-middle type completion - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 200b36171..3ae11f22b 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -5,11 +5,14 @@ # the root directory of this source tree. from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING from databricks.sdk import WorkspaceClient from llama_stack.apis.inference import OpenAICompletion + +if TYPE_CHECKING: + from llama_stack.apis.inference import OpenAICompletionRequestParams from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -43,25 +46,6 @@ class DatabricksInferenceAdapter(OpenAIMixin): async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: "OpenAICompletionRequestParams", ) -> OpenAICompletion: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 165992c16..c838ea0c5 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,9 +3,12 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import TYPE_CHECKING from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse + +if TYPE_CHECKING: + from llama_stack.apis.inference import OpenAICompletionRequestParams from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -34,26 +37,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin): async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: "OpenAICompletionRequestParams", ) -> OpenAICompletion: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 01078760a..5729e3900 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -13,15 +13,14 @@ from llama_stack.apis.inference import ( Inference, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestParams, OpenAICompletion, + OpenAICompletionRequestParams, OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from .config import PassthroughImplConfig @@ -80,110 +79,33 @@ class PassthroughInferenceAdapter(Inference): async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: client = self._get_client() - model_obj = await self.model_store.get_model(model) + model_obj = await self.model_store.get_model(params.model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - guided_choice=guided_choice, - prompt_logprobs=prompt_logprobs, - ) + # Copy params to avoid mutating the original + params = params.model_copy() + params.model = model_obj.provider_resource_id - return await client.inference.openai_completion(**params) + request_params = params.model_dump(exclude_none=True) + + return await client.inference.openai_completion(**request_params) async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: client = self._get_client() - model_obj = await self.model_store.get_model(model) + model_obj = await self.model_store.get_model(params.model) - params = await prepare_openai_completion_params( - model=model_obj.provider_resource_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) + # Copy params to avoid mutating the original + params = params.model_copy() + params.model = model_obj.provider_resource_id - return await client.inference.openai_chat_completion(**params) + request_params = params.model_dump(exclude_none=True) + + return await client.inference.openai_chat_completion(**request_params) def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: json_params = {} diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index f752740e5..7cb8cc71e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -4,11 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from collections.abc import AsyncIterator from llama_stack.apis.inference import ( - OpenAIMessageParam, - OpenAIResponseFormatParam, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestParams, ) from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -34,56 +35,13 @@ class RunpodInferenceAdapter(OpenAIMixin): async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ): + params: OpenAIChatCompletionRequestParams, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """Override to add RunPod-specific stream_options requirement.""" - if stream and not stream_options: - stream_options = {"include_usage": True} + # Copy params to avoid mutating the original + params = params.model_copy() - return await super().openai_chat_completion( - model=model, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) + if params.stream and not params.stream_options: + params.stream_options = {"include_usage": True} + + return await super().openai_chat_completion(params) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 310eaf7b6..eb3681e02 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from collections.abc import AsyncIterator -from typing import Any from urllib.parse import urljoin import httpx @@ -15,8 +14,7 @@ from pydantic import ConfigDict from llama_stack.apis.inference import ( OpenAIChatCompletion, - OpenAIMessageParam, - OpenAIResponseFormatParam, + OpenAIChatCompletionRequestParams, ToolChoice, ) from llama_stack.log import get_logger @@ -79,61 +77,20 @@ class VLLMInferenceAdapter(OpenAIMixin): async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: "OpenAIChatCompletionRequestParams", ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - max_tokens = max_tokens or self.config.max_tokens + # Copy params to avoid mutating the original + params = params.model_copy() + + # Apply vLLM-specific defaults + if params.max_tokens is None and self.config.max_tokens: + params.max_tokens = self.config.max_tokens # This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # References: # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://github.com/vllm-project/vllm/pull/10000 - if not tools and tool_choice is not None: - tool_choice = ToolChoice.none.value + if not params.tools and params.tool_choice is not None: + params.tool_choice = ToolChoice.none.value - return await super().openai_chat_completion( - model=model, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) + return await super().openai_chat_completion(params) diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6bef97dd5..b1d918585 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -7,7 +7,6 @@ import base64 import struct from collections.abc import AsyncIterator -from typing import Any import litellm @@ -17,12 +16,12 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestParams, OpenAICompletion, + OpenAICompletionRequestParams, OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, - OpenAIMessageParam, - OpenAIResponseFormatParam, ToolChoice, ) from llama_stack.core.request_headers import NeedsRequestProviderData @@ -227,116 +226,88 @@ class LiteLLMOpenAIMixin( async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: OpenAICompletionRequestParams, ) -> OpenAICompletion: - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( + model_obj = await self.model_store.get_model(params.model) + + # Extract extra fields + extra_body = dict(params.__pydantic_extra__ or {}) + + request_params = await prepare_openai_completion_params( model=self.get_litellm_model_name(model_obj.provider_resource_id), - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - guided_choice=guided_choice, - prompt_logprobs=prompt_logprobs, + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + guided_choice=params.guided_choice, + prompt_logprobs=params.prompt_logprobs, + suffix=params.suffix, api_key=self.get_api_key(), api_base=self.api_base, + **extra_body, ) - return await litellm.atext_completion(**params) + return await litellm.atext_completion(**request_params) async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: OpenAIChatCompletionRequestParams, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: # Add usage tracking for streaming when telemetry is active from llama_stack.providers.utils.telemetry.tracing import get_current_span - if stream and get_current_span() is not None: + stream_options = params.stream_options + if params.stream and get_current_span() is not None: if stream_options is None: stream_options = {"include_usage": True} elif "include_usage" not in stream_options: stream_options = {**stream_options, "include_usage": True} - model_obj = await self.model_store.get_model(model) - params = await prepare_openai_completion_params( + + model_obj = await self.model_store.get_model(params.model) + + # Extract extra fields + extra_body = dict(params.__pydantic_extra__ or {}) + + request_params = await prepare_openai_completion_params( model=self.get_litellm_model_name(model_obj.provider_resource_id), - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, api_key=self.get_api_key(), api_base=self.api_base, + **extra_body, ) - return await litellm.acompletion(**params) + return await litellm.acompletion(**request_params) async def check_model_availability(self, model: str) -> bool: """ diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index cba7508a2..9044dd576 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -8,7 +8,7 @@ import base64 import uuid from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterable -from typing import Any +from typing import TYPE_CHECKING, Any from openai import NOT_GIVEN, AsyncOpenAI from pydantic import BaseModel, ConfigDict @@ -22,8 +22,13 @@ from llama_stack.apis.inference import ( OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, - OpenAIResponseFormatParam, ) + +if TYPE_CHECKING: + from llama_stack.apis.inference import ( + OpenAIChatCompletionRequestParams, + OpenAICompletionRequestParams, + ) from llama_stack.apis.models import ModelType from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger @@ -227,96 +232,55 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): async def openai_completion( self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, + params: "OpenAICompletionRequestParams", ) -> OpenAICompletion: """ Direct OpenAI completion API call. """ - # Handle parameters that are not supported by OpenAI API, but may be by the provider - # prompt_logprobs is supported by vLLM - # guided_choice is supported by vLLM - # TODO: test coverage - extra_body: dict[str, Any] = {} - if prompt_logprobs is not None and prompt_logprobs >= 0: - extra_body["prompt_logprobs"] = prompt_logprobs - if guided_choice: - extra_body["guided_choice"] = guided_choice + # Extract extra fields using Pydantic's built-in __pydantic_extra__ + extra_body = dict(params.__pydantic_extra__ or {}) + + # Add vLLM-specific parameters to extra_body if they are set + # (these are explicitly defined in the model but still go to extra_body) + if params.prompt_logprobs is not None and params.prompt_logprobs >= 0: + extra_body["prompt_logprobs"] = params.prompt_logprobs + if params.guided_choice: + extra_body["guided_choice"] = params.guided_choice # TODO: fix openai_completion to return type compatible with OpenAI's API response - resp = await self.client.completions.create( - **await prepare_openai_completion_params( - model=await self._get_provider_model_id(model), - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - suffix=suffix, - ), - extra_body=extra_body, + completion_kwargs = await prepare_openai_completion_params( + model=await self._get_provider_model_id(params.model), + prompt=params.prompt, + best_of=params.best_of, + echo=params.echo, + frequency_penalty=params.frequency_penalty, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_tokens=params.max_tokens, + n=params.n, + presence_penalty=params.presence_penalty, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + top_p=params.top_p, + user=params.user, + suffix=params.suffix, ) + resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body) - return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] + return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return] async def openai_chat_completion( self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, + params: "OpenAIChatCompletionRequestParams", ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: """ Direct OpenAI chat completion API call. """ + messages = params.messages + if self.download_images: async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam: @@ -335,35 +299,38 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): messages = [await _localize_image_url(m) for m in messages] - params = await prepare_openai_completion_params( - model=await self._get_provider_model_id(model), + request_params = await prepare_openai_completion_params( + model=await self._get_provider_model_id(params.model), messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, ) - resp = await self.client.chat.completions.create(**params) + # Extract any additional provider-specific parameters using Pydantic's __pydantic_extra__ + if extra_body := dict(params.__pydantic_extra__ or {}): + request_params["extra_body"] = extra_body + resp = await self.client.chat.completions.create(**request_params) - return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] + return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return] async def openai_embeddings( self, diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 6d6bb20d5..8620aadbc 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -13,6 +13,7 @@ import pytest from llama_stack.apis.inference import ( OpenAIAssistantMessageParam, OpenAIChatCompletion, + OpenAIChatCompletionRequestParams, OpenAIChoice, ToolChoice, ) @@ -56,13 +57,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter): mock_client_property.return_value = mock_client # No tools but auto tool choice - await vllm_inference_adapter.openai_chat_completion( - "mock-model", - [], + params = OpenAIChatCompletionRequestParams( + model="mock-model", + messages=[{"role": "user", "content": "test"}], stream=False, tools=None, tool_choice=ToolChoice.auto.value, ) + await vllm_inference_adapter.openai_chat_completion(params) mock_client.chat.completions.create.assert_called() call_args = mock_client.chat.completions.create.call_args # Ensure tool_choice gets converted to none for older vLLM versions @@ -171,9 +173,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): ) async def do_inference(): - await vllm_inference_adapter.openai_chat_completion( - "mock-model", messages=["one fish", "two fish"], stream=False + params = OpenAIChatCompletionRequestParams( + model="mock-model", + messages=[{"role": "user", "content": "one fish two fish"}], + stream=False, ) + await vllm_inference_adapter.openai_chat_completion(params) with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: mock_client = MagicMock() @@ -186,3 +191,48 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" + + +async def test_extra_body_forwarding(vllm_inference_adapter): + """ + Test that extra_body parameters (e.g., chat_template_kwargs) are correctly + forwarded to the underlying OpenAI client. + """ + mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference") + vllm_inference_adapter.model_store.get_model.return_value = mock_model + + with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property: + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock( + return_value=OpenAIChatCompletion( + id="chatcmpl-abc123", + created=1, + model="mock-model", + choices=[ + OpenAIChoice( + message=OpenAIAssistantMessageParam( + content="test response", + ), + finish_reason="stop", + index=0, + ) + ], + ) + ) + mock_client_property.return_value = mock_client + + # Test with chat_template_kwargs for Granite thinking mode + params = OpenAIChatCompletionRequestParams( + model="mock-model", + messages=[{"role": "user", "content": "test"}], + stream=False, + chat_template_kwargs={"thinking": True}, + ) + await vllm_inference_adapter.openai_chat_completion(params) + + # Verify that the client was called with extra_body containing chat_template_kwargs + mock_client.chat.completions.create.assert_called_once() + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert "extra_body" in call_kwargs + assert "chat_template_kwargs" in call_kwargs["extra_body"] + assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True} diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index ad9406951..2df882598 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch import pytest from pydantic import BaseModel, Field -from llama_stack.apis.inference import Model, OpenAIUserMessageParam +from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestParams, OpenAIUserMessageParam from llama_stack.apis.models import ModelType from llama_stack.core.request_headers import request_provider_data_context from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig @@ -271,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing: with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: mock_localize.return_value = (b"fake_image_data", "jpeg") - await mixin.openai_chat_completion(model="test-model", messages=[message]) + params = OpenAIChatCompletionRequestParams(model="test-model", messages=[message]) + await mixin.openai_chat_completion(params) mock_localize.assert_called_once_with("http://example.com/image.jpg") @@ -303,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing: with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: - await mixin.openai_chat_completion(model="test-model", messages=[message]) + params = OpenAIChatCompletionRequestParams(model="test-model", messages=[message]) + await mixin.openai_chat_completion(params) mock_localize.assert_not_called()