fix routing in library client

This commit is contained in:
Dinesh Yeduguru 2025-01-15 15:52:21 -08:00
parent 27e07b44b5
commit 85a3fcee8e

View file

@ -9,6 +9,7 @@ import inspect
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][
_convert_path_to_regex(endpoint.route)
] = func
self.endpoint_impls = endpoint_impls
return True
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace()
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming(
self,
*,
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any,
):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body)
result = await func(**body)
matched_func, path_params = self._find_matching_endpoint(options.method, path)
body |= path_params
body = self._convert_body(path, options.method, body)
result = await matched_func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response(
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
func, path_params = self._find_matching_endpoint(options.method, path)
body |= path_params
body = self._convert_body(path, body)
body = self._convert_body(path, options.method, body)
async def gen():
async for chunk in await func(**body):
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
def _convert_body(
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
if not body:
return {}
func = self.endpoint_impls[path]
func, _ = self._find_matching_endpoint(method, path)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature