More work towards making remote stacks usable from tests

This commit is contained in:
Ashwin Bharambe 2024-11-12 17:09:31 -08:00
parent 8645f8bc9e
commit 8b7be87bec
7 changed files with 91 additions and 99 deletions

View file

@ -33,83 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
is_remote = obj.provider_id == "remote"
if is_remote:
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
# and (b) MemoryBankInput is missing BankParams
if isinstance(obj, Model):
obj = ModelInput(
model_id=obj.identifier,
metadata=obj.metadata,
provider_model_id=obj.provider_resource_id,
)
elif isinstance(obj, Shield):
obj = ShieldInput(
shield_id=obj.identifier,
params=obj.params,
provider_shield_id=obj.provider_resource_id,
)
elif isinstance(obj, MemoryBank):
# need to calculate params here
obj = MemoryBankInput(
memory_bank_id=obj.identifier,
provider_memory_bank_id=obj.provider_resource_id,
)
elif isinstance(obj, ScoringFn):
obj = ScoringFnInput(
scoring_fn_id=obj.identifier,
provider_scoring_fn_id=obj.provider_resource_id,
description=obj.description,
metadata=obj.metadata,
return_type=obj.return_type,
params=obj.params,
)
elif isinstance(obj, EvalTask):
obj = EvalTaskInput(
eval_task_id=obj.identifier,
provider_eval_task_id=obj.provider_resource_id,
dataset_id=obj.dataset_id,
scoring_function_id=obj.scoring_functions,
metadata=obj.metadata,
)
elif isinstance(obj, Dataset):
obj = DatasetInput(
dataset_id=obj.identifier,
provider_dataset_id=obj.provider_resource_id,
schema=obj.schema,
url=obj.url,
metadata=obj.metadata,
)
else:
raise ValueError(f"Unknown object type {type(obj)}")
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
if is_remote:
await p.register_shield(**obj.model_dump())
else:
await p.register_shield(obj)
await p.register_shield(**obj.model_dump())
elif api == Api.memory:
if is_remote:
await p.register_memory_bank(**obj.model_dump())
else:
await p.register_memory_bank(obj)
await p.register_memory_bank(**obj.model_dump())
elif api == Api.datasetio:
if is_remote:
await p.register_dataset(**obj.model_dump())
else:
await p.register_dataset(obj)
await p.register_dataset(**obj.model_dump())
elif api == Api.scoring:
if is_remote:
await p.register_scoring_function(**obj.model_dump())
else:
await p.register_scoring_function(obj)
await p.register_scoring_function(**obj.model_dump())
elif api == Api.eval:
if is_remote:
await p.register_eval_task(**obj.model_dump())
else:
await p.register_eval_task(obj)
await p.register_eval_task(**obj.model_dump())
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -137,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers