mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# TODO: this is broken right now because we use the generic
|
||||
# { identifier, provider_id, provider_resource_id } tuple here
|
||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
await p.register_eval_task(obj)
|
||||
return await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue