Kill "remote" providers and fix testing with a remote stack properly (#435)

# What does this PR do?

This PR kills the notion of "pure passthrough" remote providers. You
cannot specify a single provider you must specify a whole distribution
(stack) as remote.

This PR also significantly fixes / upgrades testing infrastructure so
you can now test against a remotely hosted stack server by just doing

```bash
pytest -s -v -m remote  test_agents.py \
  --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \
  --env REMOTE_STACK_URL=http://localhost:5001
```

Also fixed `test_agents_persistence.py` (which was broken) and killed
some deprecated testing functions.

## Test Plan

All the tests.
This commit is contained in:
Ashwin Bharambe 2024-11-12 21:51:29 -08:00 committed by GitHub
parent 59a65e34d3
commit 12947ac19e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 406 additions and 519 deletions

View file

@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
if obj.provider_id == "remote":
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
return await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
return await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
return await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
return await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
return await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self