From 386372dd246fc7c104fc017e4124f1aeb4bce57d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 31 Oct 2024 10:21:36 -0700 Subject: [PATCH] inference + memory + agents tests now pass with "remote" providers --- llama_stack/distribution/client.py | 124 +++++++++++------- llama_stack/distribution/resolver.py | 35 +++-- .../distribution/routers/routing_tables.py | 50 ++++--- llama_stack/providers/datatypes.py | 4 +- .../providers/tests/agents/test_agents.py | 3 +- .../providers/tests/memory/test_memory.py | 2 - 6 files changed, 127 insertions(+), 91 deletions(-) diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index cd8db9e8f..acc871f01 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -8,51 +8,51 @@ import inspect import json from collections.abc import AsyncIterator +from enum import Enum from typing import Any, get_args, get_origin, Type, Union import httpx - -from llama_models.schema_utils import WebMethod from pydantic import BaseModel, parse_obj_as from termcolor import cprint - -def extract_non_async_iterator_type(type_hint): - if get_origin(type_hint) is Union: - args = get_args(type_hint) - for arg in args: - if not issubclass(get_origin(arg) or arg, AsyncIterator): - return arg - return None - - -def extract_async_iterator_type(type_hint): - if get_origin(type_hint) is Union: - args = get_args(type_hint) - for arg in args: - if issubclass(get_origin(arg) or arg, AsyncIterator): - inner_args = get_args(arg) - return inner_args[0] - return None - +from llama_stack.providers.datatypes import RemoteProviderConfig _CLIENT_CLASSES = {} -def create_api_client_class(protocol) -> Type: +async def get_client_impl( + protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any +): + client_class = create_api_client_class(protocol, additional_protocol) + impl = client_class(config.url) + await impl.initialize() + return impl + + +def create_api_client_class(protocol, additional_protocol) -> Type: if protocol in _CLIENT_CLASSES: return _CLIENT_CLASSES[protocol] + protocols = [protocol, additional_protocol] if additional_protocol else [protocol] + class APIClient: def __init__(self, base_url: str): + print(f"({protocol.__name__}) Connecting to {base_url}") self.base_url = base_url.rstrip("/") self.routes = {} # Store routes for this protocol - for name, method in inspect.getmembers(protocol): - if hasattr(method, "__webmethod__"): - sig = inspect.signature(method) - self.routes[name] = (method.__webmethod__, sig) + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): + sig = inspect.signature(method) + self.routes[name] = (method.__webmethod__, sig) + + async def initialize(self): + pass + + async def shutdown(self): + pass async def __acall__(self, method_name: str, *args, **kwargs) -> Any: assert method_name in self.routes, f"Unknown endpoint: {method_name}" @@ -65,21 +65,23 @@ def create_api_client_class(protocol) -> Type: return await self._call_non_streaming(method_name, *args, **kwargs) async def _call_non_streaming(self, method_name: str, *args, **kwargs) -> Any: - webmethod, sig = self.routes[method_name] + _, sig = self.routes[method_name] - return_type = extract_non_async_iterator_type(sig.return_annotation) - assert ( - return_type - ), f"Could not extract return type for {sig.return_annotation}" - cprint(f"{return_type=}", "yellow") + if sig.return_annotation is None: + return_type = None + else: + return_type = extract_non_async_iterator_type(sig.return_annotation) + assert ( + return_type + ), f"Could not extract return type for {sig.return_annotation}" async with httpx.AsyncClient() as client: - params = self.httpx_request_params(webmethod, **kwargs) + params = self.httpx_request_params(method_name, *args, **kwargs) response = await client.request(**params) response.raise_for_status() j = response.json() - if not j: + if j is None: return None return parse_obj_as(return_type, j) @@ -90,10 +92,9 @@ def create_api_client_class(protocol) -> Type: assert ( return_type ), f"Could not extract return type for {sig.return_annotation}" - cprint(f"{return_type=}", "yellow") async with httpx.AsyncClient() as client: - params = self.httpx_request_params(webmethod, **kwargs) + params = self.httpx_request_params(method_name, *args, **kwargs) async with client.stream(**params) as response: response.raise_for_status() @@ -110,7 +111,15 @@ def create_api_client_class(protocol) -> Type: print(data) print(f"Error with parsing or validation: {e}") - def httpx_request_params(self, webmethod: WebMethod, **kwargs) -> dict: + def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: + webmethod, sig = self.routes[method_name] + + parameters = list(sig.parameters.values())[1:] # skip `self` + for i, param in enumerate(parameters): + if i >= len(args): + break + kwargs[param.name] = args[i] + url = f"{self.base_url}{webmethod.route}" def convert(value): @@ -119,7 +128,9 @@ def create_api_client_class(protocol) -> Type: elif isinstance(value, dict): return {k: convert(v) for k, v in value.items()} elif isinstance(value, BaseModel): - return json.loads(value.json()) + return json.loads(value.model_dump_json()) + elif isinstance(value, Enum): + return value.value else: return value @@ -140,16 +151,17 @@ def create_api_client_class(protocol) -> Type: ) # Add protocol methods to the wrapper - for name, method in inspect.getmembers(protocol): - if hasattr(method, "__webmethod__"): + for p in protocols: + for name, method in inspect.getmembers(p): + if hasattr(method, "__webmethod__"): - async def method_impl(self, *args, method_name=name, **kwargs): - return await self.__acall__(method_name, *args, **kwargs) + async def method_impl(self, *args, method_name=name, **kwargs): + return await self.__acall__(method_name, *args, **kwargs) - method_impl.__name__ = name - method_impl.__qualname__ = f"APIClient.{name}" - method_impl.__signature__ = inspect.signature(method) - setattr(APIClient, name, method_impl) + method_impl.__name__ = name + method_impl.__qualname__ = f"APIClient.{name}" + method_impl.__signature__ = inspect.signature(method) + setattr(APIClient, name, method_impl) # Name the class after the protocol APIClient.__name__ = f"{protocol.__name__}Client" @@ -157,6 +169,26 @@ def create_api_client_class(protocol) -> Type: return APIClient +# not quite general these methods are +def extract_non_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if not issubclass(get_origin(arg) or arg, AsyncIterator): + return arg + return type_hint + + +def extract_async_iterator_type(type_hint): + if get_origin(type_hint) is Union: + args = get_args(type_hint) + for arg in args: + if issubclass(get_origin(arg) or arg, AsyncIterator): + inner_args = get_args(arg) + return inner_args[0] + return None + + async def example(model: str = None): from llama_stack.apis.inference import Inference, UserMessage # noqa: F403 from llama_stack.apis.inference.event_logger import EventLogger diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index bab807da9..a93cc1183 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -40,19 +40,21 @@ def api_protocol_map() -> Dict[Api, Any]: Api.safety: Safety, Api.shields: Shields, Api.telemetry: Telemetry, - Api.datasets: Datasets, Api.datasetio: DatasetIO, - Api.scoring_functions: ScoringFunctions, + Api.datasets: Datasets, Api.scoring: Scoring, + Api.scoring_functions: ScoringFunctions, Api.eval: Eval, } def additional_protocols_map() -> Dict[Api, Any]: return { - Api.inference: ModelsProtocolPrivate, - Api.memory: MemoryBanksProtocolPrivate, - Api.safety: ShieldsProtocolPrivate, + Api.inference: (ModelsProtocolPrivate, Models), + Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks), + Api.safety: (ShieldsProtocolPrivate, Shields), + Api.datasetio: (DatasetsProtocolPrivate, Datasets), + Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), } @@ -112,8 +114,6 @@ async def resolve_impls( if info.router_api.value not in apis_to_serve: continue - available_providers = providers_with_specs[f"inner-{info.router_api.value}"] - providers_with_specs[info.routing_table_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__routing_table__", @@ -246,14 +246,21 @@ async def instantiate_provider( args = [] if isinstance(provider_spec, RemoteProviderSpec): - if provider_spec.adapter: - method = "get_adapter_impl" - else: - method = "get_client_impl" - config_type = instantiate_class_type(provider_spec.config_class) config = config_type(**provider.config) - args = [config, deps] + + if provider_spec.adapter: + method = "get_adapter_impl" + args = [config, deps] + else: + method = "get_client_impl" + protocol = protocols[provider_spec.api] + if provider_spec.api in additional_protocols: + _, additional_protocol = additional_protocols[provider_spec.api] + else: + additional_protocol = None + args = [protocol, additional_protocol, config, deps] + elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" @@ -282,7 +289,7 @@ async def instantiate_provider( not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols ): - additional_api = additional_protocols[provider_spec.api] + additional_api, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) return impl diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3e07b9162..4e462c54b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -22,6 +22,13 @@ def get_impl_api(p: Any) -> Api: async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: api = get_impl_api(p) + + if obj.provider_id == "remote": + # if this is just a passthrough, we want to let the remote + # end actually do the registration with the correct provider + obj = obj.model_copy(deep=True) + obj.provider_id = "" + if api == Api.inference: await p.register_model(obj) elif api == Api.safety: @@ -51,11 +58,22 @@ class CommonRoutingTableImpl(RoutingTable): async def initialize(self) -> None: self.registry: Registry = {} - def add_objects(objs: List[RoutableObjectWithProvider]) -> None: + def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if obj.identifier not in self.registry: self.registry[obj.identifier] = [] + if cls is None: + obj.provider_id = provider_id + else: + if provider_id == "remote": + # if this is just a passthrough, we got the *WithProvider object + # so we should just override the provider in-place + obj.provider_id = provider_id + else: + obj = cls(**obj.model_dump(), provider_id=provider_id) self.registry[obj.identifier].append(obj) for pid, p in self.impls_by_provider_id.items(): @@ -63,47 +81,27 @@ class CommonRoutingTableImpl(RoutingTable): if api == Api.inference: p.model_store = self models = await p.list_models() - add_objects( - [ModelDefWithProvider(**m.dict(), provider_id=pid) for m in models] - ) + add_objects(models, pid, ModelDefWithProvider) elif api == Api.safety: p.shield_store = self shields = await p.list_shields() - add_objects( - [ - ShieldDefWithProvider(**s.dict(), provider_id=pid) - for s in shields - ] - ) + add_objects(shields, pid, ShieldDefWithProvider) elif api == Api.memory: p.memory_bank_store = self memory_banks = await p.list_memory_banks() - - # do in-memory updates due to pesky Annotated unions - for m in memory_banks: - m.provider_id = pid - - add_objects(memory_banks) + add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self datasets = await p.list_datasets() - - # do in-memory updates due to pesky Annotated unions - for d in datasets: - d.provider_id = pid + add_objects(datasets, pid, DatasetDefWithProvider) elif api == Api.scoring: p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() - add_objects( - [ - ScoringFnDefWithProvider(**s.dict(), provider_id=pid) - for s in scoring_functions - ] - ) + add_objects(scoring_functions, pid, ScoringFnDefWithProvider) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index eace0ea1a..9a37a28a9 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -60,7 +60,7 @@ class MemoryBanksProtocolPrivate(Protocol): class DatasetsProtocolPrivate(Protocol): async def list_datasets(self) -> List[DatasetDef]: ... - async def register_datasets(self, dataset_def: DatasetDef) -> None: ... + async def register_dataset(self, dataset_def: DatasetDef) -> None: ... class ScoringFunctionsProtocolPrivate(Protocol): @@ -171,7 +171,7 @@ as being "Llama Stack compatible" def module(self) -> str: if self.adapter: return self.adapter.module - return f"llama_stack.apis.{self.api.value}.client" + return "llama_stack.distribution.client" @property def pip_packages(self) -> List[str]: diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 9c34c3a28..c09db3d20 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -26,6 +26,7 @@ from dotenv import load_dotenv # # ```bash # PROVIDER_ID= \ +# MODEL_ID= \ # PROVIDER_CONFIG=provider_config.yaml \ # pytest -s llama_stack/providers/tests/agents/test_agents.py \ # --tb=short --disable-warnings @@ -44,7 +45,7 @@ async def agents_settings(): "impl": impls[Api.agents], "memory_impl": impls[Api.memory], "common_params": { - "model": "Llama3.1-8B-Instruct", + "model": os.environ["MODEL_ID"] or "Llama3.1-8B-Instruct", "instructions": "You are a helpful assistant.", }, } diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index b26bf75a7..d83601de1 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os import pytest import pytest_asyncio @@ -73,7 +72,6 @@ async def register_memory_bank(banks_impl: MemoryBanks): embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - provider_id=os.environ["PROVIDER_ID"], ) await banks_impl.register_memory_bank(bank)