# What does this PR do?


## Test Plan
This commit is contained in:
Eric Huang 2025-10-09 20:53:19 -07:00
parent f50ce11a3b
commit 4a3d1e33f8
31 changed files with 727 additions and 892 deletions

View file

@ -11,6 +11,7 @@ from pathlib import Path
from typing import TextIO
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
from pydantic import BaseModel
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.core.resolver import api_protocol_map
@ -205,6 +206,14 @@ def _validate_has_return_in_docstring(method) -> str | None:
def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method)
sig = inspect.signature(method)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if issubclass(param_type, BaseModel):
return
# Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring"

View file

@ -7716,7 +7716,8 @@
"model",
"messages"
],
"title": "OpenaiChatCompletionRequest"
"title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
},
"OpenAIChatCompletion": {
"type": "object",
@ -7900,7 +7901,7 @@
],
"title": "OpenAICompletionWithInputMessages"
},
"OpenaiCompletionRequest": {
"OpenAICompletionRequest": {
"type": "object",
"properties": {
"model": {
@ -8031,18 +8032,20 @@
"type": "string",
"description": "(Optional) The user to use."
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": {
"type": "array",
"items": {
"type": "string"
}
},
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
},
"prompt_logprobs": {
"type": "integer"
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
"type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}
},
"additionalProperties": false,
@ -8050,6 +8053,20 @@
"model",
"prompt"
],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest"
},
"OpenAICompletion": {

View file

@ -5671,6 +5671,8 @@ components:
- model
- messages
title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion:
type: object
properties:
@ -5824,7 +5826,7 @@ components:
- model
- input_messages
title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest:
OpenAICompletionRequest:
type: object
properties:
model:
@ -5912,20 +5914,37 @@ components:
user:
type: string
description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix:
type: string
description: >-
(Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false
required:
- model
- prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest
OpenAICompletion:
type: object

View file

@ -5212,7 +5212,8 @@
"model",
"messages"
],
"title": "OpenaiChatCompletionRequest"
"title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
},
"OpenAIChatCompletion": {
"type": "object",
@ -5396,7 +5397,7 @@
],
"title": "OpenAICompletionWithInputMessages"
},
"OpenaiCompletionRequest": {
"OpenAICompletionRequest": {
"type": "object",
"properties": {
"model": {
@ -5527,18 +5528,20 @@
"type": "string",
"description": "(Optional) The user to use."
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": {
"type": "array",
"items": {
"type": "string"
}
},
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
},
"prompt_logprobs": {
"type": "integer"
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
"type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}
},
"additionalProperties": false,
@ -5546,6 +5549,20 @@
"model",
"prompt"
],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest"
},
"OpenAICompletion": {

View file

@ -3920,6 +3920,8 @@ components:
- model
- messages
title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion:
type: object
properties:
@ -4073,7 +4075,7 @@ components:
- model
- input_messages
title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest:
OpenAICompletionRequest:
type: object
properties:
model:
@ -4161,20 +4163,37 @@ components:
user:
type: string
description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix:
type: string
description: >-
(Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false
required:
- model
- prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest
OpenAICompletion:
type: object

View file

@ -7221,7 +7221,8 @@
"model",
"messages"
],
"title": "OpenaiChatCompletionRequest"
"title": "OpenaiChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
},
"OpenAIChatCompletion": {
"type": "object",
@ -7405,7 +7406,7 @@
],
"title": "OpenAICompletionWithInputMessages"
},
"OpenaiCompletionRequest": {
"OpenAICompletionRequest": {
"type": "object",
"properties": {
"model": {
@ -7536,18 +7537,20 @@
"type": "string",
"description": "(Optional) The user to use."
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
},
"guided_choice": {
"type": "array",
"items": {
"type": "string"
}
},
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
},
"prompt_logprobs": {
"type": "integer"
},
"suffix": {
"type": "string",
"description": "(Optional) The suffix that should be appended to the completion."
"type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}
},
"additionalProperties": false,
@ -7555,6 +7558,20 @@
"model",
"prompt"
],
"title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
},
"OpenaiCompletionRequest": {
"type": "object",
"properties": {
"params": {
"$ref": "#/components/schemas/OpenAICompletionRequest"
}
},
"additionalProperties": false,
"required": [
"params"
],
"title": "OpenaiCompletionRequest"
},
"OpenAICompletion": {

View file

@ -5365,6 +5365,8 @@ components:
- model
- messages
title: OpenaiChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion:
type: object
properties:
@ -5518,7 +5520,7 @@ components:
- model
- input_messages
title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest:
OpenAICompletionRequest:
type: object
properties:
model:
@ -5606,20 +5608,37 @@ components:
user:
type: string
description: (Optional) The user to use.
guided_choice:
type: array
items:
type: string
prompt_logprobs:
type: integer
suffix:
type: string
description: >-
(Optional) The suffix that should be appended to the completion.
guided_choice:
type: array
items:
type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs:
type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
additionalProperties: false
required:
- model
- prompt
title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenaiCompletionRequest:
type: object
properties:
params:
$ref: '#/components/schemas/OpenAICompletionRequest'
additionalProperties: false
required:
- params
title: OpenaiCompletionRequest
OpenAICompletion:
type: object

View file

@ -14,7 +14,7 @@ from typing import (
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
@ -995,6 +995,120 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list"
@json_schema_type
class OpenAICompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:param guided_choice: (Optional) vLLM-specific parameter for guided generation with a list of choices.
:param prompt_logprobs: (Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
# Standard OpenAI completion parameters
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# vLLM-specific parameters (documented here but also allowed via extra fields)
guided_choice: list[str] | None = None
prompt_logprobs: int | None = None
@json_schema_type
class OpenaiChatCompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible chat completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
# Standard OpenAI chat completion parameters
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
@runtime_checkable
@trace_protocol
class InferenceProvider(Protocol):
@ -1029,52 +1143,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
"""Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion.
"""
...
@ -1083,57 +1156,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
"""
...

View file

@ -383,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +446,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -493,17 +494,27 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
if not body:
return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Check if the method expects a single Pydantic model parameter
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if issubclass(param_type, BaseModel):
# Strip NOT_GIVENs before passing to Pydantic
clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# If the body has a single key matching the parameter name, unwrap it
if len(clean_body) == 1 and param.name in clean_body:
clean_body = clean_body[param.name]
return {param.name: param_type(**clean_body)}
# Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

View file

@ -8,11 +8,11 @@ import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime
from typing import Annotated, Any
from typing import Any
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionRequest,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
StopReason,
ToolPromptFormat,
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_completion(**params)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params)
response = await provider.openai_completion(params)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(model, ModelType.llm)
model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
if params.tools:
for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
if params.tool_choice == "none" and params.tools is not None:
params.tool_choice = None
params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if params.stream:
response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=messages,
messages=params.messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages))
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry:
metrics = self._construct_metrics(
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: OpenaiChatCompletionRequest
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty

View file

@ -13,12 +13,13 @@ import logging # allow-direct-logging
import os
import sys
import traceback
import types
import warnings
from collections.abc import Callable
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Annotated, Any, get_origin
from typing import Annotated, Any, Union, get_origin
import httpx
import rich.pretty
@ -177,7 +178,17 @@ async def lifespan(app: StackApp):
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
# Check for stream parameter at top level (old API style)
if "stream" in kwargs:
return kwargs["stream"]
# Check for stream parameter inside Pydantic request params (new API style)
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value):
@ -282,21 +293,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data
new_params = (
[new_params[0]]
+ [
(
def get_body_embed_value(param: inspect.Parameter) -> bool:
"""Determine if Body should use embed=True or embed=False.
For OpenAI-compatible endpoints (param name is 'params'), use embed=False
so the request body is parsed directly as the model (not nested).
This allows OpenAI clients to send standard OpenAI format.
For other endpoints, use embed=True for SDK compatibility.
"""
# Get the actual type, stripping Optional/Union if present
param_type = param.annotation
origin = get_origin(param_type)
# Check for Union types (both typing.Union and types.UnionType for | syntax)
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
# Handle Optional[T] / T | None
args = param_type.__args__ if hasattr(param_type, "__args__") else []
param_type = next((arg for arg in args if arg is not type(None)), param_type)
# Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible)
is_basemodel = issubclass(param_type, BaseModel)
if is_basemodel and param.name == "params":
return False # Use embed=False for OpenAI-compatible endpoints
return True # Use embed=True for everything else
original_params = new_params[1:] # Skip request parameter
new_params = [new_params[0]] # Keep request parameter
for param in original_params:
if param.name in path_params:
new_params.append(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
)
elif get_origin(param.annotation) is Annotated:
new_params.append(param) # Keep existing annotation
else:
embed = get_body_embed_value(param)
new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)]))
route_handler.__signature__ = sig.replace(parameters=new_params)

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
Inference,
Message,
OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.agent_config.model,
messages=openai_messages,
tools=openai_tools if openai_tools else None,
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens,
stream=True,
)
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(

View file

@ -41,6 +41,7 @@ from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenaiChatCompletionRequest,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
@ -130,7 +131,7 @@ class StreamingResponseOrchestrator:
# (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
@ -138,6 +139,7 @@ class StreamingResponseOrchestrator:
temperature=self.ctx.temperature,
response_format=response_format,
)
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response
completion_result_data = None

View file

@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
@ -601,7 +603,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
params = OpenaiChatCompletionRequest(**request.body)
chat_response = await self.inference_api.openai_chat_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -615,7 +618,8 @@ class ReferenceBatchesImpl(Batches):
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
params = OpenAICompletionRequest(**request.body)
completion_response = await self.inference_api.openai_completion(params)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.inference import (
Inference,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion(
params = OpenAICompletionRequest(
model=candidate.model,
prompt=input_content,
**sampling_params,
)
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=candidate.model,
messages=messages,
**sampling_params,
)
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -6,16 +6,16 @@
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def openai_completion(self, *args, **kwargs):
async def openai_completion(
self,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import (
InferenceProvider,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -10,7 +10,13 @@ from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage
from llama_stack.apis.inference import (
Inference,
Message,
OpenaiChatCompletionRequest,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
@ -290,20 +296,21 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage:
return UserMessage(content=self.build_prompt(messages))
def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage:
def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
conversation = []
most_recent_img = None
@ -335,7 +342,7 @@ class LlamaGuardShield:
prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt)
return OpenAIUserMessageParam(role="user", content=prompt)
def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories()
@ -377,11 +384,12 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=self.model,
messages=[shield_input_message],
stream=False,
)
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content
content = content.strip()
return self.get_moderation_object(content)

View file

@ -6,7 +6,7 @@
import re
from typing import Any
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, OpenaiChatCompletionRequest
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer,
)
judge_response = await self.inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=fn_def.params.judge_model,
messages=[
{
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
}
],
)
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.inference import OpenaiChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
params = OpenaiChatCompletionRequest(
model=model,
messages=[message],
stream=False,
)
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content

View file

@ -6,21 +6,20 @@
import json
from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient
from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -135,56 +134,12 @@ class BedrockInferenceAdapter(
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -5,11 +5,14 @@
# the root directory of this source tree.
from collections.abc import Iterable
from typing import Any
from typing import TYPE_CHECKING
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import OpenAICompletion
if TYPE_CHECKING:
from llama_stack.apis.inference import OpenAICompletionRequest
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,25 +46,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: "OpenAICompletionRequest",
) -> OpenAICompletion:
raise NotImplementedError()

View file

@ -3,9 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from typing import TYPE_CHECKING
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
if TYPE_CHECKING:
from llama_stack.apis.inference import OpenAICompletionRequest
from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -34,26 +37,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: "OpenAICompletionRequest",
) -> OpenAICompletion:
raise NotImplementedError()

View file

@ -13,15 +13,14 @@ from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig
@ -80,110 +79,33 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
# Copy params to avoid mutating the original
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client()
model_obj = await self.model_store.get_model(model)
model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
# Copy params to avoid mutating the original
params = params.model_copy()
params.model = model_obj.provider_resource_id
return await client.inference.openai_chat_completion(**params)
request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {}

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -34,56 +35,13 @@ class RunpodInferenceAdapter(OpenAIMixin):
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options:
stream_options = {"include_usage": True}
# Copy params to avoid mutating the original
params = params.model_copy()
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.stream and not params.stream_options:
params.stream_options = {"include_usage": True}
return await super().openai_chat_completion(params)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin
import httpx
@ -15,8 +14,7 @@ from pydantic import ConfigDict
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenaiChatCompletionRequest,
ToolChoice,
)
from llama_stack.log import get_logger
@ -79,61 +77,20 @@ class VLLMInferenceAdapter(OpenAIMixin):
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: "OpenaiChatCompletionRequest",
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens
# Copy params to avoid mutating the original
params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None:
tool_choice = ToolChoice.none.value
if not params.tools and params.tool_choice is not None:
params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion(
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await super().openai_chat_completion(params)

View file

@ -7,7 +7,6 @@
import base64
import struct
from collections.abc import AsyncIterator
from typing import Any
import litellm
@ -17,12 +16,12 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenaiChatCompletionRequest,
OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice,
)
from llama_stack.core.request_headers import NeedsRequestProviderData
@ -227,116 +226,88 @@ class LiteLLMOpenAIMixin(
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model_obj = await self.model_store.get_model(params.model)
# Extract extra fields
extra_body = dict(params.__pydantic_extra__ or {})
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
guided_choice=params.guided_choice,
prompt_logprobs=params.prompt_logprobs,
suffix=params.suffix,
api_key=self.get_api_key(),
api_base=self.api_base,
**extra_body,
)
return await litellm.atext_completion(**params)
return await litellm.atext_completion(**request_params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenaiChatCompletionRequest,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active
from llama_stack.providers.utils.telemetry.tracing import get_current_span
if stream and get_current_span() is not None:
stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None:
stream_options = {"include_usage": True}
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model_obj = await self.model_store.get_model(params.model)
# Extract extra fields
extra_body = dict(params.__pydantic_extra__ or {})
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
api_key=self.get_api_key(),
api_base=self.api_base,
**extra_body,
)
return await litellm.acompletion(**params)
return await litellm.acompletion(**request_params)
async def check_model_availability(self, model: str) -> bool:
"""

View file

@ -8,7 +8,7 @@ import base64
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterable
from typing import Any
from typing import TYPE_CHECKING, Any
from openai import NOT_GIVEN, AsyncOpenAI
from pydantic import BaseModel, ConfigDict
@ -22,8 +22,13 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
if TYPE_CHECKING:
from llama_stack.apis.inference import (
OpenaiChatCompletionRequest,
OpenAICompletionRequest,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
@ -227,96 +232,55 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: "OpenAICompletionRequest",
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
# Handle parameters that are not supported by OpenAI API, but may be by the provider
# prompt_logprobs is supported by vLLM
# guided_choice is supported by vLLM
# TODO: test coverage
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice
# Extract extra fields using Pydantic's built-in __pydantic_extra__
extra_body = dict(params.__pydantic_extra__ or {})
# Add vLLM-specific parameters to extra_body if they are set
# (these are explicitly defined in the model but still go to extra_body)
if params.prompt_logprobs is not None and params.prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = params.prompt_logprobs
if params.guided_choice:
extra_body["guided_choice"] = params.guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
),
extra_body=extra_body,
completion_kwargs = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
frequency_penalty=params.frequency_penalty,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_tokens=params.max_tokens,
n=params.n,
presence_penalty=params.presence_penalty,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
top_p=params.top_p,
user=params.user,
suffix=params.suffix,
)
resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: "OpenaiChatCompletionRequest",
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
messages = params.messages
if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
@ -335,35 +299,38 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages]
params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(params.model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
functions=params.functions,
logit_bias=params.logit_bias,
logprobs=params.logprobs,
max_completion_tokens=params.max_completion_tokens,
max_tokens=params.max_tokens,
n=params.n,
parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=params.presence_penalty,
response_format=params.response_format,
seed=params.seed,
stop=params.stop,
stream=params.stream,
stream_options=params.stream_options,
temperature=params.temperature,
tool_choice=params.tool_choice,
tools=params.tools,
top_logprobs=params.top_logprobs,
top_p=params.top_p,
user=params.user,
)
resp = await self.client.chat.completions.create(**params)
# Extract any additional provider-specific parameters using Pydantic's __pydantic_extra__
if extra_body := dict(params.__pydantic_extra__ or {}):
request_params["extra_body"] = extra_body
resp = await self.client.chat.completions.create(**request_params)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_embeddings(
self,

View file

@ -146,14 +146,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
# For streaming response, collect all chunks
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
)
# Verify the inference API was called with the correct params
call_args = mock_inference_api.openai_chat_completion.call_args
params = call_args.args[0] # params is passed as first positional arg
assert params.model == model
assert params.messages == [
OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)
]
assert params.response_format is None
assert params.tools is None
assert params.stream is True
assert params.temperature == 0.1
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
@ -228,13 +231,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin"
assert second_call.kwargs["temperature"] == 0.1
second_params = second_call.args[0]
assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -309,9 +314,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
@ -386,9 +392,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
@ -435,9 +442,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
@ -485,7 +493,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages):
if isinstance(m.content, str):
assert inference_messages[i].content == m.content
@ -653,7 +663,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 2
@ -691,7 +702,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -751,7 +763,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages
@ -987,8 +1000,9 @@ async def test_create_openai_response_with_text_format(
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] == response_format
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):

View file

@ -13,6 +13,7 @@ import pytest
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenaiChatCompletionRequest,
OpenAIChoice,
ToolChoice,
)
@ -56,13 +57,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
params = OpenaiChatCompletionRequest(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
tools=None,
tool_choice=ToolChoice.auto.value,
)
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
@ -171,9 +173,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
params = OpenaiChatCompletionRequest(
model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
)
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
@ -186,3 +191,48 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_extra_body_forwarding(vllm_inference_adapter):
"""
Test that extra_body parameters (e.g., chat_template_kwargs) are correctly
forwarded to the underlying OpenAI client.
"""
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm-inference")
vllm_inference_adapter.model_store.get_model.return_value = mock_model
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs for Granite thinking mode
params = OpenaiChatCompletionRequest(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await vllm_inference_adapter.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.inference import Model, OpenaiChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -271,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -303,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenaiChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called()