chore: refactor (chat)completions endpoints to use shared params struct (#3761)

# What does this PR do?

Converts openai(_chat)_completions params to pydantic BaseModel to
reduce code duplication across all providers.

## Test Plan
CI









---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/llamastack/llama-stack/pull/3761).
* #3777
* __->__ #3761
This commit is contained in:
ehhuang 2025-10-10 15:46:34 -07:00 committed by GitHub
parent 6954fe2274
commit 80d58ab519
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 599 additions and 890 deletions

View file

@ -23,6 +23,7 @@ from llama_stack.strong_typing.inspection import (
is_generic_list, is_generic_list,
is_type_optional, is_type_optional,
is_type_union, is_type_union,
is_unwrapped_body_param,
unwrap_generic_list, unwrap_generic_list,
unwrap_optional_type, unwrap_optional_type,
unwrap_union_types, unwrap_union_types,
@ -769,24 +770,30 @@ class Generator:
first = next(iter(op.request_params)) first = next(iter(op.request_params))
request_name, request_type = first request_name, request_type = first
op_name = "".join(word.capitalize() for word in op.name.split("_")) # Special case: if there's a single parameter with Body(embed=False) that's a BaseModel,
request_name = f"{op_name}Request" # unwrap it to show the flat structure in the OpenAPI spec
fields = [ # Example: openai_chat_completion()
( if (len(op.request_params) == 1 and is_unwrapped_body_param(request_type)):
name, pass
type_, else:
) op_name = "".join(word.capitalize() for word in op.name.split("_"))
for name, type_ in op.request_params request_name = f"{op_name}Request"
] fields = [
request_type = make_dataclass( (
request_name, name,
fields, type_,
namespace={
"__doc__": create_docstring_for_request(
request_name, fields, doc_params
) )
}, for name, type_ in op.request_params
) ]
request_type = make_dataclass(
request_name,
fields,
namespace={
"__doc__": create_docstring_for_request(
request_name, fields, doc_params
)
},
)
requestBody = RequestBody( requestBody = RequestBody(
content={ content={

View file

@ -8,10 +8,11 @@ import json
import typing import typing
import inspect import inspect
from pathlib import Path from pathlib import Path
from typing import TextIO from typing import Any, List, Optional, TextIO, Union, get_type_hints, get_origin, get_args
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
from pydantic import BaseModel
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
from llama_stack.core.resolver import api_protocol_map from llama_stack.core.resolver import api_protocol_map
from .generator import Generator from .generator import Generator
@ -205,6 +206,14 @@ def _validate_has_return_in_docstring(method) -> str | None:
def _validate_has_params_in_docstring(method) -> str | None: def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method) source = inspect.getsource(method)
sig = inspect.signature(method) sig = inspect.signature(method)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
return
# Only check if the method has more than one parameter # Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source: if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring" return "does not have a ':param' in its docstring"

View file

@ -1527,7 +1527,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiChatCompletionRequest" "$ref": "#/components/schemas/OpenAIChatCompletionRequest"
} }
} }
}, },
@ -1617,7 +1617,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiCompletionRequest" "$ref": "#/components/schemas/OpenAICompletionRequest"
} }
} }
}, },
@ -7522,7 +7522,7 @@
"title": "OpenAIResponseFormatText", "title": "OpenAIResponseFormatText",
"description": "Text response format for OpenAI-compatible chat completion requests." "description": "Text response format for OpenAI-compatible chat completion requests."
}, },
"OpenaiChatCompletionRequest": { "OpenAIChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -7769,7 +7769,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenAIChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -7965,7 +7966,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -8100,10 +8101,12 @@
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}, },
"suffix": { "suffix": {
"type": "string", "type": "string",
@ -8115,7 +8118,8 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenaiCompletionRequest" "title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
}, },
"OpenAICompletion": { "OpenAICompletion": {
"type": "object", "type": "object",

View file

@ -1098,7 +1098,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiChatCompletionRequest' $ref: '#/components/schemas/OpenAIChatCompletionRequest'
required: true required: true
deprecated: true deprecated: true
/v1/openai/v1/chat/completions/{completion_id}: /v1/openai/v1/chat/completions/{completion_id}:
@ -1167,7 +1167,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiCompletionRequest' $ref: '#/components/schemas/OpenAICompletionRequest'
required: true required: true
deprecated: true deprecated: true
/v1/openai/v1/embeddings: /v1/openai/v1/embeddings:
@ -5575,7 +5575,7 @@ components:
title: OpenAIResponseFormatText title: OpenAIResponseFormatText
description: >- description: >-
Text response format for OpenAI-compatible chat completion requests. Text response format for OpenAI-compatible chat completion requests.
OpenaiChatCompletionRequest: OpenAIChatCompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5717,7 +5717,9 @@ components:
required: required:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenAIChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -5883,7 +5885,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5975,8 +5977,14 @@ components:
type: array type: array
items: items:
type: string type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs: prompt_logprobs:
type: integer type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
suffix: suffix:
type: string type: string
description: >- description: >-
@ -5985,7 +5993,9 @@ components:
required: required:
- model - model
- prompt - prompt
title: OpenaiCompletionRequest title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenAICompletion: OpenAICompletion:
type: object type: object
properties: properties:

View file

@ -153,7 +153,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiChatCompletionRequest" "$ref": "#/components/schemas/OpenAIChatCompletionRequest"
} }
} }
}, },
@ -243,7 +243,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiCompletionRequest" "$ref": "#/components/schemas/OpenAICompletionRequest"
} }
} }
}, },
@ -5018,7 +5018,7 @@
"title": "OpenAIResponseFormatText", "title": "OpenAIResponseFormatText",
"description": "Text response format for OpenAI-compatible chat completion requests." "description": "Text response format for OpenAI-compatible chat completion requests."
}, },
"OpenaiChatCompletionRequest": { "OpenAIChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -5265,7 +5265,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenAIChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -5461,7 +5462,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -5596,10 +5597,12 @@
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}, },
"suffix": { "suffix": {
"type": "string", "type": "string",
@ -5611,7 +5614,8 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenaiCompletionRequest" "title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
}, },
"OpenAICompletion": { "OpenAICompletion": {
"type": "object", "type": "object",

View file

@ -98,7 +98,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiChatCompletionRequest' $ref: '#/components/schemas/OpenAIChatCompletionRequest'
required: true required: true
deprecated: false deprecated: false
/v1/chat/completions/{completion_id}: /v1/chat/completions/{completion_id}:
@ -167,7 +167,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiCompletionRequest' $ref: '#/components/schemas/OpenAICompletionRequest'
required: true required: true
deprecated: false deprecated: false
/v1/conversations: /v1/conversations:
@ -3824,7 +3824,7 @@ components:
title: OpenAIResponseFormatText title: OpenAIResponseFormatText
description: >- description: >-
Text response format for OpenAI-compatible chat completion requests. Text response format for OpenAI-compatible chat completion requests.
OpenaiChatCompletionRequest: OpenAIChatCompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -3966,7 +3966,9 @@ components:
required: required:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenAIChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -4132,7 +4134,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -4224,8 +4226,14 @@ components:
type: array type: array
items: items:
type: string type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs: prompt_logprobs:
type: integer type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
suffix: suffix:
type: string type: string
description: >- description: >-
@ -4234,7 +4242,9 @@ components:
required: required:
- model - model
- prompt - prompt
title: OpenaiCompletionRequest title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenAICompletion: OpenAICompletion:
type: object type: object
properties: properties:

View file

@ -153,7 +153,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiChatCompletionRequest" "$ref": "#/components/schemas/OpenAIChatCompletionRequest"
} }
} }
}, },
@ -243,7 +243,7 @@
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/OpenaiCompletionRequest" "$ref": "#/components/schemas/OpenAICompletionRequest"
} }
} }
}, },
@ -7027,7 +7027,7 @@
"title": "OpenAIResponseFormatText", "title": "OpenAIResponseFormatText",
"description": "Text response format for OpenAI-compatible chat completion requests." "description": "Text response format for OpenAI-compatible chat completion requests."
}, },
"OpenaiChatCompletionRequest": { "OpenAIChatCompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -7274,7 +7274,8 @@
"model", "model",
"messages" "messages"
], ],
"title": "OpenaiChatCompletionRequest" "title": "OpenAIChatCompletionRequest",
"description": "Request parameters for OpenAI-compatible chat completion endpoint."
}, },
"OpenAIChatCompletion": { "OpenAIChatCompletion": {
"type": "object", "type": "object",
@ -7470,7 +7471,7 @@
], ],
"title": "OpenAICompletionWithInputMessages" "title": "OpenAICompletionWithInputMessages"
}, },
"OpenaiCompletionRequest": { "OpenAICompletionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
"model": { "model": {
@ -7605,10 +7606,12 @@
"type": "array", "type": "array",
"items": { "items": {
"type": "string" "type": "string"
} },
"description": "(Optional) vLLM-specific parameter for guided generation with a list of choices."
}, },
"prompt_logprobs": { "prompt_logprobs": {
"type": "integer" "type": "integer",
"description": "(Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens."
}, },
"suffix": { "suffix": {
"type": "string", "type": "string",
@ -7620,7 +7623,8 @@
"model", "model",
"prompt" "prompt"
], ],
"title": "OpenaiCompletionRequest" "title": "OpenAICompletionRequest",
"description": "Request parameters for OpenAI-compatible completion endpoint."
}, },
"OpenAICompletion": { "OpenAICompletion": {
"type": "object", "type": "object",

View file

@ -101,7 +101,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiChatCompletionRequest' $ref: '#/components/schemas/OpenAIChatCompletionRequest'
required: true required: true
deprecated: false deprecated: false
/v1/chat/completions/{completion_id}: /v1/chat/completions/{completion_id}:
@ -170,7 +170,7 @@ paths:
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/OpenaiCompletionRequest' $ref: '#/components/schemas/OpenAICompletionRequest'
required: true required: true
deprecated: false deprecated: false
/v1/conversations: /v1/conversations:
@ -5269,7 +5269,7 @@ components:
title: OpenAIResponseFormatText title: OpenAIResponseFormatText
description: >- description: >-
Text response format for OpenAI-compatible chat completion requests. Text response format for OpenAI-compatible chat completion requests.
OpenaiChatCompletionRequest: OpenAIChatCompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5411,7 +5411,9 @@ components:
required: required:
- model - model
- messages - messages
title: OpenaiChatCompletionRequest title: OpenAIChatCompletionRequest
description: >-
Request parameters for OpenAI-compatible chat completion endpoint.
OpenAIChatCompletion: OpenAIChatCompletion:
type: object type: object
properties: properties:
@ -5577,7 +5579,7 @@ components:
- model - model
- input_messages - input_messages
title: OpenAICompletionWithInputMessages title: OpenAICompletionWithInputMessages
OpenaiCompletionRequest: OpenAICompletionRequest:
type: object type: object
properties: properties:
model: model:
@ -5669,8 +5671,14 @@ components:
type: array type: array
items: items:
type: string type: string
description: >-
(Optional) vLLM-specific parameter for guided generation with a list of
choices.
prompt_logprobs: prompt_logprobs:
type: integer type: integer
description: >-
(Optional) vLLM-specific parameter for number of log probabilities to
return for prompt tokens.
suffix: suffix:
type: string type: string
description: >- description: >-
@ -5679,7 +5687,9 @@ components:
required: required:
- model - model
- prompt - prompt
title: OpenaiCompletionRequest title: OpenAICompletionRequest
description: >-
Request parameters for OpenAI-compatible completion endpoint.
OpenAICompletion: OpenAICompletion:
type: object type: object
properties: properties:

View file

@ -14,7 +14,8 @@ from typing import (
runtime_checkable, runtime_checkable,
) )
from pydantic import BaseModel, Field, field_validator from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
@ -1035,6 +1036,118 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
@json_schema_type
class OpenAICompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:param guided_choice: (Optional) vLLM-specific parameter for guided generation with a list of choices.
:param prompt_logprobs: (Optional) vLLM-specific parameter for number of log probabilities to return for prompt tokens.
"""
model_config = ConfigDict(extra="allow")
# Standard OpenAI completion parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
# vLLM-specific parameters (documented here but also allowed via extra fields)
guided_choice: list[str] | None = None
prompt_logprobs: int | None = None
# for fill-in-the-middle type completion
suffix: str | None = None
@json_schema_type
class OpenAIChatCompletionRequest(BaseModel):
"""Request parameters for OpenAI-compatible chat completion endpoint.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
"""
model_config = ConfigDict(extra="allow")
# Standard OpenAI chat completion parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class InferenceProvider(Protocol): class InferenceProvider(Protocol):
@ -1069,52 +1182,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: Annotated[OpenAICompletionRequest, Body(...)],
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
"""Create completion. """Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model. Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:returns: An OpenAICompletion. :returns: An OpenAICompletion.
""" """
... ...
@ -1123,57 +1195,11 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1) @webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: Annotated[OpenAIChatCompletionRequest, Body(...)],
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions. """Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model. Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion. :returns: An OpenAIChatCompletion.
""" """
... ...

View file

@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
setup_logger, setup_logger,
start_trace, start_trace,
) )
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body) body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) body = self._convert_body(matched_func, body, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"}) await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) # Prepare body for the function call (handles both Pydantic and traditional params)
body = self._convert_body(func, body)
trace_path = webmethod.descriptive_name or route_path trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"}) await start_trace(trace_path, {"__location__": "library_client"})
@ -493,17 +495,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body( def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body: if not body:
return {} return {}
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
exclude_params = exclude_params or set() exclude_params = exclude_params or set()
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
if is_unwrapped_body_param(param_type):
base_type = get_args(param_type)[0]
return {param.name: base_type(**body)}
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature
body = {k: v for k, v in body.items() if v is not NOT_GIVEN} body = {k: v for k, v in body.items() if v is not NOT_GIVEN}

View file

@ -10,9 +10,10 @@ from collections.abc import AsyncGenerator, AsyncIterator
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Annotated, Any from typing import Annotated, Any
from fastapi import Body
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -31,15 +32,16 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction, OpenAIChatCompletionToolCallFunction,
OpenAIChoice, OpenAIChoice,
OpenAIChoiceLogprobs, OpenAIChoiceLogprobs,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAICompletionWithInputMessages, OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam,
Order, Order,
StopReason, StopReason,
ToolPromptFormat, ToolPromptFormat,
@ -181,61 +183,23 @@ class InferenceRouter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: Annotated[OpenAICompletionRequest, Body(...)],
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
logger.debug( logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
) )
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if params.stream:
return await provider.openai_completion(**params) return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact # TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params) response = await provider.openai_completion(params)
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens, prompt_tokens=response.usage.prompt_tokens,
@ -254,93 +218,49 @@ class InferenceRouter(Inference):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: Annotated[OpenAIChatCompletionRequest, Body(...)],
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
) )
model_obj = await self._get_model(model, ModelType.llm) model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without # Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface # exposing the OpenAI client itself as part of our API surface
if tool_choice: if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if tools is None: if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools: if params.tools:
for tool in tools: for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none" # Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls # so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None: if params.tool_choice == "none" and params.tools is not None:
tool_choice = None params.tool_choice = None
tools = None params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier) provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if params.stream:
response_stream = await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion # We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat( return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream, response=response_stream,
model=model_obj, model=model_obj,
messages=messages, messages=params.messages,
) )
response = await self._nonstream_openai_chat_completion(provider, params) response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client # Store the response with the ID that will be returned to the client
if self.store: if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages)) asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -396,8 +316,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id) return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.") raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: async def _nonstream_openai_chat_completion(
response = await provider.openai_chat_completion(**params) self, provider: Inference, params: OpenAIChatCompletionRequest
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices: for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses # some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty # but the OpenAI API returns None. So, set tool_calls to None if it's empty

View file

@ -184,7 +184,17 @@ async def lifespan(app: StackApp):
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly # TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False) # If there's a stream parameter at top level, use it
if "stream" in kwargs:
return kwargs["stream"]
# If there's a stream parameter inside a "params" parameter, e.g. openai_chat_completion() use it
if "params" in kwargs:
params = kwargs["params"]
if hasattr(params, "stream"):
return params.stream
return False
async def maybe_await(value): async def maybe_await(value):

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
Inference, Inference,
Message, Message,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionRequest,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
@ -582,7 +583,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens = getattr(sampling_params, "max_tokens", None) max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion # Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=self.agent_config.model, model=self.agent_config.model,
messages=openai_messages, messages=openai_messages,
tools=openai_tools if openai_tools else None, tools=openai_tools if openai_tools else None,
@ -593,6 +594,7 @@ class ChatAgent(ShieldRunnerMixin):
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
) )
openai_stream = await self.inference_api.openai_chat_completion(params)
# Convert OpenAI stream back to Llama Stack format # Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream( response_stream = convert_openai_chat_completion_stream(

View file

@ -49,6 +49,7 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAIChatCompletionToolCall, OpenAIChatCompletionToolCall,
OpenAIChoice, OpenAIChoice,
OpenAIMessageParam, OpenAIMessageParam,
@ -168,7 +169,7 @@ class StreamingResponseOrchestrator:
# (some providers don't support non-empty response_format when tools are present) # (some providers don't support non-empty response_format when tools are present)
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}") logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
completion_result = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=self.ctx.model, model=self.ctx.model,
messages=messages, messages=messages,
tools=self.ctx.chat_tools, tools=self.ctx.chat_tools,
@ -179,6 +180,7 @@ class StreamingResponseOrchestrator:
"include_usage": True, "include_usage": True,
}, },
) )
completion_result = await self.inference_api.openai_chat_completion(params)
# Process streaming chunks and build complete response # Process streaming chunks and build complete response
completion_result_data = None completion_result_data = None

View file

@ -22,6 +22,8 @@ from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionRequest,
OpenAICompletionRequest,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
@ -606,7 +608,8 @@ class ReferenceBatchesImpl(Batches):
# TODO(SECURITY): review body for security issues # TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions": if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body) chat_params = OpenAIChatCompletionRequest(**request.body)
chat_response = await self.inference_api.openai_chat_completion(chat_params)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
@ -620,7 +623,8 @@ class ReferenceBatchesImpl(Batches):
}, },
} }
elif request.url == "/v1/completions": elif request.url == "/v1/completions":
completion_response = await self.inference_api.openai_completion(**request.body) completion_params = OpenAICompletionRequest(**request.body)
completion_response = await self.inference_api.openai_completion(completion_params)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), ( assert hasattr(completion_response, "model_dump_json"), (

View file

@ -12,7 +12,14 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequest,
OpenAICompletionRequest,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -168,11 +175,12 @@ class MetaReferenceEvalImpl(
sampling_params["stop"] = candidate.sampling_params.stop sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value]) input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.openai_completion( params = OpenAICompletionRequest(
model=candidate.model, model=candidate.model,
prompt=input_content, prompt=input_content,
**sampling_params, **sampling_params,
) )
response = await self.inference_api.openai_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].text}) generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
@ -187,11 +195,12 @@ class MetaReferenceEvalImpl(
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"] messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages messages += input_messages
response = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=candidate.model, model=candidate.model,
messages=messages, messages=messages,
**sampling_params, **sampling_params,
) )
response = await self.inference_api.openai_chat_completion(params)
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content}) generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else: else:
raise ValueError("Invalid input row") raise ValueError("Invalid input row")

