# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-09 20:53:19 -07:00
parent f50ce11a3b
commit 4a3d1e33f8
31 changed files with 727 additions and 892 deletions

View file

@ -11,6 +11,7 @@ from pathlib import Path
from typing import TextIO from typing import TextIO
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
from pydantic import BaseModel
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.core.resolver import api_protocol_map from llama_stack.core.resolver import api_protocol_map
@ -205,6 +206,14 @@ def _validate_has_return_in_docstring(method) -> str | None:
def _validate_has_params_in_docstring(method) -> str | None: def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method) source = inspect.getsource(method)
sig = inspect.signature(method) sig = inspect.signature(method)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if issubclass(param_type, BaseModel):
return
# Only check if the method has more than one parameter # Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source: if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring" return "does not have a ':param' in its docstring"

View file

@ -7716,7 +7716,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -7900,7 +7901,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -8031,18 +8032,20 @@
"type": "string", "type": "string",
"description": "(Optional) The user to use." "description": "(Optional) The user to use."
}, },
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": { "guided_choice": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
}, "description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -8050,6 +8053,20 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest" "title": "OpenaiCompletionRequest"
}, },
"OpenAICompletion": { "OpenAICompletion": {

View file

@ -5671,6 +5671,8 @@ components:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -5824,7 +5826,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5912,20 +5914,37 @@ components:
user: user:
type: string type: string
description: (Optional) The user to use. description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix: suffix:
type: string type: string
description: >- description: >-
(Optional) The suffix that should be appended to the completion. (Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false additionalProperties: false
required: required:
- model - model
- prompt - prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest title: OpenaiCompletionRequest
OpenAICompletion: OpenAICompletion:
type: object type: object

View file

@ -5212,7 +5212,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -5396,7 +5397,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -5527,18 +5528,20 @@
"type": "string", "type": "string",
"description": "(Optional) The user to use." "description": "(Optional) The user to use."
}, },
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": { "guided_choice": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
}, "description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -5546,6 +5549,20 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest" "title": "OpenaiCompletionRequest"
}, },
"OpenAICompletion": { "OpenAICompletion": {

View file

@ -3920,6 +3920,8 @@ components:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -4073,7 +4075,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -4161,20 +4163,37 @@ components:
user: user:
type: string type: string
description: (Optional) The user to use. description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix: suffix:
type: string type: string
description: >- description: >-
(Optional) The suffix that should be appended to the completion. (Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false additionalProperties: false
required: required:
- model - model
- prompt - prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest title: OpenaiCompletionRequest
OpenAICompletion: OpenAICompletion:
type: object type: object

View file

@ -7221,7 +7221,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -7405,7 +7406,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -7536,18 +7537,20 @@
"type": "string", "type": "string",
"description": "(Optional) The user to use." "description": "(Optional) The user to use."
}, },
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": { "guided_choice": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
}, "description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -7555,6 +7558,20 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest" "title": "OpenaiCompletionRequest"
}, },
"OpenAICompletion": { "OpenAICompletion": {

View file

@ -5365,6 +5365,8 @@ components:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -5518,7 +5520,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5606,20 +5608,37 @@ components:
user: user:
type: string type: string
description: (Optional) The user to use. description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix: suffix:
type: string type: string
description: >- description: >-
(Optional) The suffix that should be appended to the completion. (Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false additionalProperties: false
required: required:
- model - model
- prompt - prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest title: OpenaiCompletionRequest
OpenAICompletion: OpenAICompletion:
type: object type: object

View file

@ -14,7 +14,7 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
@ -995,6 +995,120 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
@json_schema_type
class OpenAICompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:param guided_choice: (Optional) vLLM-specific parameter for guided generation with a list of choices.
:param prompt_logprobs: (Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
# Standard OpenAI completion parameters
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# vLLM-specific parameters (documented here but also allowed via extra fields)
guided_choice: list[str] | None = None
prompt_logprobs: int | None = None
@json_schema_type
class OpenaiChatCompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible chat completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
# Standard OpenAI chat completion parameters
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class InferenceProvider(Protocol): class InferenceProvider(Protocol):
@ -1029,52 +1143,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: OpenAICompletionRequest,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
"""Create completion. """Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model. Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion. :returns: An OpenAICompletion.
""" """
... ...
@ -1083,57 +1156,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions. """Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model. Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion. :returns: An OpenAIChatCompletion.
""" """
... ...

View file

@ -383,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body) body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"}) await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +446,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) # Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"}) await start_trace(trace_path, {"__location__": "library_client"})
@ -493,17 +494,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body( def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body: if not body:
return {} return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
exclude_params = exclude_params or set() exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Check if the method expects a single Pydantic model parameter
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if issubclass(param_type, BaseModel):
# Strip NOT_GIVENs before passing to Pydantic
clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# If the body has a single key matching the parameter name, unwrap it
if len(clean_body) == 1 and param.name in clean_body:
clean_body = clean_body[param.name]
return {param.name: param_type(**clean_body)}
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN} body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

View file

@ -8,11 +8,11 @@ import asyncio
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Annotated, Any from typing import Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction, OpenAIChatCompletionToolCallFunction,
OpenAIChoice, OpenAIChoice,
OpenAIChoiceLogprobs, OpenAIChoiceLogprobs,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAICompletionWithInputMessages, OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam,
Order, Order,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
logger.debug( logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
) )
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if params.stream:
return await provider.openai_completion(**params) return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact # TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params) response = await provider.openai_completion(params)
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens, prompt_tokens=response.usage.prompt_tokens,
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
) )
model_obj = await self._get_model(model, ModelType.llm) model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without # Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface # exposing the OpenAI client itself as part of our API surface
if tool_choice: if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if tools is None: if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools: if params.tools:
for tool in tools: for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none" # Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls # so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None: if params.tool_choice == "none" and params.tools is not None:
tool_choice = None params.tool_choice = None
tools = None params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if params.stream:
response_stream = await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion # We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat( return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream, response=response_stream,
model=model_obj, model=model_obj,
messages=messages, messages=params.messages,
) )
response = await self._nonstream_openai_chat_completion(provider, params) response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client # Store the response with the ID that will be returned to the client
if self.store: if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages)) asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id) return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: async def _nonstream_openai_chat_completion(
response = await provider.openai_chat_completion(**params) self, provider: Inference, params: OpenaiChatCompletionRequest
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices: for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses # some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty # but the OpenAI API returns None. So, set tool_calls to None if it's empty

View file

@ -13,12 +13,13 @@ import logging # allow-direct-logging
import os import os
import sys import sys
import traceback import traceback
import types
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
from typing import Annotated, Any, get_origin from typing import Annotated, Any, Union, get_origin
import httpx import httpx
import rich.pretty import rich.pretty
@ -177,7 +178,17 @@ async def lifespan(app: StackApp):
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly # TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False) # Check for stream parameter at top level (old API style)
if "stream" in kwargs:
return kwargs["stream"]
# Check for stream parameter inside Pydantic request params (new API style)
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value): async def maybe_await(value):
@ -282,21 +293,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
if method == "post": if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...), # Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data # but preserve existing File() and Form() annotations for multipart form data
new_params = ( def get_body_embed_value(param: inspect.Parameter) -> bool:
[new_params[0]] """Determine if Body should use embed=True or embed=False.
+ [
( For OpenAI-compatible endpoints (param name is 'params'), use embed=False
so the request body is parsed directly as the model (not nested).
This allows OpenAI clients to send standard OpenAI format.
For other endpoints, use embed=True for SDK compatibility.
"""
# Get the actual type, stripping Optional/Union if present
param_type = param.annotation
origin = get_origin(param_type)
# Check for Union types (both typing.Union and types.UnionType for | syntax)
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
# Handle Optional[T] / T | None
args = param_type.__args__ if hasattr(param_type, "__args__") else []
param_type = next((arg for arg in args if arg is not type(None)), param_type)
# Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible)
is_basemodel = issubclass(param_type, BaseModel)
if is_basemodel and param.name == "params":
return False # Use embed=False for OpenAI-compatible endpoints
return True # Use embed=True for everything else
original_params = new_params[1:] # Skip request parameter
new_params = [new_params[0]] # Keep request parameter
for param in original_params:
if param.name in path_params:
new_params.append(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)]) param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
) )
elif get_origin(param.annotation) is Annotated:
new_params.append(param) # Keep existing annotation
else:
embed = get_body_embed_value(param)
new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)]))
route_handler.__signature__ = sig.replace(parameters=new_params) route_handler.__signature__ = sig.replace(parameters=new_params)

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
Inference, Inference,
Message, Message,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None) max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion # Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=self.agent_config.model, model=self.agent_config.model,
messages=openai_messages, messages=openai_messages,
tools=openai_tools if openai_tools else None, tools=openai_tools if openai_tools else None,
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
) )
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format # Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream( response_stream = convert_openai_chat_completion_stream(

View file

@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenaiChatCompletionRequest,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChoice, OpenAIChoice,
OpenAIMessageParam, OpenAIMessageParam,
@ -130,7 +131,7 @@ class StreamingResponseOrchestrator:
# (some providers don't support non-empty response_format when tools are present) # (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=self.ctx.model, model=self.ctx.model,
messages=messages, messages=messages,
tools=self.ctx.chat_tools, tools=self.ctx.chat_tools,
@ -138,6 +139,7 @@ class StreamingResponseOrchestrator:
temperature=self.ctx.temperature, temperature=self.ctx.temperature,
response_format=response_format, response_format=response_format,
) )
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response # Process streaming chunks and build complete response
completion_result_data = None completion_result_data = None

View file

@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
@ -601,7 +603,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues # TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body) params = OpenaiChatCompletionRequest(**request.body)
chat_response = await self.inference_api.openai_chat_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -615,7 +618,8 @@ class ReferenceBatchesImpl(Batches):
}, },
} }
else: # /v1/completions else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body) params = OpenAICompletionRequest(**request.body)
completion_response = await self.inference_api.openai_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), ( assert hasattr(completion_response, "model_dump_json"), (

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage from llama_stack.apis.inference import (
Inference,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value]) input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion( params = OpenAICompletionRequest(
model=candidate.model, model=candidate.model,
prompt=input_content, prompt=input_content,
**sampling_params, **sampling_params,
) )
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text}) generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages messages += input_messages
response = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=candidate.model, model=candidate.model,
messages=messages, messages=messages,
**sampling_params, **sampling_params,
) )
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content}) generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")

View file

@ -6,16 +6,16 @@
import asyncio import asyncio
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIMessageParam, OpenAICompletion,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def openai_completion(self, *args, **kwargs): async def openai_completion(
self,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider") raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: OpenAICompletionRequest,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider") raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -10,7 +10,13 @@ from string import Template
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage from llama_stack.apis.inference import (
Inference,
Message,
OpenaiChatCompletionRequest,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
@ -290,20 +296,21 @@ class LlamaGuardShield:
else: else:
shield_input_message = self.build_text_shield_input(messages) shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=self.model, model=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=False, stream=False,
temperature=0.0, # default is 1, which is too high for safety temperature=0.0, # default is 1, which is too high for safety
) )
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content content = response.choices[0].message.content
content = content.strip() content = content.strip()
return self.get_shield_response(content) return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage: def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return UserMessage(content=self.build_prompt(messages)) return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage: def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
conversation = [] conversation = []
most_recent_img = None most_recent_img = None
@ -335,7 +342,7 @@ class LlamaGuardShield:
prompt.append(most_recent_img) prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1])) prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt) return OpenAIUserMessageParam(role="user", content=prompt)
def build_prompt(self, messages: list[Message]) -> str: def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories() categories = self.get_safety_categories()
@ -377,11 +384,12 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations # TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages) shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=self.model, model=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=False, stream=False,
) )
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content content = response.choices[0].message.content
content = content.strip() content = content.strip()
return self.get_moderation_object(content) return self.get_moderation_object(content)

View file

@ -6,7 +6,7 @@
import re import re
from typing import Any from typing import Any
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference, OpenaiChatCompletionRequest
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer, generated_answer=generated_answer,
) )
judge_response = await self.inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=fn_def.params.judge_model, model=fn_def.params.judge_model,
messages=[ messages=[
{ {
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
} }
], ],
) )
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam from llama_stack.apis.inference import OpenaiChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import ( from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model model = config.model
message = OpenAIUserMessageParam(content=rendered_content) message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion( params = OpenaiChatCompletionRequest(
model=model, model=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content query = response.choices[0].message.content

View file

@ -6,21 +6,20 @@
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
Inference, Inference,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -135,56 +134,12 @@ class BedrockInferenceAdapter(
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: OpenAICompletionRequest,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -5,11 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any from typing import TYPE_CHECKING
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import OpenAICompletion from llama_stack.apis.inference import OpenAICompletion
if TYPE_CHECKING:
from llama_stack.apis.inference import OpenAICompletionRequest
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,25 +46,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: "OpenAICompletionRequest",
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()

View file

@ -3,9 +3,12 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from typing import TYPE_CHECKING
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
if TYPE_CHECKING:
from llama_stack.apis.inference import OpenAICompletionRequest
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -34,26 +37,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: "OpenAICompletionRequest",
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()

View file

@ -13,15 +13,14 @@ from llama_stack.apis.inference import (
Inference, Inference,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
@ -80,110 +79,33 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params( # Copy params to avoid mutating the original
model=model_obj.provider_resource_id, params = params.model_copy()
prompt=prompt, params.model = model_obj.provider_resource_id
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await client.inference.openai_completion(**params) request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params( # Copy params to avoid mutating the original
model=model_obj.provider_resource_id, params = params.model_copy()
messages=messages, params.model = model_obj.provider_resource_id
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await client.inference.openai_chat_completion(**params) request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {} json_params = {}

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from collections.abc import AsyncIterator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIChatCompletion,
OpenAIResponseFormatParam, OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -34,56 +35,13 @@ class RunpodInferenceAdapter(OpenAIMixin):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam], ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
"""Override to add RunPod-specific stream_options requirement.""" """Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options: # Copy params to avoid mutating the original
stream_options = {"include_usage": True} params = params.model_copy()
return await super().openai_chat_completion( if params.stream and not params.stream_options:
model=model, params.stream_options = {"include_usage": True}
messages=messages,
frequency_penalty=frequency_penalty, return await super().openai_chat_completion(params)
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin from urllib.parse import urljoin
import httpx import httpx
@ -15,8 +14,7 @@ from pydantic import ConfigDict
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIMessageParam, OpenaiChatCompletionRequest,
OpenAIResponseFormatParam,
ToolChoice, ToolChoice,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -79,61 +77,20 @@ class VLLMInferenceAdapter(OpenAIMixin):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: "OpenaiChatCompletionRequest",
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens # Copy params to avoid mutating the original
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000 # * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None: if not params.tools and params.tool_choice is not None:
tool_choice = ToolChoice.none.value params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion( return await super().openai_chat_completion(params)
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -7,7 +7,6 @@
import base64 import base64
import struct import struct
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
import litellm import litellm
@ -17,12 +16,12 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice, ToolChoice,
) )
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
@ -227,116 +226,88 @@ class LiteLLMOpenAIMixin(
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
# Extract extra fields
extra_body = dict(params.__pydantic_extra__ or {})
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id), model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt, prompt=params.prompt,
best_of=best_of, best_of=params.best_of,
echo=echo, echo=params.echo,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=params.stream_options,
temperature=temperature, temperature=params.temperature,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
guided_choice=guided_choice, guided_choice=params.guided_choice,
prompt_logprobs=prompt_logprobs, prompt_logprobs=params.prompt_logprobs,
suffix=params.suffix,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
**extra_body,
) )
return await litellm.atext_completion(**params) return await litellm.atext_completion(**request_params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenaiChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active # Add usage tracking for streaming when telemetry is active
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
if stream and get_current_span() is not None: stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None: if stream_options is None:
stream_options = {"include_usage": True} stream_options = {"include_usage": True}
elif "include_usage" not in stream_options: elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True} stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( model_obj = await self.model_store.get_model(params.model)
# Extract extra fields
extra_body = dict(params.__pydantic_extra__ or {})
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id), model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages, messages=params.messages,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=function_call, function_call=params.function_call,
functions=functions, functions=params.functions,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_completion_tokens=max_completion_tokens, max_completion_tokens=params.max_completion_tokens,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
response_format=response_format, response_format=params.response_format,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=stream_options,
temperature=temperature, temperature=params.temperature,
tool_choice=tool_choice, tool_choice=params.tool_choice,
tools=tools, tools=params.tools,
top_logprobs=top_logprobs, top_logprobs=params.top_logprobs,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
**extra_body,
) )
return await litellm.acompletion(**params) return await litellm.acompletion(**request_params)
async def check_model_availability(self, model: str) -> bool: async def check_model_availability(self, model: str) -> bool:
""" """

View file

@ -8,7 +8,7 @@ import base64
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable from collections.abc import AsyncIterator, Iterable
from typing import Any from typing import TYPE_CHECKING, Any
from openai import NOT_GIVEN, AsyncOpenAI from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -22,8 +22,13 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
if TYPE_CHECKING:
from llama_stack.apis.inference import (
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -227,96 +232,55 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: "OpenAICompletionRequest",
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
""" """
Direct OpenAI completion API call. Direct OpenAI completion API call.
""" """
# Handle parameters that are not supported by OpenAI API, but may be by the provider # Extract extra fields using Pydantic's built-in __pydantic_extra__
# prompt_logprobs is supported by vLLM extra_body = dict(params.__pydantic_extra__ or {})
# guided_choice is supported by vLLM
# TODO: test coverage # Add vLLM-specific parameters to extra_body if they are set
extra_body: dict[str, Any] = {} # (these are explicitly defined in the model but still go to extra_body)
if prompt_logprobs is not None and prompt_logprobs >= 0: if params.prompt_logprobs is not None and params.prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs extra_body["prompt_logprobs"] = params.prompt_logprobs
if guided_choice: if params.guided_choice:
extra_body["guided_choice"] = guided_choice extra_body["guided_choice"] = params.guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
resp = await self.client.completions.create( completion_kwargs = await prepare_openai_completion_params(
**await prepare_openai_completion_params( model=await self._get_provider_model_id(params.model),
model=await self._get_provider_model_id(model), prompt=params.prompt,
prompt=prompt, best_of=params.best_of,
best_of=best_of, echo=params.echo,
echo=echo, frequency_penalty=params.frequency_penalty,
frequency_penalty=frequency_penalty, logit_bias=params.logit_bias,
logit_bias=logit_bias, logprobs=params.logprobs,
logprobs=logprobs, max_tokens=params.max_tokens,
max_tokens=max_tokens, n=params.n,
n=n, presence_penalty=params.presence_penalty,
presence_penalty=presence_penalty, seed=params.seed,
seed=seed, stop=params.stop,
stop=stop, stream=params.stream,
stream=stream, stream_options=params.stream_options,
stream_options=stream_options, temperature=params.temperature,
temperature=temperature, top_p=params.top_p,
top_p=top_p, user=params.user,
user=user, suffix=params.suffix,
suffix=suffix,
),
extra_body=extra_body,
) )
resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: "OpenaiChatCompletionRequest",
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
messages = params.messages
if self.download_images: if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam: async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
@ -335,35 +299,38 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages] messages = [await _localize_image_url(m) for m in messages]
params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(model), model=await self._get_provider_model_id(params.model),
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=function_call, function_call=params.function_call,
functions=functions, functions=params.functions,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_completion_tokens=max_completion_tokens, max_completion_tokens=params.max_completion_tokens,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
response_format=response_format, response_format=params.response_format,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=params.stream_options,
temperature=temperature, temperature=params.temperature,
tool_choice=tool_choice, tool_choice=params.tool_choice,
tools=tools, tools=params.tools,
top_logprobs=top_logprobs, top_logprobs=params.top_logprobs,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
) )
resp = await self.client.chat.completions.create(**params) # Extract any additional provider-specific parameters using Pydantic's __pydantic_extra__
if extra_body := dict(params.__pydantic_extra__ or {}):
request_params["extra_body"] = extra_body
resp = await self.client.chat.completions.create(**request_params)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_embeddings( async def openai_embeddings(
self, self,

View file

@ -146,14 +146,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
# For streaming response, collect all chunks # For streaming response, collect all chunks
chunks = [chunk async for chunk in result] chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with( # Verify the inference API was called with the correct params
model=model, call_args = mock_inference_api.openai_chat_completion.call_args
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], params = call_args.args[0] # params is passed as first positional arg
response_format=None, assert params.model == model
tools=None, assert params.messages == [
stream=True, OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)
temperature=0.1, ]
) assert params.response_format is None
assert params.tools is None
assert params.stream is True
assert params.temperature == 0.1
# Should have content part events for text streaming # Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed # Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
@ -228,13 +231,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify # Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?" first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1] second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin" second_params = second_call.args[0]
assert second_call.kwargs["temperature"] == 0.1 assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search") openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with( openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -309,9 +314,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Verify inference API was called correctly (after iterating over result) # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == input_text
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output) # Check response.created event (should have empty output)
assert chunks[0].type == "response.created" assert chunks[0].type == "response.created"
@ -386,9 +392,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
# Verify inference API was called correctly (after iterating over result) # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == input_text
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output) # Check response.created event (should have empty output)
assert chunks[0].type == "response.created" assert chunks[0].type == "response.created"
@ -435,9 +442,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
# Verify inference API was called correctly (after iterating over result) # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == input_text
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output) # Check response.created event (should have empty output)
assert chunks[0].type == "response.created" assert chunks[0].type == "response.created"
@ -485,7 +493,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e. # Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects # All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"] call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages): for i, m in enumerate(input_messages):
if isinstance(m.content, str): if isinstance(m.content, str):
assert inference_messages[i].content == m.content assert inference_messages[i].content == m.content
@ -653,7 +663,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 2 assert len(sent_messages) == 2
@ -691,7 +702,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -751,7 +763,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages assert len(sent_messages) == 4, sent_messages
@ -987,8 +1000,9 @@ async def test_create_openai_response_with_text_format(
# Verify # Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["response_format"] == response_format assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):

View file

@ -13,6 +13,7 @@ import pytest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenaiChatCompletionRequest,
OpenAIChoice, OpenAIChoice,
ToolChoice, ToolChoice,
) )
@ -56,13 +57,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client mock_client_property.return_value = mock_client
# No tools but auto tool choice # No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion( params = OpenaiChatCompletionRequest(
"mock-model", model="mock-model",
[], messages=[{"role": "user", "content": "test"}],
stream=False, stream=False,
tools=None, tools=None,
tool_choice=ToolChoice.auto.value, tool_choice=ToolChoice.auto.value,
) )
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called() mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions # Ensure tool_choice gets converted to none for older vLLM versions
@ -171,9 +173,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
) )
async def do_inference(): async def do_inference():
await vllm_inference_adapter.openai_chat_completion( params = OpenaiChatCompletionRequest(
"mock-model", messages=["one fish", "two fish"], stream=False model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
) )
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock() mock_client = MagicMock()
@ -186,3 +191,48 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_extra_body_forwarding(vllm_inference_adapter):
"""
Test that extra_body parameters (e.g., chat_template_kwargs) are correctly
forwarded to the underlying OpenAI client.
"""
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs for Granite thinking mode
params = OpenaiChatCompletionRequest(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await vllm_inference_adapter.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam from llama_stack.apis.inference import Model, OpenaiChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -271,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg") mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message]) params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg") mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -303,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message]) params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called() mock_localize.assert_not_called()