mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
fix routing in library client
This commit is contained in:
parent
27e07b44b5
commit
85a3fcee8e
1 changed files with 49 additions and 13 deletions
|
@ -9,6 +9,7 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
@ -232,13 +233,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in self.impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
endpoint_impls[endpoint.route] = func
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][
|
||||
_convert_path_to_regex(endpoint.route)
|
||||
] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
return True
|
||||
|
@ -273,6 +284,32 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
await end_trace()
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = self.endpoint_impls.get(method)
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, func in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
@ -280,15 +317,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
options: Any,
|
||||
):
|
||||
path = options.url
|
||||
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
result = await func(**body)
|
||||
matched_func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
result = await matched_func(**body)
|
||||
|
||||
json_content = json.dumps(convert_pydantic_to_json_value(result))
|
||||
mock_response = httpx.Response(
|
||||
|
@ -325,11 +360,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
func, path_params = self._find_matching_endpoint(options.method, path)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
async def gen():
|
||||
async for chunk in await func(**body):
|
||||
|
@ -367,11 +401,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return await response.parse()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
def _convert_body(
|
||||
self, path: str, method: str, body: Optional[dict] = None
|
||||
) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
func = self.endpoint_impls[path]
|
||||
func, _ = self._find_matching_endpoint(method, path)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue