llama-stack-mirror/docs/openapi_generator/pyopenapi/operations.py
Charlie Doern c88c4ff2c6
feat: introduce API leveling, post_training, eval to v1alpha (#3449)
# What does this PR do?

Rather than have a single `LLAMA_STACK_VERSION`, we need to have a
`_V1`, `_V1ALPHA`, and `_V1BETA` constant.

This also necessitated addition of `level` to the `WebMethod` so that
routing can be handeled properly.


For backwards compat, the `v1` routes are being kept around and marked
as `deprecated`. When used, the server will log a deprecation warning.

Deprecation log:

<img width="1224" height="134" alt="Screenshot 2025-09-25 at 2 43 36 PM"
src="https://github.com/user-attachments/assets/0cc7c245-dafc-48f0-be99-269fb9a686f9"
/>

move:
1. post_training to `v1alpha` as it is under heavy development and not
near its final state
2. eval: job scheduling is not implemented. Relies heavily on the
datasetio API which is under development missing implementations of
specific routes indicating the structure of those routes might change.
Additionally eval depends on the `inference` API which is going to be
deprecated, eval will likely need a major API surface change to conform
to using completions properly

implements leveling in #3317 

note: integration tests will fail until the SDK is regenerated with
v1alpha/inference as opposed to v1/inference

## Test Plan

existing tests should pass with newly generated schema. Conformance will
also pass as these routes are not the ones we currently test for
stability

Signed-off-by: Charlie Doern <cdoern@redhat.com>
2025-09-26 16:18:07 +02:00

431 lines
16 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import collections.abc
import enum
import inspect
import typing
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA, LLAMA_STACK_API_V1ALPHA
from termcolor import colored
from llama_stack.strong_typing.inspection import get_signature
from typing import get_origin, get_args
from fastapi import UploadFile
from fastapi.params import File, Form
from typing import Annotated
def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]]
) -> Tuple[Optional[str], str]:
"""
Recognizes a prefix at the beginning of a string.
:param s: The string to check.
:param sep: A separator between (one of) the prefix(es) and the rest of the string.
:param prefix: A string or a set of strings to identify as a prefix.
:return: A tuple of the recognized prefix (if any) and the rest of the string excluding the separator (or the entire string).
"""
if isinstance(prefix, str):
if s.startswith(prefix + sep):
return prefix, s[len(prefix) + len(sep) :]
else:
return None, s
for p in prefix:
if s.startswith(p + sep):
return p, s[len(p) + len(sep) :]
return None, s
def _get_annotation_type(annotation: Union[type, str], callable: Callable) -> type:
"Maps a stringized reference to a type, as if using `from __future__ import annotations`."
if isinstance(annotation, str):
return eval(annotation, callable.__globals__)
else:
return annotation
class HTTPMethod(enum.Enum):
"HTTP method used to invoke an endpoint operation."
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
OperationParameter = Tuple[str, type]
class ValidationError(TypeError):
pass
@dataclass
class EndpointOperation:
"""
Type information and metadata associated with an endpoint operation.
"param defining_class: The most specific class that defines the endpoint operation.
:param name: The short name of the endpoint operation.
:param func_name: The name of the function to invoke when the operation is triggered.
:param func_ref: The callable to invoke when the operation is triggered.
:param route: A custom route string assigned to the operation.
:param path_params: Parameters of the operation signature that are passed in the path component of the URL string.
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
:param public: True if the operation can be invoked without prior authentication.
:param request_examples: Sample requests that the operation might take.
:param response_examples: Sample responses that the operation might produce.
"""
defining_class: type
name: str
func_name: str
func_ref: Callable[..., Any]
route: Optional[str]
path_params: List[OperationParameter]
query_params: List[OperationParameter]
request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter]
event_type: Optional[type]
response_type: type
http_method: HTTPMethod
public: bool
request_examples: Optional[List[Any]] = None
response_examples: Optional[List[Any]] = None
def get_route(self, webmethod) -> str:
api_level = webmethod.level
if self.route is not None:
return "/".join(["", api_level, self.route.lstrip("/")])
route_parts = ["", api_level, self.name]
for param_name, _ in self.path_params:
route_parts.append("{" + param_name + "}")
return "/".join(route_parts)
class _FormatParameterExtractor:
"A visitor to exract parameters in a format string."
keys: List[str]
def __init__(self) -> None:
self.keys = []
def __getitem__(self, key: str) -> None:
self.keys.append(key)
return None
def _get_route_parameters(route: str) -> List[str]:
extractor = _FormatParameterExtractor()
# Replace all occurrences of ":path" with empty string
route = route.replace(":path", "")
route.format_map(extractor)
return extractor.keys
def _get_endpoint_functions(
endpoint: type, prefixes: List[str]
) -> Iterator[Tuple[str, str, str, Callable]]:
if not inspect.isclass(endpoint):
raise ValueError(f"object is not a class type: {endpoint}")
functions = inspect.getmembers(endpoint, inspect.isfunction)
for func_name, func_ref in functions:
webmethods = []
# Check for multiple webmethods (stacked decorators)
if hasattr(func_ref, "__webmethods__"):
webmethods = func_ref.__webmethods__
if not webmethods:
continue
for webmethod in webmethods:
print(f"Processing {colored(func_name, 'white')}...")
operation_name = func_name
if webmethod.method == "GET":
prefix = "get"
elif webmethod.method == "DELETE":
prefix = "delete"
elif webmethod.method == "POST":
prefix = "post"
elif operation_name.startswith("get_") or operation_name.endswith("/get"):
prefix = "get"
elif (
operation_name.startswith("delete_")
or operation_name.startswith("remove_")
or operation_name.endswith("/delete")
or operation_name.endswith("/remove")
):
prefix = "delete"
else:
# by default everything else is a POST
prefix = "post"
yield prefix, operation_name, func_name, func_ref
def _get_defining_class(member_fn: str, derived_cls: type) -> type:
"Find the class in which a member function is first defined in a class inheritance hierarchy."
# This import must be dynamic here
from llama_stack.apis.tools import RAGToolRuntime, ToolRuntime
# iterate in reverse member resolution order to find most specific class first
for cls in reversed(inspect.getmro(derived_cls)):
for name, _ in inspect.getmembers(cls, inspect.isfunction):
if name == member_fn:
# HACK ALERT
if cls == RAGToolRuntime:
return ToolRuntime
return cls
raise ValidationError(
f"cannot find defining class for {member_fn} in {derived_cls}"
)
def get_endpoint_operations(
endpoint: type, use_examples: bool = True
) -> List[EndpointOperation]:
"""
Extracts a list of member functions in a class eligible for HTTP interface binding.
These member functions are expected to have a signature like
```
async def get_object(self, uuid: str, version: int) -> Object:
...
```
where the prefix `get_` translates to an HTTP GET, `object` corresponds to the name of the endpoint operation,
`uuid` and `version` are mapped to route path elements in "/object/{uuid}/{version}", and `Object` becomes
the response payload type, transmitted as an object serialized to JSON.
If the member function has a composite class type in the argument list, it becomes the request payload type,
and the caller is expected to provide the data as serialized JSON in an HTTP POST request.
:param endpoint: A class with member functions that can be mapped to an HTTP endpoint.
:param use_examples: Whether to return examples associated with member functions.
"""
result = []
for prefix, operation_name, func_name, func_ref in _get_endpoint_functions(
endpoint,
[
"create",
"delete",
"do",
"get",
"post",
"put",
"remove",
"set",
"update",
],
):
# Get all webmethods for this function
webmethods = getattr(func_ref, "__webmethods__", [])
# Create one EndpointOperation for each webmethod
for webmethod in webmethods:
route = webmethod.route
route_params = _get_route_parameters(route) if route is not None else None
public = webmethod.public
request_examples = webmethod.request_examples
response_examples = webmethod.response_examples
# inspect function signature for path and query parameters, and request/response payload type
signature = get_signature(func_ref)
path_params = []
query_params = []
request_params = []
multipart_params = []
for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref)
# omit "self" for instance methods
if param_name == "self" and param_type is inspect.Parameter.empty:
continue
# check if all parameters have explicit type
if parameter.annotation is inspect.Parameter.empty:
raise ValidationError(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]:
if route_params is not None and param_name in route_params:
path_params.append((param_name, param_type))
else:
query_params.append((param_name, param_type))
else:
if route_params is not None and param_name in route_params:
path_params.append((param_name, param_type))
elif is_multipart:
multipart_params.append((param_name, param_type))
else:
request_params.append((param_name, param_type))
# check if function has explicit return type
if signature.return_annotation is inspect.Signature.empty:
raise ValidationError(
f"function '{func_name}' has no return type annotation"
)
return_type = _get_annotation_type(signature.return_annotation, func_ref)
# operations that produce events are labeled as Generator[YieldType, SendType, ReturnType]
# where YieldType is the event type, SendType is None, and ReturnType is the immediate response type to the request
if typing.get_origin(return_type) is collections.abc.Generator:
event_type, send_type, response_type = typing.get_args(return_type)
if send_type is not type(None):
raise ValidationError(
f"function '{func_name}' has a return type Generator[Y,S,R] and therefore looks like an event but has an explicit send type"
)
else:
event_type = None
def process_type(t):
if typing.get_origin(t) is collections.abc.AsyncIterator:
# NOTE(ashwin): this is SSE and there is no way to represent it. either we make it a List
# or the item type. I am choosing it to be the latter
args = typing.get_args(t)
return args[0]
elif typing.get_origin(t) is typing.Union:
types = [process_type(a) for a in typing.get_args(t)]
return typing._UnionGenericAlias(typing.Union, tuple(types))
else:
return t
response_type = process_type(return_type)
if prefix in ["delete", "remove"]:
http_method = HTTPMethod.DELETE
elif prefix == "post":
http_method = HTTPMethod.POST
elif prefix == "get":
http_method = HTTPMethod.GET
elif prefix == "set":
http_method = HTTPMethod.PUT
elif prefix == "update":
http_method = HTTPMethod.PATCH
else:
raise ValidationError(f"unknown prefix {prefix}")
# Create an EndpointOperation for this specific webmethod
operation = EndpointOperation(
defining_class=_get_defining_class(func_name, endpoint),
name=operation_name,
func_name=func_name,
func_ref=func_ref,
route=route,
path_params=path_params,
query_params=query_params,
request_params=request_params,
multipart_params=multipart_params,
event_type=event_type,
response_type=response_type,
http_method=http_method,
public=public,
request_examples=request_examples if use_examples else None,
response_examples=response_examples if use_examples else None,
)
# Store the specific webmethod with this operation
operation.webmethod = webmethod
result.append(operation)
if not result:
raise ValidationError(f"no eligible endpoint operations in type {endpoint}")
return result
def get_endpoint_events(endpoint: type) -> Dict[str, type]:
results = {}
for decl in typing.get_type_hints(endpoint).values():
# check if signature is Callable[...]
origin = typing.get_origin(decl)
if origin is None or not issubclass(origin, Callable): # type: ignore
continue
# check if signature is Callable[[...], Any]
args = typing.get_args(decl)
if len(args) != 2:
continue
params_type, return_type = args
if not isinstance(params_type, list):
continue
# check if signature is Callable[[...], None]
if not issubclass(return_type, type(None)):
continue
# check if signature is Callable[[EventType], None]
if len(params_type) != 1:
continue
param_type = params_type[0]
results[param_type.__name__] = param_type
return results
def _is_multipart_param(param_type: type) -> bool:
"""
Check if a parameter type indicates multipart form data.
Returns True if the type is:
- UploadFile
- Annotated[UploadFile, File()]
- Annotated[str, Form()]
- Annotated[Any, File()]
- Annotated[Any, Form()]
"""
if param_type is UploadFile:
return True
# Check for Annotated types
origin = get_origin(param_type)
if origin is None:
return False
if origin is Annotated:
args = get_args(param_type)
if len(args) < 2:
return False
# Check the annotations for File() or Form()
for annotation in args[1:]:
if isinstance(annotation, (File, Form)):
return True
return False