inference + memory + agents tests now pass with "remote" providers

This commit is contained in:
Ashwin Bharambe 2024-10-31 10:21:36 -07:00
parent fc66131fea
commit 386372dd24
6 changed files with 127 additions and 91 deletions

View file

@ -8,51 +8,51 @@ import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, get_args, get_origin, Type, Union
import httpx
from llama_models.schema_utils import WebMethod
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return None
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
def create_api_client_class(protocol) -> Type:
async def get_client_impl(
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
):
client_class = create_api_client_class(protocol, additional_protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol, additional_protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
self.base_url = base_url.rstrip("/")
self.routes = {}
# Store routes for this protocol
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
async def shutdown(self):
pass
async def __acall__(self, method_name: str, *args, **kwargs) -> Any:
assert method_name in self.routes, f"Unknown endpoint: {method_name}"
@ -65,21 +65,23 @@ def create_api_client_class(protocol) -> Type:
return await self._call_non_streaming(method_name, *args, **kwargs)
async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any:
webmethod, sig = self.routes[method_name]
_, sig = self.routes[method_name]
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
cprint(f"{return_type=}", "yellow")
if sig.return_annotation is None:
return_type = None
else:
return_type = extract_non_async_iterator_type(sig.return_annotation)
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(webmethod, **kwargs)
params = self.httpx_request_params(method_name, *args, **kwargs)
response = await client.request(**params)
response.raise_for_status()
j = response.json()
if not j:
if j is None:
return None
return parse_obj_as(return_type, j)
@ -90,10 +92,9 @@ def create_api_client_class(protocol) -> Type:
assert (
return_type
), f"Could not extract return type for {sig.return_annotation}"
cprint(f"{return_type=}", "yellow")
async with httpx.AsyncClient() as client:
params = self.httpx_request_params(webmethod, **kwargs)
params = self.httpx_request_params(method_name, *args, **kwargs)
async with client.stream(**params) as response:
response.raise_for_status()
@ -110,7 +111,15 @@ def create_api_client_class(protocol) -> Type:
print(data)
print(f"Error with parsing or validation: {e}")
def httpx_request_params(self, webmethod: WebMethod, **kwargs) -> dict:
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name]
parameters = list(sig.parameters.values())[1:] # skip `self`
for i, param in enumerate(parameters):
if i >= len(args):
break
kwargs[param.name] = args[i]
url = f"{self.base_url}{webmethod.route}"
def convert(value):
@ -119,7 +128,9 @@ def create_api_client_class(protocol) -> Type:
elif isinstance(value, dict):
return {k: convert(v) for k, v in value.items()}
elif isinstance(value, BaseModel):
return json.loads(value.json())
return json.loads(value.model_dump_json())
elif isinstance(value, Enum):
return value.value
else:
return value
@ -140,16 +151,17 @@ def create_api_client_class(protocol) -> Type:
)
# Add protocol methods to the wrapper
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"
@ -157,6 +169,26 @@ def create_api_client_class(protocol) -> Type:
return APIClient
# not quite general these methods are
def extract_non_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if not issubclass(get_origin(arg) or arg, AsyncIterator):
return arg
return type_hint
def extract_async_iterator_type(type_hint):
if get_origin(type_hint) is Union:
args = get_args(type_hint)
for arg in args:
if issubclass(get_origin(arg) or arg, AsyncIterator):
inner_args = get_args(arg)
return inner_args[0]
return None
async def example(model: str = None):
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
from llama_stack.apis.inference.event_logger import EventLogger