mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix routing in library client (#776)
# What does this PR do? Library client needs to match the impl based on both the path and method. Since path is no longer static, this PR uses the inefficient way of using regexes computed based on the annotated route path to match against the incoming request path. The variables now also can come to the impl from both path or the body, which is also handled cleanly by finding all the regex matches. ## Test Plan LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py
This commit is contained in:
parent
3e518c049a
commit
8fd9bcb8cd
1 changed files with 49 additions and 13 deletions
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
endpoints = get_all_api_endpoints()
|
||||||
endpoint_impls = {}
|
endpoint_impls = {}
|
||||||
|
|
||||||
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
|
# Convert {param} to named capture groups
|
||||||
|
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
|
||||||
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_endpoints in endpoints.items():
|
for api, api_endpoints in endpoints.items():
|
||||||
if api not in self.impls:
|
if api not in self.impls:
|
||||||
continue
|
continue
|
||||||
for endpoint in api_endpoints:
|
for endpoint in api_endpoints:
|
||||||
impl = self.impls[api]
|
impl = self.impls[api]
|
||||||
func = getattr(impl, endpoint.name)
|
func = getattr(impl, endpoint.name)
|
||||||
endpoint_impls[endpoint.route] = func
|
if endpoint.method not in endpoint_impls:
|
||||||
|
endpoint_impls[endpoint.method] = {}
|
||||||
|
endpoint_impls[endpoint.method][
|
||||||
|
_convert_path_to_regex(endpoint.route)
|
||||||
|
] = func
|
||||||
|
|
||||||
self.endpoint_impls = endpoint_impls
|
self.endpoint_impls = endpoint_impls
|
||||||
return True
|
return True
|
||||||
|
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
await end_trace()
|
await end_trace()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||||
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: URL path to match against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (endpoint_function, path_params)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no matching endpoint is found
|
||||||
|
"""
|
||||||
|
impls = self.endpoint_impls.get(method)
|
||||||
|
if not impls:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
|
for regex, func in impls.items():
|
||||||
|
match = re.match(regex, path)
|
||||||
|
if match:
|
||||||
|
# Extract named groups from the regex match
|
||||||
|
path_params = match.groupdict()
|
||||||
|
return func, path_params
|
||||||
|
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
path = options.url
|
path = options.url
|
||||||
|
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func = self.endpoint_impls.get(path)
|
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
||||||
result = await func(**body)
|
body |= path_params
|
||||||
|
body = self._convert_body(path, options.method, body)
|
||||||
|
result = await matched_func(**body)
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
|
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func = self.endpoint_impls.get(path)
|
func, path_params = self._find_matching_endpoint(options.method, path)
|
||||||
if not func:
|
body |= path_params
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
|
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(
|
||||||
|
self, path: str, method: str, body: Optional[dict] = None
|
||||||
|
) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
func = self.endpoint_impls[path]
|
func, _ = self._find_matching_endpoint(method, path)
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue