This commit is contained in:
Ashwin Bharambe 2024-11-12 11:11:01 -08:00
parent 593ba80162
commit d0ad198be9
2 changed files with 4 additions and 1 deletions

View file

@ -32,6 +32,10 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if obj.provider_id == "remote":
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)

View file

@ -128,7 +128,6 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
print(f"Registering shield {shield}")
if shield.shield_type != ShieldType.llama_guard:
raise ValueError(f"Unsupported shield type: {shield.shield_type}")