View file

@ -6,16 +6,16 @@
import asyncio import asyncio
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
OpenAIChatCompletionRequest,
OpenAICompletionRequest,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIMessageParam, OpenAICompletion,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -65,7 +65,10 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def openai_completion(self, *args, **kwargs): async def openai_completion(
self,
params: OpenAICompletionRequest,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by meta reference provider") raise NotImplementedError("OpenAI completion not supported by meta reference provider")
async def should_refresh_models(self) -> bool: async def should_refresh_models(self) -> bool:
@ -150,28 +153,6 @@ class MetaReferenceInferenceImpl(
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider") raise NotImplementedError("OpenAI chat completion not supported by meta-reference inference provider")

View file

@ -5,17 +5,16 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
InferenceProvider, InferenceProvider,
OpenAIChatCompletionRequest,
OpenAICompletionRequest,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -73,56 +72,12 @@ class SentenceTransformersInferenceImpl(
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: OpenAICompletionRequest,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by sentence transformers provider") raise NotImplementedError("OpenAI completion not supported by sentence transformers provider")
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider") raise NotImplementedError("OpenAI chat completion not supported by sentence transformers provider")

View file

@ -10,7 +10,13 @@ from string import Template
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import Inference, Message, UserMessage from llama_stack.apis.inference import (
Inference,
Message,
OpenAIChatCompletionRequest,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
@ -290,20 +296,21 @@ class LlamaGuardShield:
else: else:
shield_input_message = self.build_text_shield_input(messages) shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=self.model, model=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=False, stream=False,
temperature=0.0, # default is 1, which is too high for safety temperature=0.0, # default is 1, which is too high for safety
) )
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content content = response.choices[0].message.content
content = content.strip() content = content.strip()
return self.get_shield_response(content) return self.get_shield_response(content)
def build_text_shield_input(self, messages: list[Message]) -> UserMessage: def build_text_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
return UserMessage(content=self.build_prompt(messages)) return OpenAIUserMessageParam(role="user", content=self.build_prompt(messages))
def build_vision_shield_input(self, messages: list[Message]) -> UserMessage: def build_vision_shield_input(self, messages: list[Message]) -> OpenAIUserMessageParam:
conversation = [] conversation = []
most_recent_img = None most_recent_img = None
@ -335,7 +342,7 @@ class LlamaGuardShield:
prompt.append(most_recent_img) prompt.append(most_recent_img)
prompt.append(self.build_prompt(conversation[::-1])) prompt.append(self.build_prompt(conversation[::-1]))
return UserMessage(content=prompt) return OpenAIUserMessageParam(role="user", content=prompt)
def build_prompt(self, messages: list[Message]) -> str: def build_prompt(self, messages: list[Message]) -> str:
categories = self.get_safety_categories() categories = self.get_safety_categories()
@ -377,11 +384,12 @@ class LlamaGuardShield:
# TODO: Add Image based support for OpenAI Moderations # TODO: Add Image based support for OpenAI Moderations
shield_input_message = self.build_text_shield_input(messages) shield_input_message = self.build_text_shield_input(messages)
response = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=self.model, model=self.model,
messages=[shield_input_message], messages=[shield_input_message],
stream=False, stream=False,
) )
response = await self.inference_api.openai_chat_completion(params)
content = response.choices[0].message.content content = response.choices[0].message.content
content = content.strip() content = content.strip()
return self.get_moderation_object(content) return self.get_moderation_object(content)

View file

@ -6,7 +6,7 @@
import re import re
from typing import Any from typing import Any
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference, OpenAIChatCompletionRequest
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
@ -55,7 +55,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
generated_answer=generated_answer, generated_answer=generated_answer,
) )
judge_response = await self.inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=fn_def.params.judge_model, model=fn_def.params.judge_model,
messages=[ messages=[
{ {
@ -64,6 +64,7 @@ class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
} }
], ],
) )
judge_response = await self.inference_api.openai_chat_completion(params)
content = judge_response.choices[0].message.content content = judge_response.choices[0].message.content
rating_regexes = fn_def.params.judge_score_regexes rating_regexes = fn_def.params.judge_score_regexes

View file

@ -8,7 +8,7 @@
from jinja2 import Template from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIUserMessageParam from llama_stack.apis.inference import OpenAIChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import ( from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
@ -65,11 +65,12 @@ async def llm_rag_query_generator(
model = config.model model = config.model
message = OpenAIUserMessageParam(content=rendered_content) message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion( params = OpenAIChatCompletionRequest(
model=model, model=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )
response = await inference_api.openai_chat_completion(params)
query = response.choices[0].message.content query = response.choices[0].message.content

View file

@ -6,21 +6,20 @@
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
Inference, Inference,
OpenAIChatCompletionRequest,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
@ -135,56 +134,12 @@ class BedrockInferenceAdapter(
async def openai_completion( async def openai_completion(
self, self,
# Standard OpenAI completion parameters params: OpenAICompletionRequest,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")

View file

@ -5,11 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any
from databricks.sdk import WorkspaceClient from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import OpenAICompletion from llama_stack.apis.inference import OpenAICompletion, OpenAICompletionRequest
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -40,25 +39,6 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()

View file

@ -3,9 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from llama_stack.apis.inference.inference import OpenAICompletion, OpenAICompletionRequest, OpenAIEmbeddingsResponse
from llama_stack.apis.inference.inference import OpenAICompletion, OpenAIEmbeddingsResponse
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -31,26 +29,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()

View file

@ -13,15 +13,14 @@ from llama_stack.apis.inference import (
Inference, Inference,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.core.library_client import convert_pydantic_to_json_value from llama_stack.core.library_client import convert_pydantic_to_json_value
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
from .config import PassthroughImplConfig from .config import PassthroughImplConfig
@ -80,110 +79,31 @@ class PassthroughInferenceAdapter(Inference):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params( params = params.model_copy()
model=model_obj.provider_resource_id, params.model = model_obj.provider_resource_id
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
)
return await client.inference.openai_completion(**params) request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_completion(**request_params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
client = self._get_client() client = self._get_client()
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params( params = params.model_copy()
model=model_obj.provider_resource_id, params.model = model_obj.provider_resource_id
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await client.inference.openai_chat_completion(**params) request_params = params.model_dump(exclude_none=True)
return await client.inference.openai_chat_completion(**request_params)
def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]: def cast_value_to_json_dict(self, request_params: dict[str, Any]) -> dict[str, Any]:
json_params = {} json_params = {}

View file

@ -4,11 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any from collections.abc import AsyncIterator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIMessageParam, OpenAIChatCompletion,
OpenAIResponseFormatParam, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -30,56 +31,12 @@ class RunpodInferenceAdapter(OpenAIMixin):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam], ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
):
"""Override to add RunPod-specific stream_options requirement.""" """Override to add RunPod-specific stream_options requirement."""
if stream and not stream_options: params = params.model_copy()
stream_options = {"include_usage": True}
return await super().openai_chat_completion( if params.stream and not params.stream_options:
model=model, params.stream_options = {"include_usage": True}
messages=messages,
frequency_penalty=frequency_penalty, return await super().openai_chat_completion(params)
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urljoin from urllib.parse import urljoin
import httpx import httpx
@ -15,8 +14,7 @@ from pydantic import ConfigDict
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIMessageParam, OpenAIChatCompletionRequest,
OpenAIResponseFormatParam,
ToolChoice, ToolChoice,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -95,61 +93,19 @@ class VLLMInferenceAdapter(OpenAIMixin):
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
max_tokens = max_tokens or self.config.max_tokens params = params.model_copy()
# Apply vLLM-specific defaults
if params.max_tokens is None and self.config.max_tokens:
params.max_tokens = self.config.max_tokens
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3 # This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References: # References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice # * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000 # * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_choice is not None: if not params.tools and params.tool_choice is not None:
tool_choice = ToolChoice.none.value params.tool_choice = ToolChoice.none.value
return await super().openai_chat_completion( return await super().openai_chat_completion(params)
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)

View file

@ -7,7 +7,6 @@
import base64 import base64
import struct import struct
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
import litellm import litellm
@ -17,12 +16,12 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ToolChoice, ToolChoice,
) )
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
@ -227,116 +226,80 @@ class LiteLLMOpenAIMixin(
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(params.model)
params = await prepare_openai_completion_params(
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id), model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt, prompt=params.prompt,
best_of=best_of, best_of=params.best_of,
echo=echo, echo=params.echo,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=params.stream_options,
temperature=temperature, temperature=params.temperature,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
guided_choice=guided_choice, guided_choice=params.guided_choice,
prompt_logprobs=prompt_logprobs, prompt_logprobs=params.prompt_logprobs,
suffix=params.suffix,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
) )
return await litellm.atext_completion(**params) return await litellm.atext_completion(**request_params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
# Add usage tracking for streaming when telemetry is active # Add usage tracking for streaming when telemetry is active
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
if stream and get_current_span() is not None: stream_options = params.stream_options
if params.stream and get_current_span() is not None:
if stream_options is None: if stream_options is None:
stream_options = {"include_usage": True} stream_options = {"include_usage": True}
elif "include_usage" not in stream_options: elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True} stream_options = {**stream_options, "include_usage": True}
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params( model_obj = await self.model_store.get_model(params.model)
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id), model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages, messages=params.messages,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=function_call, function_call=params.function_call,
functions=functions, functions=params.functions,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_completion_tokens=max_completion_tokens, max_completion_tokens=params.max_completion_tokens,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
response_format=response_format, response_format=params.response_format,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=stream_options,
temperature=temperature, temperature=params.temperature,
tool_choice=tool_choice, tool_choice=params.tool_choice,
tools=tools, tools=params.tools,
top_logprobs=top_logprobs, top_logprobs=params.top_logprobs,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
api_key=self.get_api_key(), api_key=self.get_api_key(),
api_base=self.api_base, api_base=self.api_base,
) )
return await litellm.acompletion(**params) return await litellm.acompletion(**request_params)
async def check_model_availability(self, model: str) -> bool: async def check_model_availability(self, model: str) -> bool:
""" """

View file

@ -17,12 +17,13 @@ from llama_stack.apis.inference import (
Model, Model,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAIChatCompletionRequest,
OpenAICompletion, OpenAICompletion,
OpenAICompletionRequest,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.core.request_headers import NeedsRequestProviderData
@ -222,26 +223,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
async def openai_completion( async def openai_completion(
self, self,
model: str, params: OpenAICompletionRequest,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
""" """
Direct OpenAI completion API call. Direct OpenAI completion API call.
@ -251,67 +233,45 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# guided_choice is supported by vLLM # guided_choice is supported by vLLM
# TODO: test coverage # TODO: test coverage
extra_body: dict[str, Any] = {} extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0: if params.prompt_logprobs is not None and params.prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs extra_body["prompt_logprobs"] = params.prompt_logprobs
if guided_choice: if params.guided_choice:
extra_body["guided_choice"] = guided_choice extra_body["guided_choice"] = params.guided_choice
# TODO: fix openai_completion to return type compatible with OpenAI's API response # TODO: fix openai_completion to return type compatible with OpenAI's API response
resp = await self.client.completions.create( completion_kwargs = await prepare_openai_completion_params(
**await prepare_openai_completion_params( model=await self._get_provider_model_id(params.model),
model=await self._get_provider_model_id(model), prompt=params.prompt,
prompt=prompt, best_of=params.best_of,
best_of=best_of, echo=params.echo,
echo=echo, frequency_penalty=params.frequency_penalty,
frequency_penalty=frequency_penalty, logit_bias=params.logit_bias,
logit_bias=logit_bias, logprobs=params.logprobs,
logprobs=logprobs, max_tokens=params.max_tokens,
max_tokens=max_tokens, n=params.n,
n=n, presence_penalty=params.presence_penalty,
presence_penalty=presence_penalty, seed=params.seed,
seed=seed, stop=params.stop,
stop=stop, stream=params.stream,
stream=stream, stream_options=params.stream_options,
stream_options=stream_options, temperature=params.temperature,
temperature=temperature, top_p=params.top_p,
top_p=top_p, user=params.user,
user=user, suffix=params.suffix,
suffix=suffix,
),
extra_body=extra_body,
) )
resp = await self.client.completions.create(**completion_kwargs, extra_body=extra_body)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, params: OpenAIChatCompletionRequest,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
""" """
Direct OpenAI chat completion API call. Direct OpenAI chat completion API call.
""" """
messages = params.messages
if self.download_images: if self.download_images:
async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam: async def _localize_image_url(m: OpenAIMessageParam) -> OpenAIMessageParam:
@ -330,35 +290,35 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
messages = [await _localize_image_url(m) for m in messages] messages = [await _localize_image_url(m) for m in messages]
params = await prepare_openai_completion_params( request_params = await prepare_openai_completion_params(
model=await self._get_provider_model_id(model), model=await self._get_provider_model_id(params.model),
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=params.frequency_penalty,
function_call=function_call, function_call=params.function_call,
functions=functions, functions=params.functions,
logit_bias=logit_bias, logit_bias=params.logit_bias,
logprobs=logprobs, logprobs=params.logprobs,
max_completion_tokens=max_completion_tokens, max_completion_tokens=params.max_completion_tokens,
max_tokens=max_tokens, max_tokens=params.max_tokens,
n=n, n=params.n,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=params.parallel_tool_calls,
presence_penalty=presence_penalty, presence_penalty=params.presence_penalty,
response_format=response_format, response_format=params.response_format,
seed=seed, seed=params.seed,
stop=stop, stop=params.stop,
stream=stream, stream=params.stream,
stream_options=stream_options, stream_options=params.stream_options,
temperature=temperature, temperature=params.temperature,
tool_choice=tool_choice, tool_choice=params.tool_choice,
tools=tools, tools=params.tools,
top_logprobs=top_logprobs, top_logprobs=params.top_logprobs,
top_p=top_p, top_p=params.top_p,
user=user, user=params.user,
) )
resp = await self.client.chat.completions.create(**params) resp = await self.client.chat.completions.create(**request_params)
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] return await self._maybe_overwrite_id(resp, params.stream) # type: ignore[no-any-return]
async def openai_embeddings( async def openai_embeddings(
self, self,

View file

@ -50,6 +50,10 @@ if sys.version_info >= (3, 10):
else: else:
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
from pydantic import BaseModel
from pydantic.fields import FieldInfo
S = TypeVar("S") S = TypeVar("S")
T = TypeVar("T") T = TypeVar("T")
K = TypeVar("K") K = TypeVar("K")
@ -570,7 +574,8 @@ def get_class_properties(typ: type) -> Iterable[Tuple[str, type | str]]:
elif hasattr(typ, "model_fields"): elif hasattr(typ, "model_fields"):
# Pydantic BaseModel - use model_fields to exclude ClassVar and other non-field attributes # Pydantic BaseModel - use model_fields to exclude ClassVar and other non-field attributes
# Reconstruct Annotated type if discriminator exists to preserve metadata # Reconstruct Annotated type if discriminator exists to preserve metadata
from typing import Annotated, Any, cast from typing import Annotated, Any
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
def get_field_type(name: str, field: Any) -> type | str: def get_field_type(name: str, field: Any) -> type | str:
@ -1049,3 +1054,32 @@ def check_recursive(
pred = lambda typ, obj: True # noqa: E731 pred = lambda typ, obj: True # noqa: E731
return RecursiveChecker(pred).check(type(obj), obj) return RecursiveChecker(pred).check(type(obj), obj)
def is_unwrapped_body_param(param_type: Any) -> bool:
"""
Check if a parameter type represents an unwrapped body parameter.
An unwrapped body parameter is an Annotated type with Body(embed=False)
This is used to determine whether request parameters should be flattened
in OpenAPI specs and client libraries (matching FastAPI's embed=False behavior).
Args:
param_type: The parameter type annotation to check
Returns:
True if the parameter should be treated as an unwrapped body parameter
"""
# Check if it's Annotated with Body(embed=False)
if typing.get_origin(param_type) is Annotated:
args = typing.get_args(param_type)
base_type = args[0]
metadata = args[1:]
# Look for Body annotation with embed=False
# Body() returns a FieldInfo object, so we check for that type and the embed attribute
for item in metadata:
if isinstance(item, FieldInfo) and hasattr(item, "embed") and not item.embed:
return inspect.isclass(base_type) and issubclass(base_type, BaseModel)
return False

View file

@ -33,6 +33,7 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequest,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIJSONSchema, OpenAIJSONSchema,
OpenAIResponseFormatJSONObject, OpenAIResponseFormatJSONObject,
@ -161,15 +162,17 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
chunks = [chunk async for chunk in result] chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with( mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model, OpenAIChatCompletionRequest(
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)], model=model,
response_format=None, messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
tools=None, response_format=None,
stream=True, tools=None,
temperature=0.1, stream=True,
stream_options={ temperature=0.1,
"include_usage": True, stream_options={
}, "include_usage": True,
},
)
) )
# Should have content part events for text streaming # Should have content part events for text streaming
@ -256,13 +259,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify # Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?" first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1] second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin" second_params = second_call.args[0]
assert second_call.kwargs["temperature"] == 0.1 assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search") openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with( openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -348,9 +353,10 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Verify inference API was called correctly (after iterating over result) # Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == input_text
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output) # Check response.created event (should have empty output)
assert len(chunks[0].response.output) == 0 assert len(chunks[0].response.output) == 0
@ -394,9 +400,10 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
def assert_common_expectations(chunks) -> None: def assert_common_expectations(chunks) -> None:
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["tools"] is not None assert first_params.messages[0].content == input_text
assert first_call.kwargs["temperature"] == 0.1 assert first_params.tools is not None
assert first_params.temperature == 0.1
assert len(chunks[0].response.output) == 0 assert len(chunks[0].response.output) == 0
completed_chunk = chunks[-1] completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed" assert completed_chunk.type == "response.completed"
@ -512,7 +519,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e. # Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects # All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"] call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages): for i, m in enumerate(input_messages):
if isinstance(m.content, str): if isinstance(m.content, str):
assert inference_messages[i].content == m.content assert inference_messages[i].content == m.content
@ -680,7 +689,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 2 assert len(sent_messages) == 2
@ -718,7 +728,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -778,7 +789,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify # Verify
mock_inference_api.openai_chat_completion.assert_called_once() mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"] params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message # Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages assert len(sent_messages) == 4, sent_messages
@ -1018,7 +1030,8 @@ async def test_reuse_mcp_tool_list(
) )
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2 assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1] second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
tools_seen = second_call.kwargs["tools"] second_params = second_call.args[0]
tools_seen = second_params.tools
assert len(tools_seen) == 1 assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool" assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool" assert tools_seen[0]["function"]["description"] == "a test tool"
@ -1065,8 +1078,9 @@ async def test_create_openai_response_with_text_format(
# Verify # Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0] first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text first_params = first_call.args[0]
assert first_call.kwargs["response_format"] == response_format assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):

View file

@ -13,6 +13,7 @@ import pytest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionRequest,
OpenAIChoice, OpenAIChoice,
ToolChoice, ToolChoice,
) )
@ -56,13 +57,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client mock_client_property.return_value = mock_client
# No tools but auto tool choice # No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion( params = OpenAIChatCompletionRequest(
"mock-model", model="mock-model",
[], messages=[{"role": "user", "content": "test"}],
stream=False, stream=False,
tools=None, tools=None,
tool_choice=ToolChoice.auto.value, tool_choice=ToolChoice.auto.value,
) )
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called() mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions # Ensure tool_choice gets converted to none for older vLLM versions
@ -171,9 +173,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
) )
async def do_inference(): async def do_inference():
await vllm_inference_adapter.openai_chat_completion( params = OpenAIChatCompletionRequest(
"mock-model", messages=["one fish", "two fish"], stream=False model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
) )
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client: with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock() mock_client = MagicMock()

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam from llama_stack.apis.inference import Model, OpenAIChatCompletionRequest, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -271,7 +271,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg") mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message]) params = OpenAIChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg") mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -303,7 +304,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client): with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize: with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message]) params = OpenAIChatCompletionRequest(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called() mock_localize.assert_not_called()