mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 16:57:21 +00:00
## Summary Introduce `ExtraBodyField` annotation to enable parameters that arrive via extra_body in client SDKs but are accessible server-side with full typing. These parameters are documented in OpenAPI specs under **`x-llama-stack-extra-body-params`** but excluded from generated SDK signatures. Add `shields` parameter to `create_openai_response` as the first implementation using this pattern. ## Test Plan - added an integration test which checks that shields parameter passed via extra_body reaches server implementation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
463 lines
17 KiB
Python
463 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import collections.abc
|
|
import enum
|
|
import inspect
|
|
import typing
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
|
|
|
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA, LLAMA_STACK_API_V1ALPHA
|
|
|
|
from termcolor import colored
|
|
|
|
from llama_stack.strong_typing.inspection import get_signature
|
|
|
|
from typing import get_origin, get_args
|
|
|
|
from fastapi import UploadFile
|
|
from fastapi.params import File, Form
|
|
from typing import Annotated
|
|
|
|
from llama_stack.schema_utils import ExtraBodyField
|
|
|
|
|
|
def split_prefix(
|
|
s: str, sep: str, prefix: Union[str, Iterable[str]]
|
|
) -> Tuple[Optional[str], str]:
|
|
"""
|
|
Recognizes a prefix at the beginning of a string.
|
|
|
|
:param s: The string to check.
|
|
:param sep: A separator between (one of) the prefix(es) and the rest of the string.
|
|
:param prefix: A string or a set of strings to identify as a prefix.
|
|
:return: A tuple of the recognized prefix (if any) and the rest of the string excluding the separator (or the entire string).
|
|
"""
|
|
|
|
if isinstance(prefix, str):
|
|
if s.startswith(prefix + sep):
|
|
return prefix, s[len(prefix) + len(sep) :]
|
|
else:
|
|
return None, s
|
|
|
|
for p in prefix:
|
|
if s.startswith(p + sep):
|
|
return p, s[len(p) + len(sep) :]
|
|
|
|
return None, s
|
|
|
|
|
|
def _get_annotation_type(annotation: Union[type, str], callable: Callable) -> type:
|
|
"Maps a stringized reference to a type, as if using `from __future__ import annotations`."
|
|
|
|
if isinstance(annotation, str):
|
|
return eval(annotation, callable.__globals__)
|
|
else:
|
|
return annotation
|
|
|
|
|
|
class HTTPMethod(enum.Enum):
|
|
"HTTP method used to invoke an endpoint operation."
|
|
|
|
GET = "GET"
|
|
POST = "POST"
|
|
PUT = "PUT"
|
|
DELETE = "DELETE"
|
|
PATCH = "PATCH"
|
|
|
|
|
|
OperationParameter = Tuple[str, type]
|
|
|
|
|
|
class ValidationError(TypeError):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class EndpointOperation:
|
|
"""
|
|
Type information and metadata associated with an endpoint operation.
|
|
|
|
"param defining_class: The most specific class that defines the endpoint operation.
|
|
:param name: The short name of the endpoint operation.
|
|
:param func_name: The name of the function to invoke when the operation is triggered.
|
|
:param func_ref: The callable to invoke when the operation is triggered.
|
|
:param route: A custom route string assigned to the operation.
|
|
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
|
|
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
|
|
:param request_params: The parameter that corresponds to the data transmitted in the request body.
|
|
:param multipart_params: Parameters that indicate multipart/form-data request body.
|
|
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
|
|
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
|
|
:param response_type: The Python type of the data that is transmitted in the response body.
|
|
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
|
|
:param public: True if the operation can be invoked without prior authentication.
|
|
:param request_examples: Sample requests that the operation might take.
|
|
:param response_examples: Sample responses that the operation might produce.
|
|
"""
|
|
|
|
defining_class: type
|
|
name: str
|
|
func_name: str
|
|
func_ref: Callable[..., Any]
|
|
route: Optional[str]
|
|
path_params: List[OperationParameter]
|
|
query_params: List[OperationParameter]
|
|
request_params: Optional[OperationParameter]
|
|
multipart_params: List[OperationParameter]
|
|
extra_body_params: List[tuple[str, type, str | None]]
|
|
event_type: Optional[type]
|
|
response_type: type
|
|
http_method: HTTPMethod
|
|
public: bool
|
|
request_examples: Optional[List[Any]] = None
|
|
response_examples: Optional[List[Any]] = None
|
|
|
|
def get_route(self, webmethod) -> str:
|
|
api_level = webmethod.level
|
|
|
|
if self.route is not None:
|
|
return "/".join(["", api_level, self.route.lstrip("/")])
|
|
|
|
route_parts = ["", api_level, self.name]
|
|
for param_name, _ in self.path_params:
|
|
route_parts.append("{" + param_name + "}")
|
|
return "/".join(route_parts)
|
|
|
|
|
|
class _FormatParameterExtractor:
|
|
"A visitor to exract parameters in a format string."
|
|
|
|
keys: List[str]
|
|
|
|
def __init__(self) -> None:
|
|
self.keys = []
|
|
|
|
def __getitem__(self, key: str) -> None:
|
|
self.keys.append(key)
|
|
return None
|
|
|
|
|
|
def _get_route_parameters(route: str) -> List[str]:
|
|
extractor = _FormatParameterExtractor()
|
|
# Replace all occurrences of ":path" with empty string
|
|
route = route.replace(":path", "")
|
|
route.format_map(extractor)
|
|
return extractor.keys
|
|
|
|
|
|
def _get_endpoint_functions(
|
|
endpoint: type, prefixes: List[str]
|
|
) -> Iterator[Tuple[str, str, str, Callable]]:
|
|
if not inspect.isclass(endpoint):
|
|
raise ValueError(f"object is not a class type: {endpoint}")
|
|
|
|
functions = inspect.getmembers(endpoint, inspect.isfunction)
|
|
for func_name, func_ref in functions:
|
|
webmethods = []
|
|
|
|
# Check for multiple webmethods (stacked decorators)
|
|
if hasattr(func_ref, "__webmethods__"):
|
|
webmethods = func_ref.__webmethods__
|
|
|
|
if not webmethods:
|
|
continue
|
|
|
|
for webmethod in webmethods:
|
|
print(f"Processing {colored(func_name, 'white')}...")
|
|
operation_name = func_name
|
|
|
|
if webmethod.method == "GET":
|
|
prefix = "get"
|
|
elif webmethod.method == "DELETE":
|
|
prefix = "delete"
|
|
elif webmethod.method == "POST":
|
|
prefix = "post"
|
|
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
|
|
prefix = "get"
|
|
elif (
|
|
operation_name.startswith("delete_")
|
|
or operation_name.startswith("remove_")
|
|
or operation_name.endswith("/delete")
|
|
or operation_name.endswith("/remove")
|
|
):
|
|
prefix = "delete"
|
|
else:
|
|
# by default everything else is a POST
|
|
prefix = "post"
|
|
|
|
yield prefix, operation_name, func_name, func_ref
|
|
|
|
|
|
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
|
|
"Find the class in which a member function is first defined in a class inheritance hierarchy."
|
|
|
|
# This import must be dynamic here
|
|
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
|
|
|
|
# iterate in reverse member resolution order to find most specific class first
|
|
for cls in reversed(inspect.getmro(derived_cls)):
|
|
for name, _ in inspect.getmembers(cls, inspect.isfunction):
|
|
if name == member_fn:
|
|
# HACK ALERT
|
|
if cls == RAGToolRuntime:
|
|
return ToolRuntime
|
|
return cls
|
|
|
|
raise ValidationError(
|
|
f"cannot find defining class for {member_fn} in {derived_cls}"
|
|
)
|
|
|
|
|
|
def get_endpoint_operations(
|
|
endpoint: type, use_examples: bool = True
|
|
) -> List[EndpointOperation]:
|
|
"""
|
|
Extracts a list of member functions in a class eligible for HTTP interface binding.
|
|
|
|
These member functions are expected to have a signature like
|
|
```
|
|
async def get_object(self, uuid: str, version: int) -> Object:
|
|
...
|
|
```
|
|
where the prefix `get_` translates to an HTTP GET, `object` corresponds to the name of the endpoint operation,
|
|
`uuid` and `version` are mapped to route path elements in "/object/{uuid}/{version}", and `Object` becomes
|
|
the response payload type, transmitted as an object serialized to JSON.
|
|
|
|
If the member function has a composite class type in the argument list, it becomes the request payload type,
|
|
and the caller is expected to provide the data as serialized JSON in an HTTP POST request.
|
|
|
|
:param endpoint: A class with member functions that can be mapped to an HTTP endpoint.
|
|
:param use_examples: Whether to return examples associated with member functions.
|
|
"""
|
|
|
|
result = []
|
|
|
|
for prefix, operation_name, func_name, func_ref in _get_endpoint_functions(
|
|
endpoint,
|
|
[
|
|
"create",
|
|
"delete",
|
|
"do",
|
|
"get",
|
|
"post",
|
|
"put",
|
|
"remove",
|
|
"set",
|
|
"update",
|
|
],
|
|
):
|
|
# Get all webmethods for this function
|
|
webmethods = getattr(func_ref, "__webmethods__", [])
|
|
|
|
# Create one EndpointOperation for each webmethod
|
|
for webmethod in webmethods:
|
|
route = webmethod.route
|
|
route_params = _get_route_parameters(route) if route is not None else None
|
|
public = webmethod.public
|
|
request_examples = webmethod.request_examples
|
|
response_examples = webmethod.response_examples
|
|
|
|
# inspect function signature for path and query parameters, and request/response payload type
|
|
signature = get_signature(func_ref)
|
|
|
|
path_params = []
|
|
query_params = []
|
|
request_params = []
|
|
multipart_params = []
|
|
extra_body_params = []
|
|
|
|
for param_name, parameter in signature.parameters.items():
|
|
param_type = _get_annotation_type(parameter.annotation, func_ref)
|
|
|
|
# omit "self" for instance methods
|
|
if param_name == "self" and param_type is inspect.Parameter.empty:
|
|
continue
|
|
|
|
# check if all parameters have explicit type
|
|
if parameter.annotation is inspect.Parameter.empty:
|
|
raise ValidationError(
|
|
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
|
|
)
|
|
|
|
# Check if this is an extra_body parameter
|
|
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
|
|
if is_extra_body:
|
|
# Store in a separate list for documentation
|
|
extra_body_params.append((param_name, param_type, extra_body_desc))
|
|
continue # Skip adding to request_params
|
|
|
|
is_multipart = _is_multipart_param(param_type)
|
|
|
|
if prefix in ["get", "delete"]:
|
|
if route_params is not None and param_name in route_params:
|
|
path_params.append((param_name, param_type))
|
|
else:
|
|
query_params.append((param_name, param_type))
|
|
else:
|
|
if route_params is not None and param_name in route_params:
|
|
path_params.append((param_name, param_type))
|
|
elif is_multipart:
|
|
multipart_params.append((param_name, param_type))
|
|
else:
|
|
request_params.append((param_name, param_type))
|
|
|
|
# check if function has explicit return type
|
|
if signature.return_annotation is inspect.Signature.empty:
|
|
raise ValidationError(
|
|
f"function '{func_name}' has no return type annotation"
|
|
)
|
|
|
|
return_type = _get_annotation_type(signature.return_annotation, func_ref)
|
|
|
|
# operations that produce events are labeled as Generator[YieldType, SendType, ReturnType]
|
|
# where YieldType is the event type, SendType is None, and ReturnType is the immediate response type to the request
|
|
if typing.get_origin(return_type) is collections.abc.Generator:
|
|
event_type, send_type, response_type = typing.get_args(return_type)
|
|
if send_type is not type(None):
|
|
raise ValidationError(
|
|
f"function '{func_name}' has a return type Generator[Y,S,R] and therefore looks like an event but has an explicit send type"
|
|
)
|
|
else:
|
|
event_type = None
|
|
|
|
def process_type(t):
|
|
if typing.get_origin(t) is collections.abc.AsyncIterator:
|
|
# NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
|
|
# or the item type. I am choosing it to be the latter
|
|
args = typing.get_args(t)
|
|
return args[0]
|
|
elif typing.get_origin(t) is typing.Union:
|
|
types = [process_type(a) for a in typing.get_args(t)]
|
|
return typing._UnionGenericAlias(typing.Union, tuple(types))
|
|
else:
|
|
return t
|
|
|
|
response_type = process_type(return_type)
|
|
|
|
if prefix in ["delete", "remove"]:
|
|
http_method = HTTPMethod.DELETE
|
|
elif prefix == "post":
|
|
http_method = HTTPMethod.POST
|
|
elif prefix == "get":
|
|
http_method = HTTPMethod.GET
|
|
elif prefix == "set":
|
|
http_method = HTTPMethod.PUT
|
|
elif prefix == "update":
|
|
http_method = HTTPMethod.PATCH
|
|
else:
|
|
raise ValidationError(f"unknown prefix {prefix}")
|
|
|
|
# Create an EndpointOperation for this specific webmethod
|
|
operation = EndpointOperation(
|
|
defining_class=_get_defining_class(func_name, endpoint),
|
|
name=operation_name,
|
|
func_name=func_name,
|
|
func_ref=func_ref,
|
|
route=route,
|
|
path_params=path_params,
|
|
query_params=query_params,
|
|
request_params=request_params,
|
|
multipart_params=multipart_params,
|
|
extra_body_params=extra_body_params,
|
|
event_type=event_type,
|
|
response_type=response_type,
|
|
http_method=http_method,
|
|
public=public,
|
|
request_examples=request_examples if use_examples else None,
|
|
response_examples=response_examples if use_examples else None,
|
|
)
|
|
|
|
# Store the specific webmethod with this operation
|
|
operation.webmethod = webmethod
|
|
result.append(operation)
|
|
|
|
if not result:
|
|
raise ValidationError(f"no eligible endpoint operations in type {endpoint}")
|
|
|
|
return result
|
|
|
|
|
|
def get_endpoint_events(endpoint: type) -> Dict[str, type]:
|
|
results = {}
|
|
|
|
for decl in typing.get_type_hints(endpoint).values():
|
|
# check if signature is Callable[...]
|
|
origin = typing.get_origin(decl)
|
|
if origin is None or not issubclass(origin, Callable): # type: ignore
|
|
continue
|
|
|
|
# check if signature is Callable[[...], Any]
|
|
args = typing.get_args(decl)
|
|
if len(args) != 2:
|
|
continue
|
|
params_type, return_type = args
|
|
if not isinstance(params_type, list):
|
|
continue
|
|
|
|
# check if signature is Callable[[...], None]
|
|
if not issubclass(return_type, type(None)):
|
|
continue
|
|
|
|
# check if signature is Callable[[EventType], None]
|
|
if len(params_type) != 1:
|
|
continue
|
|
|
|
param_type = params_type[0]
|
|
results[param_type.__name__] = param_type
|
|
|
|
return results
|
|
|
|
|
|
def _is_multipart_param(param_type: type) -> bool:
|
|
"""
|
|
Check if a parameter type indicates multipart form data.
|
|
|
|
Returns True if the type is:
|
|
- UploadFile
|
|
- Annotated[UploadFile, File()]
|
|
- Annotated[str, Form()]
|
|
- Annotated[Any, File()]
|
|
- Annotated[Any, Form()]
|
|
"""
|
|
if param_type is UploadFile:
|
|
return True
|
|
|
|
# Check for Annotated types
|
|
origin = get_origin(param_type)
|
|
if origin is None:
|
|
return False
|
|
|
|
if origin is Annotated:
|
|
args = get_args(param_type)
|
|
if len(args) < 2:
|
|
return False
|
|
|
|
# Check the annotations for File() or Form()
|
|
for annotation in args[1:]:
|
|
if isinstance(annotation, (File, Form)):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
|
|
"""
|
|
Check if parameter is marked as coming from extra_body.
|
|
|
|
Returns:
|
|
(is_extra_body, description): Tuple of boolean and optional description
|
|
"""
|
|
origin = get_origin(param_type)
|
|
if origin is Annotated:
|
|
args = get_args(param_type)
|
|
for annotation in args[1:]:
|
|
if isinstance(annotation, ExtraBodyField):
|
|
return True, annotation.description
|
|
# Also check by type name for cases where import matters
|
|
if type(annotation).__name__ == 'ExtraBodyField':
|
|
return True, getattr(annotation, 'description', None)
|
|
return False, None
|