undo some unintentional changes, clean up some stuff

This commit is contained in:
Ashwin Bharambe 2024-11-12 19:47:46 -08:00
parent 0121114a5d
commit 4f3b009980
3 changed files with 20 additions and 29 deletions

View file

@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
):
client_class = create_api_client_class(protocol, additional_protocol)
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol, additional_protocol) -> Type:
def create_api_client_class(protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
self.routes = {}
# Store routes for this protocol
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
return ret
# Add protocol methods to the wrapper
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"

View file

@ -369,7 +369,6 @@ async def resolve_remote_stack_impls(
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
config,
{},
)
@ -377,7 +376,6 @@ async def resolve_remote_stack_impls(
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
config,
{},
)

View file

@ -38,15 +38,15 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(**obj.model_dump())
await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(**obj.model_dump())
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(**obj.model_dump())
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(**obj.model_dump())
await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(**obj.model_dump())
await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -95,7 +95,6 @@ class CommonRoutingTableImpl(RoutingTable):
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self