mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Remove the "ShieldType" concept (#430)
# What does this PR do? This PR kills the notion of "ShieldType". The impetus for this is the realization: > Why is keyword llama-guard appearing so many times everywhere, sometimes with hyphens, sometimes with underscores? Now that we have a notion of "provider specific resource identifiers" and "user specific aliases" for those and the fact that this works with models ("Llama3.1-8B-Instruct" <> "fireworks/llama-3pv1-..."), we can follow the same rules for Shields. So each Safety provider can make up a notion of identifiers it has registered. This already happens with Bedrock correctly. We just generalize it for Llama Guard, Prompt Guard, etc. For Llama Guard, we further simplify by just adopting the underlying model name itself as the identifier! No confusion necessary. While doing this, I noticed a bug in our DistributionRegistry where we weren't scoping identifiers by type. Fixed. ## Feature/Issue validation/testing/test plan Ran (inference, safety, memory, agents) tests with ollama and fireworks providers.
This commit is contained in:
parent
09269e2a44
commit
983d6ce2df
26 changed files with 150 additions and 209 deletions
|
@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
apiname, objtype = apiname_object()
|
||||
|
||||
# Get objects from disk registry
|
||||
objects = self.dist_registry.get_cached(routing_key)
|
||||
objects = self.dist_registry.get_cached(objtype, routing_key)
|
||||
if not objects:
|
||||
apiname, objname = apiname_object()
|
||||
provider_ids = list(self.impls_by_provider_id.keys())
|
||||
if len(provider_ids) > 1:
|
||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||
else:
|
||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||
raise ValueError(
|
||||
f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
|
||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||
)
|
||||
|
||||
for obj in objects:
|
||||
|
@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(
|
||||
self, identifier: str
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
objects = await self.dist_registry.get(identifier)
|
||||
objects = await self.dist_registry.get(type, identifier)
|
||||
if not objects:
|
||||
return None
|
||||
|
||||
# kind of ill-defined behavior here, but we'll just return the first one
|
||||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
# Get existing objects from registry
|
||||
existing_objects = await self.dist_registry.get(obj.identifier)
|
||||
existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
|
||||
|
||||
# Check for existing registration
|
||||
for existing_obj in existing_objects:
|
||||
|
@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return await self.get_all_with_type("model")
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[Model]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
return await self.get_object_by_identifier("model", identifier)
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
return await self.get_all_with_type(ResourceType.shield.value)
|
||||
|
||||
async def get_shield(self, identifier: str) -> Optional[Shield]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
return await self.get_object_by_identifier("shield", identifier)
|
||||
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
shield_type: ShieldType,
|
||||
provider_shield_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
|
@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
params = {}
|
||||
shield = Shield(
|
||||
identifier=shield_id,
|
||||
shield_type=shield_type,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
params=params,
|
||||
|
@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
return await self.get_all_with_type(ResourceType.memory_bank.value)
|
||||
|
||||
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
|
||||
return await self.get_object_by_identifier(memory_bank_id)
|
||||
return await self.get_object_by_identifier("memory_bank", memory_bank_id)
|
||||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
|
@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
return await self.get_all_with_type("dataset")
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
|
||||
return await self.get_object_by_identifier(dataset_id)
|
||||
return await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
|
@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
return await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
return await self.get_all_with_type("scoring_function")
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
|
||||
return await self.get_object_by_identifier(scoring_fn_id)
|
||||
return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
|
@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
|
|||
return await self.get_all_with_type("eval_task")
|
||||
|
||||
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
|
||||
return await self.get_object_by_identifier(name)
|
||||
return await self.get_object_by_identifier("eval_task", name)
|
||||
|
||||
async def register_eval_task(
|
||||
self,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue