Remove the "ShieldType" concept (#430)

# What does this PR do?

This PR kills the notion of "ShieldType". The impetus for this is the
realization:

> Why is keyword llama-guard appearing so many times everywhere,
sometimes with hyphens, sometimes with underscores?

Now that we have a notion of "provider specific resource identifiers"
and "user specific aliases" for those and the fact that this works with
models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can
follow the same rules for Shields.

So each Safety provider can make up a notion of identifiers it has
registered. This already happens with Bedrock correctly. We just
generalize it for Llama Guard, Prompt Guard, etc.

For Llama Guard, we further simplify by just adopting the underlying
model name itself as the identifier! No confusion necessary.

While doing this, I noticed a bug in our DistributionRegistry where we
weren't scoping identifiers by type. Fixed.

## Feature/Issue validation/testing/test plan

Ran (inference, safety, memory, agents) tests with ollama and fireworks
providers.
This commit is contained in:
Ashwin Bharambe 2024-11-12 12:37:24 -08:00 committed by GitHub
parent 09269e2a44
commit 983d6ce2df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 150 additions and 209 deletions

View file

@ -172,13 +172,12 @@ class SafetyRouter(Safety):
async def register_shield(
self,
shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
shield_id, shield_type, provider_shield_id, provider_id, params
shield_id, provider_shield_id, provider_id, params
)
async def run_shield(

View file

@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable):
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
objects = self.dist_registry.get_cached(routing_key)
objects = self.dist_registry.get_cached(objtype, routing_key)
if not objects:
apiname, objname = apiname_object()
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
for obj in objects:
@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(
self, identifier: str
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
objects = await self.dist_registry.get(identifier)
objects = await self.dist_registry.get(type, identifier)
if not objects:
return None
# kind of ill-defined behavior here, but we'll just return the first one
assert len(objects) == 1
return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider):
# Get existing objects from registry
existing_objects = await self.dist_registry.get(obj.identifier)
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
# Check for existing registration
for existing_obj in existing_objects:
@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return await self.get_all_with_type("model")
async def get_model(self, identifier: str) -> Optional[Model]:
return await self.get_object_by_identifier(identifier)
return await self.get_object_by_identifier("model", identifier)
async def register_model(
self,
@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
return await self.get_object_by_identifier(identifier)
return await self.get_object_by_identifier("shield", identifier)
async def register_shield(
self,
shield_id: str,
shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
params = {}
shield = Shield(
identifier=shield_id,
shield_type=shield_type,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return await self.get_object_by_identifier(memory_bank_id)
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
async def register_memory_bank(
self,
@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
return await self.get_all_with_type("dataset")
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
return await self.get_object_by_identifier(dataset_id)
return await self.get_object_by_identifier("dataset", dataset_id)
async def register_dataset(
self,
@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
return await self.get_all_with_type(ResourceType.scoring_function.value)
return await self.get_all_with_type("scoring_function")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
return await self.get_object_by_identifier(scoring_fn_id)
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def register_scoring_function(
self,
@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier(name)
return await self.get_object_by_identifier("eval_task", name)
async def register_eval_task(
self,

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from typing import Any, Dict
from termcolor import colored
from termcolor import colored

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Optional, Protocol
from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
@ -35,7 +35,8 @@ class DistributionRegistry(Protocol):
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
KEY_FORMAT = "distributions:registry:v1::{}"
KEY_VERSION = "v1"
KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
class DiskDistributionRegistry(DistributionRegistry):
@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None:
pass
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
# Disk registry does not have a cache
return []
async def get_all(self) -> List[RoutableObjectWithProvider]:
start_key = KEY_FORMAT.format("")
end_key = KEY_FORMAT.format("\xff")
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
return [await self.get(key.split(":")[-1]) for key in keys]
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
return [await self.get(type, identifier) for type, identifier in tuples]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
json_str = await self.kvstore.get(
KEY_FORMAT.format(type=type, identifier=identifier)
)
if not json_str:
return []
@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry):
]
async def register(self, obj: RoutableObjectWithProvider) -> bool:
existing_objects = await self.get(obj.identifier)
existing_objects = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
for eobj in existing_objects:
if eobj.provider_id == obj.provider_id:
@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry):
obj.model_dump_json() for obj in existing_objects
] # Fixed variable name
await self.kvstore.set(
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
json.dumps(objects_json),
)
return True
@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry):
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
self.cache: Dict[str, List[RoutableObjectWithProvider]] = {}
self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
async def initialize(self) -> None:
start_key = KEY_FORMAT.format("")
end_key = KEY_FORMAT.format("\xff")
start_key = KEY_FORMAT.format(type="", identifier="")
end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
for key in keys:
identifier = key.split(":")[-1]
objects = await super().get(identifier)
type, identifier = key.split(":")[-2:]
objects = await super().get(type, identifier)
if objects:
self.cache[identifier] = objects
self.cache[type, identifier] = objects
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
return self.cache.get(identifier, [])
def get_cached(
self, type: str, identifier: str
) -> List[RoutableObjectWithProvider]:
return self.cache.get((type, identifier), [])
async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist]
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
if identifier in self.cache:
return self.cache[identifier]
async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
cachekey = (type, identifier)
if cachekey in self.cache:
return self.cache[cachekey]
objects = await super().get(identifier)
objects = await super().get(type, identifier)
if objects:
self.cache[identifier] = objects
self.cache[cachekey] = objects
return objects
@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
if success:
# Then update cache
if obj.identifier not in self.cache:
self.cache[obj.identifier] = []
cachekey = (obj.type, obj.identifier)
if cachekey not in self.cache:
self.cache[cachekey] = []
# Check if provider already exists in cache
for cached_obj in self.cache[obj.identifier]:
for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id:
return success
# If not, update cache
self.cache[obj.identifier].append(obj)
self.cache[cachekey].append(obj)
return success