undo some unintentional changes, clean up some stuff

This commit is contained in:
Ashwin Bharambe 2024-11-12 19:47:46 -08:00
parent 0121114a5d
commit 4f3b009980
3 changed files with 20 additions and 29 deletions

View file

@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
):
client_class = create_api_client_class(protocol, additional_protocol)
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol, additional_protocol) -> Type:
def create_api_client_class(protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
self.routes = {}
# Store routes for this protocol
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
return ret
# Add protocol methods to the wrapper
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"