# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
Completes the refactoring started in previous commit by:

1. **Fix library client** (critical): Add logic to detect Pydantic model parameters
   and construct them properly from request bodies. The key fix is to NOT exclude
   any params when converting the body for Pydantic models - we need all fields
   to pass to the Pydantic constructor.

   Before: _convert_body excluded all params, leaving body empty for Pydantic construction
   After: Check for Pydantic params first, skip exclusion, construct model with full body

2. **Update remaining providers** to use new Pydantic-based signatures:
   - litellm_openai_mixin: Extract extra fields via __pydantic_extra__
   - databricks: Use TYPE_CHECKING import for params type
   - llama_openai_compat: Use TYPE_CHECKING import for params type
   - sentence_transformers: Update method signatures to use params

3. **Update unit tests** to use new Pydantic signature:
   - test_openai_mixin.py: Use OpenAIChatCompletionRequestParams

This fixes test failures where the library client was trying to construct
Pydantic models with empty dictionaries.
The previous fix had a bug: it called _convert_body() which only keeps fields
that match function parameter names. For Pydantic methods with signature:
  openai_chat_completion(params: OpenAIChatCompletionRequestParams)

The signature only has 'params', but the body has 'model', 'messages', etc.
So _convert_body() returned an empty dict.

Fix: Skip _convert_body() entirely for Pydantic params. Use the raw body
directly to construct the Pydantic model (after stripping NOT_GIVENs).

This properly fixes the ValidationError where required fields were missing.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
This commit is contained in:
Eric Huang 2025-10-09 13:53:31 -07:00
parent 26fd5dbd34
commit a93130e323
295 changed files with 51966 additions and 3051 deletions

View file

@ -14,7 +14,7 @@ from typing import (
runtime_checkable,
)
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
@ -995,6 +995,81 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list"
@json_schema_type
class OpenAICompletionRequestParams(BaseModel):
"""Request parameters for OpenAI-compatible completion endpoint.
This model uses extra="allow" to capture provider-specific parameters
(like vLLM's guided_choice) which are passed through as extra_body.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
prompt: str | list[str] | list[int] | list[list[int]]
# Standard OpenAI completion parameters
best_of: int | None = None
echo: bool | None = None
frequency_penalty: float | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_tokens: int | None = None
n: int | None = None
presence_penalty: float | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
top_p: float | None = None
user: str | None = None
suffix: str | None = None
# vLLM-specific parameters (documented here but also allowed via extra fields)
guided_choice: list[str] | None = None
prompt_logprobs: int | None = None
@json_schema_type
class OpenAIChatCompletionRequestParams(BaseModel):
"""Request parameters for OpenAI-compatible chat completion endpoint.
This model uses extra="allow" to capture provider-specific parameters
which are passed through as extra_body.
"""
model_config = ConfigDict(extra="allow")
# Required parameters
model: str
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)]
# Standard OpenAI chat completion parameters
frequency_penalty: float | None = None
function_call: str | dict[str, Any] | None = None
functions: list[dict[str, Any]] | None = None
logit_bias: dict[str, float] | None = None
logprobs: bool | None = None
max_completion_tokens: int | None = None
max_tokens: int | None = None
n: int | None = None
parallel_tool_calls: bool | None = None
presence_penalty: float | None = None
response_format: OpenAIResponseFormatParam | None = None
seed: int | None = None
stop: str | list[str] | None = None
stream: bool | None = None
stream_options: dict[str, Any] | None = None
temperature: float | None = None
tool_choice: str | dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
top_logprobs: int | None = None
top_p: float | None = None
user: str | None = None
@runtime_checkable
@trace_protocol
class InferenceProvider(Protocol):
@ -1029,52 +1104,14 @@ class InferenceProvider(Protocol):
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
# Standard OpenAI completion parameters
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
# vLLM-specific parameters
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
# for fill-in-the-middle type completion
suffix: str | None = None,
params: OpenAICompletionRequestParams,
) -> OpenAICompletion:
"""Create completion.
Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:param suffix: (Optional) The suffix that should be appended to the completion.
:param params: Request parameters including model, prompt, and optional parameters.
Use params.get_extra_body() to extract provider-specific parameters.
:returns: An OpenAICompletion.
"""
...
@ -1083,58 +1120,15 @@ class InferenceProvider(Protocol):
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestParams,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""Create chat completions.
Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
:param params: Request parameters including model, messages, and optional parameters.
Use params.get_extra_body() to extract provider-specific parameters (e.g., chat_template_kwargs for vLLM).
:returns: An OpenAIChatCompletion or stream of OpenAIChatCompletionChunk.
"""
...