fix routing in library client (#776)

# What does this PR do?

Library client needs to match the impl based on both the path and
method. Since path is no longer static, this PR uses the inefficient way
of using regexes computed based on the annotated route path to match
against the incoming request path. The variables now also can come to
the impl from both path or the body, which is also handled cleanly by
finding all the regex matches.



## Test Plan


LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
Dinesh Yeduguru 2025-01-15 15:59:45 -08:00 committed by GitHub
parent 3e518c049a
commit 8fd9bcb8cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9,6 +9,7 @@ import inspect
import json import json
import logging import logging
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
endpoints = get_all_api_endpoints() endpoints = get_all_api_endpoints()
endpoint_impls = {} endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items(): for api, api_endpoints in endpoints.items():
if api not in self.impls: if api not in self.impls:
continue continue
for endpoint in api_endpoints: for endpoint in api_endpoints:
impl = self.impls[api] impl = self.impls[api]
func = getattr(impl, endpoint.name) func = getattr(impl, endpoint.name)
endpoint_impls[endpoint.route] = func if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][
_convert_path_to_regex(endpoint.route)
] = func
self.endpoint_impls = endpoint_impls self.endpoint_impls = endpoint_impls
return True return True
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
await end_trace() await end_trace()
return response return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
options: Any, options: Any,
): ):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func = self.endpoint_impls.get(path)
if not func:
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) matched_func, path_params = self._find_matching_endpoint(options.method, path)
result = await func(**body) body |= path_params
body = self._convert_body(path, options.method, body)
result = await matched_func(**body)
json_content = json.dumps(convert_pydantic_to_json_value(result)) json_content = json.dumps(convert_pydantic_to_json_value(result))
mock_response = httpx.Response( mock_response = httpx.Response(
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func = self.endpoint_impls.get(path) func, path_params = self._find_matching_endpoint(options.method, path)
if not func: body |= path_params
raise ValueError(f"No endpoint found for {path}")
body = self._convert_body(path, body) body = self._convert_body(path, options.method, body)
async def gen(): async def gen():
async for chunk in await func(**body): async for chunk in await func(**body):
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return await response.parse() return await response.parse()
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict: def _convert_body(
self, path: str, method: str, body: Optional[dict] = None
) -> dict:
if not body: if not body:
return {} return {}
func = self.endpoint_impls[path] func, _ = self._find_matching_endpoint(method, path)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature