forked from phoenix-oss/llama-stack-mirror
Fix shield_type and routing table breakage
This commit is contained in:
parent
657de08f04
commit
fb2678b134
6 changed files with 30 additions and 35 deletions
|
@ -13,14 +13,22 @@ apis:
|
||||||
- safety
|
- safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- provider_id: meta0
|
- provider_id: meta-reference-inference
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config:
|
config:
|
||||||
model: Llama3.1-8B-Instruct
|
model: Llama3.2-3B-Instruct
|
||||||
quantization: null
|
quantization: null
|
||||||
torch_seed: null
|
torch_seed: null
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
max_batch_size: 1
|
max_batch_size: 1
|
||||||
|
- provider_id: meta-reference-safety
|
||||||
|
provider_type: meta-reference
|
||||||
|
config:
|
||||||
|
model: Llama-Guard-3-1B
|
||||||
|
quantization: null
|
||||||
|
torch_seed: null
|
||||||
|
max_seq_len: 2048
|
||||||
|
max_batch_size: 1
|
||||||
safety:
|
safety:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
@ -28,10 +36,9 @@ providers:
|
||||||
llama_guard_shield:
|
llama_guard_shield:
|
||||||
model: Llama-Guard-3-1B
|
model: Llama-Guard-3-1B
|
||||||
excluded_categories: []
|
excluded_categories: []
|
||||||
disable_input_check: false
|
# Uncomment to use prompt guard
|
||||||
disable_output_check: false
|
# prompt_guard_shield:
|
||||||
prompt_guard_shield:
|
# model: Prompt-Guard-86M
|
||||||
model: Prompt-Guard-86M
|
|
||||||
memory:
|
memory:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
@ -52,7 +59,7 @@ providers:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ~/.llama/runtime/kvstore.db
|
db_path: ~/.llama/runtime/agents_store.db
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
|
@ -23,7 +23,7 @@ class ShieldDef(BaseModel):
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="A unique identifier for the shield type",
|
description="A unique identifier for the shield type",
|
||||||
)
|
)
|
||||||
type: str = Field(
|
shield_type: str = Field(
|
||||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||||
)
|
)
|
||||||
params: Dict[str, Any] = Field(
|
params: Dict[str, Any] = Field(
|
||||||
|
|
|
@ -178,13 +178,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await register_object_with_provider(obj, p)
|
await register_object_with_provider(obj, p)
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
|
return await self.dist_registry.get_all()
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||||
objects = []
|
return await self.get_all()
|
||||||
for objs in self.registry.values():
|
|
||||||
objects.extend(objs)
|
|
||||||
return objects
|
|
||||||
|
|
||||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||||
return self.get_object_by_identifier(identifier)
|
return self.get_object_by_identifier(identifier)
|
||||||
|
@ -195,10 +195,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
objects = []
|
return await self.get_all()
|
||||||
for objs in self.registry.values():
|
|
||||||
objects.extend(objs)
|
|
||||||
return objects
|
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||||
return self.get_object_by_identifier(shield_type)
|
return self.get_object_by_identifier(shield_type)
|
||||||
|
@ -209,10 +206,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||||
objects = []
|
return await self.get_all()
|
||||||
for objs in self.registry.values():
|
|
||||||
objects.extend(objs)
|
|
||||||
return objects
|
|
||||||
|
|
||||||
async def get_memory_bank(
|
async def get_memory_bank(
|
||||||
self, identifier: str
|
self, identifier: str
|
||||||
|
@ -227,10 +221,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||||
objects = []
|
return await self.get_all()
|
||||||
for objs in self.registry.values():
|
|
||||||
objects.extend(objs)
|
|
||||||
return objects
|
|
||||||
|
|
||||||
async def get_dataset(
|
async def get_dataset(
|
||||||
self, dataset_identifier: str
|
self, dataset_identifier: str
|
||||||
|
@ -243,10 +234,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||||
objects = []
|
return await self.get_all()
|
||||||
for objs in self.registry.values():
|
|
||||||
objects.extend(objs)
|
|
||||||
return objects
|
|
||||||
|
|
||||||
async def get_scoring_function(
|
async def get_scoring_function(
|
||||||
self, name: str
|
self, name: str
|
||||||
|
|
|
@ -37,7 +37,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
||||||
return [
|
return [
|
||||||
ShieldDef(
|
ShieldDef(
|
||||||
identifier=ShieldType.llama_guard.value,
|
identifier=ShieldType.llama_guard.value,
|
||||||
type=ShieldType.llama_guard.value,
|
shield_type=ShieldType.llama_guard.value,
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -25,8 +25,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDef) -> None:
|
async def register_shield(self, shield: ShieldDef) -> None:
|
||||||
if shield.type != ShieldType.code_scanner.value:
|
if shield.shield_type != ShieldType.code_scanner.value:
|
||||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -49,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
return [
|
return [
|
||||||
ShieldDef(
|
ShieldDef(
|
||||||
identifier=shield_type,
|
identifier=shield_type,
|
||||||
type=shield_type,
|
shield_type=shield_type,
|
||||||
params={},
|
params={},
|
||||||
)
|
)
|
||||||
for shield_type in self.available_shields
|
for shield_type in self.available_shields
|
||||||
|
@ -92,14 +92,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||||
if shield.type == ShieldType.llama_guard.value:
|
if shield.shield_type == ShieldType.llama_guard.value:
|
||||||
cfg = self.config.llama_guard_shield
|
cfg = self.config.llama_guard_shield
|
||||||
return LlamaGuardShield(
|
return LlamaGuardShield(
|
||||||
model=cfg.model,
|
model=cfg.model,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
excluded_categories=cfg.excluded_categories,
|
excluded_categories=cfg.excluded_categories,
|
||||||
)
|
)
|
||||||
elif shield.type == ShieldType.prompt_guard.value:
|
elif shield.shield_type == ShieldType.prompt_guard.value:
|
||||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||||
if subtype == "injection":
|
if subtype == "injection":
|
||||||
|
@ -109,4 +109,4 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
raise ValueError(f"Unknown shield type: {shield.shield_type}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue