mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
|||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(
|
||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||
):
|
||||
client_class = create_api_client_class(protocol, additional_protocol)
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
|
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue