diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 79701d926..fb02dd136 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant
- [ ] Addresses issue (#issue)
-## Feature/Issue validation/testing/test plan
+
+## Test Plan
Please describe:
- tests you ran to verify your changes with result summaries.
diff --git a/docs/_deprecating_soon.ipynb b/docs/_deprecating_soon.ipynb
index 343005962..7fa4034ce 100644
--- a/docs/_deprecating_soon.ipynb
+++ b/docs/_deprecating_soon.ipynb
@@ -180,8 +180,8 @@
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
- " input_shields=[\"llama_guard\"],\n",
- " output_shields=[\"llama_guard\"],\n",
+ " input_shields=[\"Llama-Guard-3-1B\"],\n",
+ " output_shields=[\"Llama-Guard-3-1B\"],\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html
index 231633464..7ef4ece21 100644
--- a/docs/resources/llama-stack-spec.html
+++ b/docs/resources/llama-stack-spec.html
@@ -21,7 +21,7 @@
"info": {
"title": "[DRAFT] Llama Stack Specification",
"version": "0.0.1",
- "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
+ "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
},
"servers": [
{
@@ -5743,9 +5743,6 @@
"const": "shield",
"default": "shield"
},
- "shield_type": {
- "$ref": "#/components/schemas/ShieldType"
- },
"params": {
"type": "object",
"additionalProperties": {
@@ -5777,20 +5774,10 @@
"identifier",
"provider_resource_id",
"provider_id",
- "type",
- "shield_type"
+ "type"
],
"title": "A safety shield resource that can be used to check content"
},
- "ShieldType": {
- "type": "string",
- "enum": [
- "generic_content_shield",
- "llama_guard",
- "code_scanner",
- "prompt_guard"
- ]
- },
"Trace": {
"type": "object",
"properties": {
@@ -7262,9 +7249,6 @@
"shield_id": {
"type": "string"
},
- "shield_type": {
- "$ref": "#/components/schemas/ShieldType"
- },
"provider_shield_id": {
"type": "string"
},
@@ -7299,8 +7283,7 @@
},
"additionalProperties": false,
"required": [
- "shield_id",
- "shield_type"
+ "shield_id"
]
},
"RunEvalRequest": {
@@ -7854,13 +7837,19 @@
],
"tags": [
{
- "name": "Inference"
+ "name": "MemoryBanks"
+ },
+ {
+ "name": "BatchInference"
},
{
"name": "Agents"
},
{
- "name": "Telemetry"
+ "name": "Inference"
+ },
+ {
+ "name": "DatasetIO"
},
{
"name": "Eval"
@@ -7869,43 +7858,37 @@
"name": "Models"
},
{
- "name": "Inspect"
- },
- {
- "name": "EvalTasks"
+ "name": "PostTraining"
},
{
"name": "ScoringFunctions"
},
{
- "name": "Memory"
- },
- {
- "name": "Safety"
- },
- {
- "name": "DatasetIO"
- },
- {
- "name": "MemoryBanks"
+ "name": "Datasets"
},
{
"name": "Shields"
},
{
- "name": "PostTraining"
+ "name": "Telemetry"
},
{
- "name": "Datasets"
+ "name": "Inspect"
},
{
- "name": "Scoring"
+ "name": "Safety"
},
{
"name": "SyntheticDataGeneration"
},
{
- "name": "BatchInference"
+ "name": "Memory"
+ },
+ {
+ "name": "Scoring"
+ },
+ {
+ "name": "EvalTasks"
},
{
"name": "BuiltinTool",
@@ -8255,10 +8238,6 @@
"name": "Shield",
"description": "A safety shield resource that can be used to check content\n\n"
},
- {
- "name": "ShieldType",
- "description": ""
- },
{
"name": "Trace",
"description": ""
@@ -8614,7 +8593,6 @@
"Session",
"Shield",
"ShieldCallStep",
- "ShieldType",
"SpanEndPayload",
"SpanStartPayload",
"SpanStatus",
diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml
index 4e02e8075..b86c0df61 100644
--- a/docs/resources/llama-stack-spec.yaml
+++ b/docs/resources/llama-stack-spec.yaml
@@ -2227,11 +2227,8 @@ components:
type: string
shield_id:
type: string
- shield_type:
- $ref: '#/components/schemas/ShieldType'
required:
- shield_id
- - shield_type
type: object
RestAPIExecutionConfig:
additionalProperties: false
@@ -2698,8 +2695,6 @@ components:
type: string
provider_resource_id:
type: string
- shield_type:
- $ref: '#/components/schemas/ShieldType'
type:
const: shield
default: shield
@@ -2709,7 +2704,6 @@ components:
- provider_resource_id
- provider_id
- type
- - shield_type
title: A safety shield resource that can be used to check content
type: object
ShieldCallStep:
@@ -2736,13 +2730,6 @@ components:
- step_id
- step_type
type: object
- ShieldType:
- enum:
- - generic_content_shield
- - llama_guard
- - code_scanner
- - prompt_guard
- type: string
SpanEndPayload:
additionalProperties: false
properties:
@@ -3397,7 +3384,7 @@ info:
description: "This is the specification of the llama stack that provides\n \
\ a set of endpoints and their corresponding interfaces that are tailored\
\ to\n best leverage Llama Models. The specification is still in\
- \ draft and subject to change.\n Generated at 2024-11-12 11:16:58.657871"
+ \ draft and subject to change.\n Generated at 2024-11-12 11:39:48.665782"
title: '[DRAFT] Llama Stack Specification'
version: 0.0.1
jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema
@@ -4761,24 +4748,24 @@ security:
servers:
- url: http://any-hosted-llama-stack.com
tags:
-- name: Inference
+- name: MemoryBanks
+- name: BatchInference
- name: Agents
-- name: Telemetry
+- name: Inference
+- name: DatasetIO
- name: Eval
- name: Models
-- name: Inspect
-- name: EvalTasks
-- name: ScoringFunctions
-- name: Memory
-- name: Safety
-- name: DatasetIO
-- name: MemoryBanks
-- name: Shields
- name: PostTraining
+- name: ScoringFunctions
- name: Datasets
-- name: Scoring
+- name: Shields
+- name: Telemetry
+- name: Inspect
+- name: Safety
- name: SyntheticDataGeneration
-- name: BatchInference
+- name: Memory
+- name: Scoring
+- name: EvalTasks
- description:
name: BuiltinTool
- description: '
name: Shield
-- description:
- name: ShieldType
- description:
name: Trace
- description: 'Checkpoint created during training runs
@@ -5343,7 +5328,6 @@ x-tagGroups:
- Session
- Shield
- ShieldCallStep
- - ShieldType
- SpanEndPayload
- SpanStartPayload
- SpanStatus
diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb
index e1e9301d3..f5352627e 100644
--- a/docs/zero_to_hero_guide/06_Safety101.ipynb
+++ b/docs/zero_to_hero_guide/06_Safety101.ipynb
@@ -182,13 +182,13 @@
" pass\n",
"\n",
" async def run_shield(\n",
- " self, shield_type: str, messages: List[dict]\n",
+ " self, shield_id: str, messages: List[dict]\n",
" ) -> RunShieldResponse:\n",
" async with httpx.AsyncClient() as client:\n",
" response = await client.post(\n",
" f\"{self.base_url}/safety/run_shield\",\n",
" json=dict(\n",
- " shield_type=shield_type,\n",
+ " shield_id=shield_id,\n",
" messages=[encodable_dict(m) for m in messages],\n",
" ),\n",
" headers={\n",
@@ -216,7 +216,7 @@
" ]:\n",
" cprint(f\"User>{message['content']}\", \"green\")\n",
" response = await client.run_shield(\n",
- " shield_type=\"llama_guard\",\n",
+ " shield_id=\"Llama-Guard-3-1B\",\n",
" messages=[message],\n",
" )\n",
" print(response)\n",
diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py
index f0f02b3c5..2dc74e6ec 100644
--- a/llama_stack/apis/datasets/datasets.py
+++ b/llama_stack/apis/datasets/datasets.py
@@ -38,7 +38,6 @@ class Dataset(CommonDatasetFields, Resource):
return self.provider_resource_id
-@json_schema_type
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py
index 10c35c3ee..940dafc06 100644
--- a/llama_stack/apis/eval_tasks/eval_tasks.py
+++ b/llama_stack/apis/eval_tasks/eval_tasks.py
@@ -34,7 +34,6 @@ class EvalTask(CommonEvalTaskFields, Resource):
return self.provider_resource_id
-@json_schema_type
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py
index 83b292612..c0a0c643a 100644
--- a/llama_stack/apis/memory_banks/memory_banks.py
+++ b/llama_stack/apis/memory_banks/memory_banks.py
@@ -122,7 +122,6 @@ MemoryBank = Annotated[
]
-@json_schema_type
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams
diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py
index a5d226886..2cd12b4bc 100644
--- a/llama_stack/apis/models/models.py
+++ b/llama_stack/apis/models/models.py
@@ -32,7 +32,6 @@ class Model(CommonModelFields, Resource):
return self.provider_resource_id
-@json_schema_type
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py
index 96168fedd..d7d4bc981 100644
--- a/llama_stack/apis/safety/client.py
+++ b/llama_stack/apis/safety/client.py
@@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
def encodable_dict(d: BaseModel):
- return json.loads(d.json())
+ return json.loads(d.model_dump_json())
class SafetyClient(Safety):
@@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
)
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
- shield_id="llama_guard",
+ shield_id="Llama-Guard-3-1B",
messages=[message],
)
print(response)
diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py
index 7a2a83c72..251a683c1 100644
--- a/llama_stack/apis/scoring_functions/scoring_functions.py
+++ b/llama_stack/apis/scoring_functions/scoring_functions.py
@@ -96,7 +96,6 @@ class ScoringFn(CommonScoringFnFields, Resource):
return self.provider_resource_id
-@json_schema_type
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
diff --git a/llama_stack/apis/shields/client.py b/llama_stack/apis/shields/client.py
index 2f6b5e649..7556d2d12 100644
--- a/llama_stack/apis/shields/client.py
+++ b/llama_stack/apis/shields/client.py
@@ -37,7 +37,6 @@ class ShieldsClient(Shields):
async def register_shield(
self,
shield_id: str,
- shield_type: ShieldType,
provider_shield_id: Optional[str],
provider_id: Optional[str],
params: Optional[Dict[str, Any]],
@@ -47,7 +46,6 @@ class ShieldsClient(Shields):
f"{self.base_url}/shields/register",
json={
"shield_id": shield_id,
- "shield_type": shield_type,
"provider_shield_id": provider_shield_id,
"provider_id": provider_id,
"params": params,
@@ -56,12 +54,12 @@ class ShieldsClient(Shields):
)
response.raise_for_status()
- async def get_shield(self, shield_type: str) -> Optional[Shield]:
+ async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/shields/get",
params={
- "shield_type": shield_type,
+ "shield_id": shield_id,
},
headers={"Content-Type": "application/json"},
)
diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py
index 1dcfd4f4c..5ee444f68 100644
--- a/llama_stack/apis/shields/shields.py
+++ b/llama_stack/apis/shields/shields.py
@@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
@@ -13,16 +12,7 @@ from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
-@json_schema_type
-class ShieldType(Enum):
- generic_content_shield = "generic_content_shield"
- llama_guard = "llama_guard"
- code_scanner = "code_scanner"
- prompt_guard = "prompt_guard"
-
-
class CommonShieldFields(BaseModel):
- shield_type: ShieldType
params: Optional[Dict[str, Any]] = None
@@ -59,7 +49,6 @@ class Shields(Protocol):
async def register_shield(
self,
shield_id: str,
- shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index 5f6395e0d..220dfdb56 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -172,13 +172,12 @@ class SafetyRouter(Safety):
async def register_shield(
self,
shield_id: str,
- shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
return await self.routing_table.register_shield(
- shield_id, shield_type, provider_shield_id, provider_id, params
+ shield_id, provider_shield_id, provider_id, params
)
async def run_shield(
diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py
index 7b369df2c..d6fb5d662 100644
--- a/llama_stack/distribution/routers/routing_tables.py
+++ b/llama_stack/distribution/routers/routing_tables.py
@@ -136,17 +136,18 @@ class CommonRoutingTableImpl(RoutingTable):
else:
raise ValueError("Unknown routing table type")
+ apiname, objtype = apiname_object()
+
# Get objects from disk registry
- objects = self.dist_registry.get_cached(routing_key)
+ objects = self.dist_registry.get_cached(objtype, routing_key)
if not objects:
- apiname, objname = apiname_object()
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
- f"{objname.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objname}."
+ f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
for obj in objects:
@@ -156,19 +157,19 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(
- self, identifier: str
+ self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
- objects = await self.dist_registry.get(identifier)
+ objects = await self.dist_registry.get(type, identifier)
if not objects:
return None
- # kind of ill-defined behavior here, but we'll just return the first one
+ assert len(objects) == 1
return objects[0]
async def register_object(self, obj: RoutableObjectWithProvider):
# Get existing objects from registry
- existing_objects = await self.dist_registry.get(obj.identifier)
+ existing_objects = await self.dist_registry.get(obj.type, obj.identifier)
# Check for existing registration
for existing_obj in existing_objects:
@@ -200,7 +201,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return await self.get_all_with_type("model")
async def get_model(self, identifier: str) -> Optional[Model]:
- return await self.get_object_by_identifier(identifier)
+ return await self.get_object_by_identifier("model", identifier)
async def register_model(
self,
@@ -236,12 +237,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return await self.get_all_with_type(ResourceType.shield.value)
async def get_shield(self, identifier: str) -> Optional[Shield]:
- return await self.get_object_by_identifier(identifier)
+ return await self.get_object_by_identifier("shield", identifier)
async def register_shield(
self,
shield_id: str,
- shield_type: ShieldType,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
@@ -260,7 +260,6 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
params = {}
shield = Shield(
identifier=shield_id,
- shield_type=shield_type,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
@@ -274,7 +273,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return await self.get_all_with_type(ResourceType.memory_bank.value)
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
- return await self.get_object_by_identifier(memory_bank_id)
+ return await self.get_object_by_identifier("memory_bank", memory_bank_id)
async def register_memory_bank(
self,
@@ -312,7 +311,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
return await self.get_all_with_type("dataset")
async def get_dataset(self, dataset_id: str) -> Optional[Dataset]:
- return await self.get_object_by_identifier(dataset_id)
+ return await self.get_object_by_identifier("dataset", dataset_id)
async def register_dataset(
self,
@@ -348,10 +347,10 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFn]:
- return await self.get_all_with_type(ResourceType.scoring_function.value)
+ return await self.get_all_with_type("scoring_function")
async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]:
- return await self.get_object_by_identifier(scoring_fn_id)
+ return await self.get_object_by_identifier("scoring_function", scoring_fn_id)
async def register_scoring_function(
self,
@@ -389,7 +388,7 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTask]:
- return await self.get_object_by_identifier(name)
+ return await self.get_object_by_identifier("eval_task", name)
async def register_eval_task(
self,
diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py
index 3afd51304..1c7325eee 100644
--- a/llama_stack/distribution/stack.py
+++ b/llama_stack/distribution/stack.py
@@ -5,7 +5,6 @@
# the root directory of this source tree.
from typing import Any, Dict
-from termcolor import colored
from termcolor import colored
diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py
index 6115ea1b3..d837c4375 100644
--- a/llama_stack/distribution/store/registry.py
+++ b/llama_stack/distribution/store/registry.py
@@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
-from typing import Dict, List, Optional, Protocol
+from typing import Dict, List, Optional, Protocol, Tuple
import pydantic
@@ -35,7 +35,8 @@ class DistributionRegistry(Protocol):
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
-KEY_FORMAT = "distributions:registry:v1::{}"
+KEY_VERSION = "v1"
+KEY_FORMAT = f"distributions:registry:{KEY_VERSION}::" + "{type}:{identifier}"
class DiskDistributionRegistry(DistributionRegistry):
@@ -45,18 +46,24 @@ class DiskDistributionRegistry(DistributionRegistry):
async def initialize(self) -> None:
pass
- def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
+ def get_cached(
+ self, type: str, identifier: str
+ ) -> List[RoutableObjectWithProvider]:
# Disk registry does not have a cache
return []
async def get_all(self) -> List[RoutableObjectWithProvider]:
- start_key = KEY_FORMAT.format("")
- end_key = KEY_FORMAT.format("\xff")
+ start_key = KEY_FORMAT.format(type="", identifier="")
+ end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
- return [await self.get(key.split(":")[-1]) for key in keys]
- async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
- json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
+ tuples = [(key.split(":")[-2], key.split(":")[-1]) for key in keys]
+ return [await self.get(type, identifier) for type, identifier in tuples]
+
+ async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
+ json_str = await self.kvstore.get(
+ KEY_FORMAT.format(type=type, identifier=identifier)
+ )
if not json_str:
return []
@@ -70,7 +77,7 @@ class DiskDistributionRegistry(DistributionRegistry):
]
async def register(self, obj: RoutableObjectWithProvider) -> bool:
- existing_objects = await self.get(obj.identifier)
+ existing_objects = await self.get(obj.type, obj.identifier)
# dont register if the object's providerid already exists
for eobj in existing_objects:
if eobj.provider_id == obj.provider_id:
@@ -82,7 +89,8 @@ class DiskDistributionRegistry(DistributionRegistry):
obj.model_dump_json() for obj in existing_objects
] # Fixed variable name
await self.kvstore.set(
- KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
+ KEY_FORMAT.format(type=obj.type, identifier=obj.identifier),
+ json.dumps(objects_json),
)
return True
@@ -90,33 +98,36 @@ class DiskDistributionRegistry(DistributionRegistry):
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
def __init__(self, kvstore: KVStore):
super().__init__(kvstore)
- self.cache: Dict[str, List[RoutableObjectWithProvider]] = {}
+ self.cache: Dict[Tuple[str, str], List[RoutableObjectWithProvider]] = {}
async def initialize(self) -> None:
- start_key = KEY_FORMAT.format("")
- end_key = KEY_FORMAT.format("\xff")
+ start_key = KEY_FORMAT.format(type="", identifier="")
+ end_key = KEY_FORMAT.format(type="", identifier="\xff")
keys = await self.kvstore.range(start_key, end_key)
for key in keys:
- identifier = key.split(":")[-1]
- objects = await super().get(identifier)
+ type, identifier = key.split(":")[-2:]
+ objects = await super().get(type, identifier)
if objects:
- self.cache[identifier] = objects
+ self.cache[type, identifier] = objects
- def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
- return self.cache.get(identifier, [])
+ def get_cached(
+ self, type: str, identifier: str
+ ) -> List[RoutableObjectWithProvider]:
+ return self.cache.get((type, identifier), [])
async def get_all(self) -> List[RoutableObjectWithProvider]:
return [item for sublist in self.cache.values() for item in sublist]
- async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
- if identifier in self.cache:
- return self.cache[identifier]
+ async def get(self, type: str, identifier: str) -> List[RoutableObjectWithProvider]:
+ cachekey = (type, identifier)
+ if cachekey in self.cache:
+ return self.cache[cachekey]
- objects = await super().get(identifier)
+ objects = await super().get(type, identifier)
if objects:
- self.cache[identifier] = objects
+ self.cache[cachekey] = objects
return objects
@@ -126,16 +137,17 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
if success:
# Then update cache
- if obj.identifier not in self.cache:
- self.cache[obj.identifier] = []
+ cachekey = (obj.type, obj.identifier)
+ if cachekey not in self.cache:
+ self.cache[cachekey] = []
# Check if provider already exists in cache
- for cached_obj in self.cache[obj.identifier]:
+ for cached_obj in self.cache[cachekey]:
if cached_obj.provider_id == obj.provider_id:
return success
# If not, update cache
- self.cache[obj.identifier].append(obj)
+ self.cache[cachekey].append(obj)
return success
diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py
index 1ca65c9bb..c477c685c 100644
--- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py
+++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py
@@ -14,6 +14,12 @@ from .config import CodeScannerConfig
from llama_stack.apis.safety import * # noqa: F403
+ALLOWED_CODE_SCANNER_MODEL_IDS = [
+ "CodeScanner",
+ "CodeShield",
+]
+
+
class MetaReferenceCodeScannerSafetyImpl(Safety):
def __init__(self, config: CodeScannerConfig, deps) -> None:
self.config = config
@@ -25,8 +31,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
pass
async def register_shield(self, shield: Shield) -> None:
- if shield.shield_type != ShieldType.code_scanner:
- raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
+ if shield.provider_resource_id not in ALLOWED_CODE_SCANNER_MODEL_IDS:
+ raise ValueError(
+ f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
+ )
async def run_shield(
self,
diff --git a/llama_stack/providers/inline/safety/llama_guard/config.py b/llama_stack/providers/inline/safety/llama_guard/config.py
index aec856bce..72036fd1c 100644
--- a/llama_stack/providers/inline/safety/llama_guard/config.py
+++ b/llama_stack/providers/inline/safety/llama_guard/config.py
@@ -6,32 +6,8 @@
from typing import List
-from llama_models.sku_list import CoreModelId, safety_models
-
-from pydantic import BaseModel, field_validator
+from pydantic import BaseModel
class LlamaGuardConfig(BaseModel):
- model: str = "Llama-Guard-3-1B"
excluded_categories: List[str] = []
-
- @field_validator("model")
- @classmethod
- def validate_model(cls, model: str) -> str:
- permitted_models = [
- m.descriptor()
- for m in safety_models()
- if (
- m.core_model_id
- in {
- CoreModelId.llama_guard_3_8b,
- CoreModelId.llama_guard_3_1b,
- CoreModelId.llama_guard_3_11b_vision,
- }
- )
- ]
- if model not in permitted_models:
- raise ValueError(
- f"Invalid model: {model}. Must be one of {permitted_models}"
- )
- return model
diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
index 12d012b16..494c1b43e 100644
--- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
+++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py
@@ -73,6 +73,11 @@ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_ELECTIONS,
]
+LLAMA_GUARD_MODEL_IDS = [
+ CoreModelId.llama_guard_3_8b.value,
+ CoreModelId.llama_guard_3_1b.value,
+ CoreModelId.llama_guard_3_11b_vision.value,
+]
MODEL_TO_SAFETY_CATEGORIES_MAP = {
CoreModelId.llama_guard_3_8b.value: (
@@ -118,18 +123,16 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
self.inference_api = deps[Api.inference]
async def initialize(self) -> None:
- self.shield = LlamaGuardShield(
- model=self.config.model,
- inference_api=self.inference_api,
- excluded_categories=self.config.excluded_categories,
- )
+ pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
- if shield.shield_type != ShieldType.llama_guard:
- raise ValueError(f"Unsupported shield type: {shield.shield_type}")
+ if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS:
+ raise ValueError(
+ f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}"
+ )
async def run_shield(
self,
@@ -147,7 +150,13 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
- return await self.shield.run(messages)
+ impl = LlamaGuardShield(
+ model=shield.provider_resource_id,
+ inference_api=self.inference_api,
+ excluded_categories=self.config.excluded_categories,
+ )
+
+ return await impl.run(messages)
class LlamaGuardShield:
diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
index 20bfdd241..9f3d78374 100644
--- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
+++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py
@@ -36,8 +36,10 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
- if shield.shield_type != ShieldType.prompt_guard:
- raise ValueError(f"Unsupported shield type: {shield.shield_type}")
+ if shield.provider_resource_id != PROMPT_GUARD_MODEL:
+ raise ValueError(
+ f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. "
+ )
async def run_shield(
self,
diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py
index d49035321..78e8105e0 100644
--- a/llama_stack/providers/remote/safety/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py
@@ -20,11 +20,6 @@ from .config import BedrockSafetyConfig
logger = logging.getLogger(__name__)
-BEDROCK_SUPPORTED_SHIELDS = [
- ShieldType.generic_content_shield,
-]
-
-
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py
index 64f493b88..db157174f 100644
--- a/llama_stack/providers/tests/agents/fixtures.py
+++ b/llama_stack/providers/tests/agents/fixtures.py
@@ -44,7 +44,7 @@ def agents_meta_reference() -> ProviderFixture:
providers=[
Provider(
provider_id="meta-reference",
- provider_type="meta-reference",
+ provider_type="inline::meta-reference",
config=MetaReferenceAgentsImplConfig(
# TODO: make this an in-memory store
persistence_store=SqliteKVStoreConfig(
diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py
index b3f3dc31c..47e5a751f 100644
--- a/llama_stack/providers/tests/agents/test_agents.py
+++ b/llama_stack/providers/tests/agents/test_agents.py
@@ -81,15 +81,17 @@ async def create_agent_session(agents_impl, agent_config):
class TestAgents:
@pytest.mark.asyncio
- async def test_agent_turns_with_safety(self, agents_stack, common_params):
+ async def test_agent_turns_with_safety(
+ self, safety_model, agents_stack, common_params
+ ):
agents_impl, _ = agents_stack
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
- "input_shields": ["llama_guard"],
- "output_shields": ["llama_guard"],
+ "input_shields": [safety_model],
+ "output_shields": [safety_model],
}
),
)
diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py
index 66576e9d7..b73c2d798 100644
--- a/llama_stack/providers/tests/safety/fixtures.py
+++ b/llama_stack/providers/tests/safety/fixtures.py
@@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput
-from llama_stack.apis.shields import ShieldInput, ShieldType
+from llama_stack.apis.shields import ShieldInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@@ -41,7 +41,7 @@ def safety_llama_guard(safety_model) -> ProviderFixture:
Provider(
provider_id="inline::llama-guard",
provider_type="inline::llama-guard",
- config=LlamaGuardConfig(model=safety_model).model_dump(),
+ config=LlamaGuardConfig().model_dump(),
)
],
)
@@ -101,6 +101,8 @@ async def safety_stack(inference_model, safety_model, request):
shield_provider_type = safety_fixture.providers[0].provider_type
shield_input = get_shield_to_register(shield_provider_type, safety_model)
+ print(f"inference_model: {inference_model}")
+ print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
@@ -114,20 +116,14 @@ async def safety_stack(inference_model, safety_model, request):
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
- shield_config = {}
- shield_type = ShieldType.llama_guard
- identifier = "llama_guard"
- if provider_type == "meta-reference":
- shield_config["model"] = safety_model
- elif provider_type == "remote::together":
- shield_config["model"] = safety_model
- elif provider_type == "remote::bedrock":
+ if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
- shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
- shield_type = ShieldType.generic_content_shield
+ params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
+ else:
+ params = {}
+ identifier = safety_model
return ShieldInput(
shield_id=identifier,
- shield_type=shield_type,
- params=shield_config,
+ params=params,
)
diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py
index 48fab9741..9daa7bf40 100644
--- a/llama_stack/providers/tests/safety/test_safety.py
+++ b/llama_stack/providers/tests/safety/test_safety.py
@@ -34,7 +34,6 @@ class TestSafety:
for shield in response:
assert isinstance(shield, Shield)
- assert shield.shield_type in [v for v in ShieldType]
@pytest.mark.asyncio
async def test_run_shield(self, safety_stack):