fix routing in library client

This commit is contained in:
Dinesh Yeduguru 2025-01-15 15:52:21 -08:00
parent 27e07b44b5
commit 85a3fcee8e

View file

@ -9,6 +9,7 @@ import inspect
import json import json
import logging import logging
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
endpoints = get_all_api_endpoints() endpoints = get_all_api_endpoints()
endpoint_impls = {} endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items(): for api, api_endpoints in endpoints.items():
if api not in self.impls: if api not in self.impls:
continue continue
for endpoint in api_endpoints: for endpoint in api_endpoints:
impl = self.impls[api] impl = self.impls[api]
func = getattr(impl, endpoint.name) func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][
_convert_path_to_regex(endpoint.route)
] = func
self.endpoint_impls = endpoint_impls self.endpoint_impls = endpoint_impls
return True return True
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace() await end_trace()
return response return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any, options: Any,
): ):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) matched_func, path_params = self._find_matching_endpoint(options.method, path)
result = await func(**body) body |= path_params
body = self._convert_body(path, options.method, body)
result = await matched_func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response( mock_response = httpx.Response(
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func = self.endpoint_impls.get(path) func, path_params = self._find_matching_endpoint(options.method, path)
if not func: body |= path_params
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, options.method, body)
async def gen(): async def gen():
async for chunk in await func(**body): async for chunk in await func(**body):
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
if not body: if not body:
return {} return {}
func = self.endpoint_impls[path] func, _ = self._find_matching_endpoint(method, path)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature