mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
chore: return webmethod from find_matching_route (#2883)
This will be used to support API access control, i.e. Webmethod would have a `required_scope` attribute, and we need access to that in the middleware.
This commit is contained in:
parent
1463b79218
commit
cbe89d2bdd
4 changed files with 29 additions and 21 deletions
|
@ -42,8 +42,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
all_endpoints = get_all_api_routes()
|
api_to_routes = get_all_api_routes()
|
||||||
for api, endpoints in all_endpoints.items():
|
for api, endpoints in api_to_routes.items():
|
||||||
# Always include provider and inspect APIs, filter others based on run config
|
# Always include provider and inspect APIs, filter others based on run config
|
||||||
if api.value in ["providers", "inspect"]:
|
if api.value in ["providers", "inspect"]:
|
||||||
ret.extend(
|
ret.extend(
|
||||||
|
@ -53,7 +53,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e, _ in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -67,7 +67,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e, _ in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -359,13 +359,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
|
|
||||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body, field_names = self._handle_file_uploads(options, body)
|
body, field_names = self._handle_file_uploads(options, body)
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||||
await start_trace(route, {"__location__": "library_client"})
|
|
||||||
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
|
await start_trace(trace_path, {"__location__": "library_client"})
|
||||||
try:
|
try:
|
||||||
result = await matched_func(**body)
|
result = await matched_func(**body)
|
||||||
finally:
|
finally:
|
||||||
|
@ -415,12 +417,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||||
body |= path_params
|
body |= path_params
|
||||||
|
|
||||||
body = self._convert_body(path, options.method, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
await start_trace(route, {"__location__": "library_client"})
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
|
await start_trace(trace_path, {"__location__": "library_client"})
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
try:
|
try:
|
||||||
|
@ -475,7 +478,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
exclude_params = exclude_params or set()
|
exclude_params = exclude_params or set()
|
||||||
|
|
||||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
|
@ -16,13 +16,14 @@ from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||||
from llama_stack.distribution.resolver import api_protocol_map
|
from llama_stack.distribution.resolver import api_protocol_map
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
from llama_stack.schema_utils import WebMethod
|
||||||
|
|
||||||
EndpointFunc = Callable[..., Any]
|
EndpointFunc = Callable[..., Any]
|
||||||
PathParams = dict[str, str]
|
PathParams = dict[str, str]
|
||||||
RouteInfo = tuple[EndpointFunc, str]
|
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||||
PathImpl = dict[str, RouteInfo]
|
PathImpl = dict[str, RouteInfo]
|
||||||
RouteImpls = dict[str, PathImpl]
|
RouteImpls = dict[str, PathImpl]
|
||||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||||
|
|
||||||
|
|
||||||
def toolgroup_protocol_map():
|
def toolgroup_protocol_map():
|
||||||
|
@ -31,7 +32,7 @@ def toolgroup_protocol_map():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
def get_all_api_routes() -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||||
apis = {}
|
apis = {}
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
|
@ -65,7 +66,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
http_method = hdrs.METH_POST
|
||||||
routes.append(
|
routes.append(
|
||||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||||
) # setting endpoint to None since don't use a Router object
|
) # setting endpoint to None since don't use a Router object
|
||||||
|
|
||||||
apis[api] = routes
|
apis[api] = routes
|
||||||
|
@ -74,7 +75,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
|
|
||||||
|
|
||||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
routes = get_all_api_routes()
|
api_to_routes = get_all_api_routes()
|
||||||
route_impls: RouteImpls = {}
|
route_impls: RouteImpls = {}
|
||||||
|
|
||||||
def _convert_path_to_regex(path: str) -> str:
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
|
@ -88,10 +89,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
|
|
||||||
return f"^{pattern}$"
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_routes in routes.items():
|
for api, api_routes in api_to_routes.items():
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
for route in api_routes:
|
for route, webmethod in api_routes:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, route.name)
|
func = getattr(impl, route.name)
|
||||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||||
|
@ -104,6 +105,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||||
func,
|
func,
|
||||||
route.path,
|
route.path,
|
||||||
|
webmethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
return route_impls
|
return route_impls
|
||||||
|
@ -118,7 +120,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
||||||
route_impls: A dictionary of endpoint implementations
|
route_impls: A dictionary of endpoint implementations
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no matching endpoint is found
|
ValueError: If no matching endpoint is found
|
||||||
|
@ -127,11 +129,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
||||||
if not impls:
|
if not impls:
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
for regex, (func, descriptive_name) in impls.items():
|
for regex, (func, route_path, webmethod) in impls.items():
|
||||||
match = re.match(regex, path)
|
match = re.match(regex, path)
|
||||||
if match:
|
if match:
|
||||||
# Extract named groups from the regex match
|
# Extract named groups from the regex match
|
||||||
path_params = match.groupdict()
|
path_params = match.groupdict()
|
||||||
return func, path_params, descriptive_name
|
return func, path_params, route_path, webmethod
|
||||||
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
|
@ -304,7 +304,9 @@ class TracingMiddleware:
|
||||||
self.route_impls = initialize_route_impls(self.impls)
|
self.route_impls = initialize_route_impls(self.impls)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
_, _, route_path, webmethod = find_matching_route(
|
||||||
|
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no matching endpoint is found, pass through to FastAPI
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||||
|
@ -321,6 +323,7 @@ class TracingMiddleware:
|
||||||
if tracestate:
|
if tracestate:
|
||||||
trace_attributes["tracestate"] = tracestate
|
trace_attributes["tracestate"] = tracestate
|
||||||
|
|
||||||
|
trace_path = webmethod.descriptive_name or route_path
|
||||||
trace_context = await start_trace(trace_path, trace_attributes)
|
trace_context = await start_trace(trace_path, trace_attributes)
|
||||||
|
|
||||||
async def send_with_trace_id(message):
|
async def send_with_trace_id(message):
|
||||||
|
@ -504,7 +507,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
routes = all_routes[api]
|
routes = all_routes[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
for route in routes:
|
for route, _ in routes:
|
||||||
if not hasattr(impl, route.name):
|
if not hasattr(impl, route.name):
|
||||||
# ideally this should be a typing violation already
|
# ideally this should be a typing violation already
|
||||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue