mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 13:44:38 +00:00
test
# What does this PR do? ## Test Plan
This commit is contained in:
parent
f50ce11a3b
commit
4a3d1e33f8
31 changed files with 727 additions and 892 deletions
|
@ -13,12 +13,13 @@ import logging # allow-direct-logging
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, get_origin
|
||||
from typing import Annotated, Any, Union, get_origin
|
||||
|
||||
import httpx
|
||||
import rich.pretty
|
||||
|
@ -177,7 +178,17 @@ async def lifespan(app: StackApp):
|
|||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
# Check for stream parameter at top level (old API style)
|
||||
if "stream" in kwargs:
|
||||
return kwargs["stream"]
|
||||
|
||||
# Check for stream parameter inside Pydantic request params (new API style)
|
||||
if "params" in kwargs:
|
||||
params = kwargs["params"]
|
||||
if hasattr(params, "stream"):
|
||||
return params.stream
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
|
@ -282,21 +293,42 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
if method == "post":
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||
# but preserve existing File() and Form() annotations for multipart form data
|
||||
new_params = (
|
||||
[new_params[0]]
|
||||
+ [
|
||||
(
|
||||
def get_body_embed_value(param: inspect.Parameter) -> bool:
|
||||
"""Determine if Body should use embed=True or embed=False.
|
||||
|
||||
For OpenAI-compatible endpoints (param name is 'params'), use embed=False
|
||||
so the request body is parsed directly as the model (not nested).
|
||||
This allows OpenAI clients to send standard OpenAI format.
|
||||
For other endpoints, use embed=True for SDK compatibility.
|
||||
"""
|
||||
# Get the actual type, stripping Optional/Union if present
|
||||
param_type = param.annotation
|
||||
origin = get_origin(param_type)
|
||||
# Check for Union types (both typing.Union and types.UnionType for | syntax)
|
||||
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
|
||||
# Handle Optional[T] / T | None
|
||||
args = param_type.__args__ if hasattr(param_type, "__args__") else []
|
||||
param_type = next((arg for arg in args if arg is not type(None)), param_type)
|
||||
|
||||
# Check if it's a Pydantic BaseModel and param name is 'params' (OpenAI-compatible)
|
||||
is_basemodel = issubclass(param_type, BaseModel)
|
||||
if is_basemodel and param.name == "params":
|
||||
return False # Use embed=False for OpenAI-compatible endpoints
|
||||
return True # Use embed=True for everything else
|
||||
|
||||
original_params = new_params[1:] # Skip request parameter
|
||||
new_params = [new_params[0]] # Keep request parameter
|
||||
|
||||
for param in original_params:
|
||||
if param.name in path_params:
|
||||
new_params.append(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else (
|
||||
param # Keep original annotation if it's already an Annotated type
|
||||
if get_origin(param.annotation) is Annotated
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
)
|
||||
elif get_origin(param.annotation) is Annotated:
|
||||
new_params.append(param) # Keep existing annotation
|
||||
else:
|
||||
embed = get_body_embed_value(param)
|
||||
new_params.append(param.replace(annotation=Annotated[param.annotation, Body(..., embed=embed)]))
|
||||
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue