# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
# What does this PR do?


## Test Plan
Completes the refactoring started in previous commit by:

1. **Fix library client** (critical): Add logic to detect Pydantic model parameters
   and construct them properly from request bodies. The key fix is to NOT exclude
   any params when converting the body for Pydantic models - we need all fields
   to pass to the Pydantic constructor.

   Before: _convert_body excluded all params, leaving body empty for Pydantic construction
   After: Check for Pydantic params first, skip exclusion, construct model with full body

2. **Update remaining providers** to use new Pydantic-based signatures:
   - litellm_openai_mixin: Extract extra fields via __pydantic_extra__
   - databricks: Use TYPE_CHECKING import for params type
   - llama_openai_compat: Use TYPE_CHECKING import for params type
   - sentence_transformers: Update method signatures to use params

3. **Update unit tests** to use new Pydantic signature:
   - test_openai_mixin.py: Use OpenAIChatCompletionRequestParams

This fixes test failures where the library client was trying to construct
Pydantic models with empty dictionaries.
The previous fix had a bug: it called _convert_body() which only keeps fields
that match function parameter names. For Pydantic methods with signature:
  openai_chat_completion(params: OpenAIChatCompletionRequestParams)

The signature only has 'params', but the body has 'model', 'messages', etc.
So _convert_body() returned an empty dict.

Fix: Skip _convert_body() entirely for Pydantic params. Use the raw body
directly to construct the Pydantic model (after stripping NOT_GIVENs).

This properly fixes the ValidationError where required fields were missing.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
The streaming code path (_call_streaming) had the same issue as non-streaming:
it called _convert_body() which returned empty dict for Pydantic params.

Applied the same fix as commit 7476c0ae:
- Detect Pydantic model parameters before body conversion
- Skip _convert_body() for Pydantic params
- Construct Pydantic model directly from raw body (after stripping NOT_GIVENs)

This fixes streaming endpoints like openai_chat_completion with stream=True.
This commit is contained in:
Eric Huang 2025-10-09 13:53:17 -07:00
parent 9e9a827fcd
commit a70fc60485
295 changed files with 51966 additions and 3051 deletions

View file

@ -363,6 +363,56 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
return body, field_names
def _prepare_request_body(
self, func: Any, body: dict, path: str, method: str, exclude_params: set[str] | None = None
) -> dict:
"""Prepare request body by converting to Pydantic models or traditional parameters.
For endpoints with a single Pydantic parameter, constructs the model from the body.
For traditional endpoints, converts body to match function parameters.
Args:
func: The function to call
body: The request body
path: The request path
method: The HTTP method
exclude_params: Parameters to exclude from conversion
Returns:
The prepared body dict ready to pass to the function
"""
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Check if the method expects a single Pydantic model parameter
is_pydantic_param = False
if len(params_list) == 1:
param = params_list[0]
param_type = param.annotation
try:
if isinstance(param_type, type) and issubclass(param_type, BaseModel):
is_pydantic_param = True
except (TypeError, AttributeError):
pass
# For Pydantic models, use the raw body directly to construct the model
# For traditional methods, convert body to match function parameters
if is_pydantic_param:
param = params_list[0]
param_type = param.annotation
# Strip NOT_GIVENs before passing to Pydantic
clean_body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
# If the body has a single key matching the parameter name, unwrap it
# This handles cases where the client passes agent_config={...} and we need
# to construct AgentConfig from the inner dict, not {"agent_config": {...}}
if len(clean_body) == 1 and param.name in clean_body:
clean_body = clean_body[param.name]
return {param.name: param_type(**clean_body)}
else:
return self._convert_body(path, method, body, exclude_params=exclude_params)
async def _call_non_streaming(
self,
*,
@ -383,7 +433,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._prepare_request_body(matched_func, body, path, options.method, exclude_params=set(field_names))
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
@ -446,7 +497,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
# Prepare body for the function call (handles both Pydantic and traditional params)
body = self._prepare_request_body(func, body, path, options.method)
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})

View file

@ -31,15 +31,16 @@ from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestParams,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIChoiceLogprobs,
OpenAICompletion,
OpenAICompletionRequestParams,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
StopReason,
ToolPromptFormat,
@ -181,61 +182,23 @@ class InferenceRouter(Inference):
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
params: OpenAICompletionRequestParams,
) -> OpenAICompletion:
logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
)
model_obj = await self._get_model(model, ModelType.llm)
params = dict(
model=model_obj.identifier,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
suffix=suffix,
f"InferenceRouter.openai_completion: model={params.model}, stream={params.stream}, prompt={params.prompt}",
)
model_obj = await self._get_model(params.model, ModelType.llm)
# Update params with the resolved model identifier
params.model = model_obj.identifier
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_completion(**params)
if params.stream:
return await provider.openai_completion(params)
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
# response_stream = await provider.openai_completion(**params)
response = await provider.openai_completion(**params)
response = await provider.openai_completion(params)
if self.telemetry:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
@ -254,93 +217,49 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: Annotated[list[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
params: OpenAIChatCompletionRequestParams,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
f"InferenceRouter.openai_chat_completion: model={params.model}, stream={params.stream}, messages={params.messages}",
)
model_obj = await self._get_model(model, ModelType.llm)
model_obj = await self._get_model(params.model, ModelType.llm)
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
if params.tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(params.tool_choice)
if params.tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
if params.tools:
for tool in params.tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
if params.tool_choice == "none" and params.tools is not None:
params.tool_choice = None
params.tools = None
# Update params with the resolved model identifier
params.model = model_obj.identifier
params = dict(
model=model_obj.identifier,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
provider = await self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
response_stream = await provider.openai_chat_completion(**params)
if params.stream:
response_stream = await provider.openai_chat_completion(params)
# For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk]
# We need to add metrics to each chunk and store the final completion
return self.stream_tokens_and_compute_metrics_openai_chat(
response=response_stream,
model=model_obj,
messages=messages,
messages=params.messages,
)
response = await self._nonstream_openai_chat_completion(provider, params)
# Store the response with the ID that will be returned to the client
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, messages))
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
if self.telemetry:
metrics = self._construct_metrics(
@ -396,8 +315,10 @@ class InferenceRouter(Inference):
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
async def _nonstream_openai_chat_completion(
self, provider: Inference, params: OpenAIChatCompletionRequestParams
) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty

View file

@ -268,21 +268,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
if method == "post":
# Annotate parameters that are in the path with Path(...) and others with Body(...),
# but preserve existing File() and Form() annotations for multipart form data
new_params = (
[new_params[0]]
+ [
(
def get_body_embed_value(param: inspect.Parameter) -> bool:
"""Determine if Body should use embed=True or embed=False.
For Pydantic BaseModel subclasses, use embed=False so the request body
is parsed directly as the model (not nested under param name).
For other types, use embed=True.
"""
# Get the actual type, stripping Optional/Union if present
param_type = param.annotation
if get_origin(param_type) in (type(None) | type, type | type(None)):
# Handle Optional[T] / T | None
args = param_type.__args__ if hasattr(param_type, '__args__') else []
param_type = next((arg for arg in args if arg is not type(None)), param_type)
# Check if it's a Pydantic BaseModel
try:
return not (isinstance(param_type, type) and issubclass(param_type, BaseModel))
except TypeError:
# Not a class, use default embed=True
return True
original_params = new_params[1:] # Skip request parameter
new_params = [new_params[0]] # Keep request parameter
for param in original_params:
if param.name in path_params:
new_params.append(
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
if param.name in path_params
else (
param # Keep original annotation if it's already an Annotated type
if get_origin(param.annotation) is Annotated
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
)
)
for param in new_params[1:]
]
)
elif get_origin(param.annotation) is Annotated:
new_params.append(param) # Keep existing annotation
else:
embed = get_body_embed_value(param)
new_params.append(
param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)])
)
route_handler.__signature__ = sig.replace(parameters=new_params)