diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 0c124e64b..fdc68c0a4 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,6 +9,7 @@ import inspect import json import logging import os +import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path @@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): endpoints = get_all_api_endpoints() endpoint_impls = {} + + def _convert_path_to_regex(path: str) -> str: + # Convert {param} to named capture groups + pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path) + return f"^{pattern}$" + for api, api_endpoints in endpoints.items(): if api not in self.impls: continue for endpoint in api_endpoints: impl = self.impls[api] func = getattr(impl, endpoint.name) - endpoint_impls[endpoint.route] = func + if endpoint.method not in endpoint_impls: + endpoint_impls[endpoint.method] = {} + endpoint_impls[endpoint.method][ + _convert_path_to_regex(endpoint.route) + ] = func self.endpoint_impls = endpoint_impls return True @@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): await end_trace() return response + def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]: + """Find the matching endpoint implementation for a given method and path. + + Args: + method: HTTP method (GET, POST, etc.) + path: URL path to match against + + Returns: + A tuple of (endpoint_function, path_params) + + Raises: + ValueError: If no matching endpoint is found + """ + impls = self.endpoint_impls.get(method) + if not impls: + raise ValueError(f"No endpoint found for {path}") + + for regex, func in impls.items(): + match = re.match(regex, path) + if match: + # Extract named groups from the regex match + path_params = match.groupdict() + return func, path_params + + raise ValueError(f"No endpoint found for {path}") + async def _call_non_streaming( self, *, @@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): options: Any, ): path = options.url - body = options.params or {} body |= options.json_data or {} - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") - body = self._convert_body(path, body) - result = await func(**body) + matched_func, path_params = self._find_matching_endpoint(options.method, path) + body |= path_params + body = self._convert_body(path, options.method, body) + result = await matched_func(**body) json_content = json.dumps(convert_pydantic_to_json_value(result)) mock_response = httpx.Response( @@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): path = options.url body = options.params or {} body |= options.json_data or {} - func = self.endpoint_impls.get(path) - if not func: - raise ValueError(f"No endpoint found for {path}") + func, path_params = self._find_matching_endpoint(options.method, path) + body |= path_params - body = self._convert_body(path, body) + body = self._convert_body(path, options.method, body) async def gen(): async for chunk in await func(**body): @@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: + def _convert_body( + self, path: str, method: str, body: Optional[dict] = None + ) -> dict: if not body: return {} - func = self.endpoint_impls[path] + func, _ = self._find_matching_endpoint(method, path) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature