mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 21:48:36 +00:00
chore: refactor (chat)completions endpoints to use shared params struct (#3761)
# What does this PR do? Converts openai(_chat)_completions params to pydantic BaseModel to reduce code duplication across all providers. ## Test Plan CI --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/llamastack/llama-stack/pull/3761). * #3777 * __->__ #3761
This commit is contained in:
parent
6954fe2274
commit
80d58ab519
33 changed files with 599 additions and 890 deletions
|
@ -54,6 +54,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
@ -383,7 +384,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -446,7 +447,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
@ -493,17 +495,20 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
|
||||
) -> dict:
|
||||
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
params_list = [p for p in sig.parameters.values() if p.name != "self"]
|
||||
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
|
||||
if len(params_list) == 1:
|
||||
param = params_list[0]
|
||||
param_type = param.annotation
|
||||
if is_unwrapped_body_param(param_type):
|
||||
base_type = get_args(param_type)[0]
|
||||
return {param.name: base_type(**body)}
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue