mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
fix routing in library client
This commit is contained in:
parent
27e07b44b5
commit
85a3fcee8e
1 changed files with 49 additions and 13 deletions
|
@ -9,6 +9,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
|
|
||||||
endpoints = get_all_api_endpoints()
|
endpoints = get_all_api_endpoints()
|
||||||
endpoint_impls = {}
|
endpoint_impls = {}
|
||||||
|
|
||||||
|
def _convert_path_to_regex(path: str) -> str:
|
||||||
|
# Convert {param} to named capture groups
|
||||||
|
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
|
||||||
|
return f"^{pattern}$"
|
||||||
|
|
||||||
for api, api_endpoints in endpoints.items():
|
for api, api_endpoints in endpoints.items():
|
||||||
if api not in self.impls:
|
if api not in self.impls:
|
||||||
continue
|
continue
|
||||||
for endpoint in api_endpoints:
|
for endpoint in api_endpoints:
|
||||||
impl = self.impls[api]
|
impl = self.impls[api]
|
||||||
func = getattr(impl, endpoint.name)
|
func = getattr(impl, endpoint.name)
|
||||||
endpoint_impls[endpoint.route] = func
|
if endpoint.method not in endpoint_impls:
|
||||||
|
endpoint_impls[endpoint.method] = {}
|
||||||
|
endpoint_impls[endpoint.method][
|
||||||
|
_convert_path_to_regex(endpoint.route)
|
||||||
|
] = func
|
||||||
|
|
||||||
self.endpoint_impls = endpoint_impls
|
self.endpoint_impls = endpoint_impls
|
||||||
return True
|
return True
|
||||||
|
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
await end_trace()
|
await end_trace()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||||
|
"""Find the matching endpoint implementation for a given method and path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
path: URL path to match against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (endpoint_function, path_params)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no matching endpoint is found
|
||||||
|
"""
|
||||||
|
impls = self.endpoint_impls.get(method)
|
||||||
|
if not impls:
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
|
for regex, func in impls.items():
|
||||||
|
match = re.match(regex, path)
|
||||||
|
if match:
|
||||||
|
# Extract named groups from the regex match
|
||||||
|
path_params = match.groupdict()
|
||||||
|
return func, path_params
|
||||||
|
|
||||||
|
raise ValueError(f"No endpoint found for {path}")
|
||||||
|
|
||||||
async def _call_non_streaming(
|
async def _call_non_streaming(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
options: Any,
|
options: Any,
|
||||||
):
|
):
|
||||||
path = options.url
|
path = options.url
|
||||||
|
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func = self.endpoint_impls.get(path)
|
|
||||||
if not func:
|
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
||||||
result = await func(**body)
|
body |= path_params
|
||||||
|
body = self._convert_body(path, options.method, body)
|
||||||
|
result = await matched_func(**body)
|
||||||
|
|
||||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||||
mock_response = httpx.Response(
|
mock_response = httpx.Response(
|
||||||
|
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
path = options.url
|
path = options.url
|
||||||
body = options.params or {}
|
body = options.params or {}
|
||||||
body |= options.json_data or {}
|
body |= options.json_data or {}
|
||||||
func = self.endpoint_impls.get(path)
|
func, path_params = self._find_matching_endpoint(options.method, path)
|
||||||
if not func:
|
body |= path_params
|
||||||
raise ValueError(f"No endpoint found for {path}")
|
|
||||||
|
|
||||||
body = self._convert_body(path, body)
|
body = self._convert_body(path, options.method, body)
|
||||||
|
|
||||||
async def gen():
|
async def gen():
|
||||||
async for chunk in await func(**body):
|
async for chunk in await func(**body):
|
||||||
|
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
)
|
)
|
||||||
return await response.parse()
|
return await response.parse()
|
||||||
|
|
||||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
def _convert_body(
|
||||||
|
self, path: str, method: str, body: Optional[dict] = None
|
||||||
|
) -> dict:
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
func = self.endpoint_impls[path]
|
func, _ = self._find_matching_endpoint(method, path)
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
|
|
||||||
# Strip NOT_GIVENs to use the defaults in signature
|
# Strip NOT_GIVENs to use the defaults in signature
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue