From 4f3b009980ee3a2d47ea05710405c5720ba94d03 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 19:47:46 -0800 Subject: [PATCH] undo some unintentional changes, clean up some stuff --- llama_stack/distribution/client.py | 36 ++++++++----------- llama_stack/distribution/resolver.py | 2 -- .../distribution/routers/routing_tables.py | 11 +++--- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index ce788a713..b36ef94e4 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig _CLIENT_CLASSES = {} -async def get_client_impl( - protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any -): - client_class = create_api_client_class(protocol, additional_protocol) +async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any): + client_class = create_api_client_class(protocol) impl = client_class(config.url) await impl.initialize() return impl -def create_api_client_class(protocol, additional_protocol) -> Type: +def create_api_client_class(protocol) -> Type: if protocol in _CLIENT_CLASSES: return _CLIENT_CLASSES[protocol] - protocols = [protocol, additional_protocol] if additional_protocol else [protocol] - class APIClient: def __init__(self, base_url: str): print(f"({protocol.__name__}) Connecting to {base_url}") @@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type: self.routes = {} # Store routes for this protocol - for p in protocols: - for name, method in inspect.getmembers(p): - if hasattr(method, "__webmethod__"): - sig = inspect.signature(method) - self.routes[name] = (method.__webmethod__, sig) + for name, method in inspect.getmembers(protocol): + if hasattr(method, "__webmethod__"): + sig = inspect.signature(method) + self.routes[name] = (method.__webmethod__, sig) async def initialize(self): pass @@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type: return ret # Add protocol methods to the wrapper - for p in protocols: - for name, method in inspect.getmembers(p): - if hasattr(method, "__webmethod__"): + for name, method in inspect.getmembers(protocol): + if hasattr(method, "__webmethod__"): - async def method_impl(self, *args, method_name=name, **kwargs): - return await self.__acall__(method_name, *args, **kwargs) + async def method_impl(self, *args, method_name=name, **kwargs): + return await self.__acall__(method_name, *args, **kwargs) - method_impl.__name__ = name - method_impl.__qualname__ = f"APIClient.{name}" - method_impl.__signature__ = inspect.signature(method) - setattr(APIClient, name, method_impl) + method_impl.__name__ = name + method_impl.__qualname__ = f"APIClient.{name}" + method_impl.__signature__ = inspect.signature(method) + setattr(APIClient, name, method_impl) # Name the class after the protocol APIClient.__name__ = f"{protocol.__name__}Client" diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index d00aedb5c..b95cc5418 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -369,7 +369,6 @@ async def resolve_remote_stack_impls( api = Api(api_str) impls[api] = await get_client_impl( protocols[api], - None, config, {}, ) @@ -377,7 +376,6 @@ async def resolve_remote_stack_impls( _, additional_protocol, additional_api = additional_protocols[api] impls[additional_api] = await get_client_impl( additional_protocol, - None, config, {}, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4bdeb608a..7b7433862 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -38,15 +38,15 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable if api == Api.inference: return await p.register_model(obj) elif api == Api.safety: - await p.register_shield(**obj.model_dump()) + await p.register_shield(obj) elif api == Api.memory: - await p.register_memory_bank(**obj.model_dump()) + await p.register_memory_bank(obj) elif api == Api.datasetio: - await p.register_dataset(**obj.model_dump()) + await p.register_dataset(obj) elif api == Api.scoring: - await p.register_scoring_function(**obj.model_dump()) + await p.register_scoring_function(obj) elif api == Api.eval: - await p.register_eval_task(**obj.model_dump()) + await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -95,7 +95,6 @@ class CommonRoutingTableImpl(RoutingTable): p.scoring_function_store = self scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: p.eval_task_store = self