mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 18:50:44 +00:00
Squashed commit of the following:
commit a95d2b15b83057e194cf69e57a03deeeeeadd7c2 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 14:33:50 2025 +0000 🚧 working on the config file so that it is inheriting from pydantic base models commit 0546379f817e37bca030247b48c72ce84899a766 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 09:14:31 2025 +0000 🚧 dealing with ruff checks commit 8abe39ee4cb4b8fb77c7252342c4809fa6ddc432 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 09:03:18 2025 +0000 🚧 dealing with mypy errors in `base.py` commit 045f833e79c9a25af3d46af6c8896da91a0e6e62 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 17:31:25 2025 +0000 🚧 fixing mypy errors in content.py commit a9c1ee4e92ad1b5db89039317555cd983edbde65 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 17:09:02 2025 +0000 🚧 fixing mypy errors in chat.py commit 69e8ddc2f8a4e13cecbab30272fd7d685d7864ec Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 16:57:28 2025 +0000 🚧 fixing mypy errors commit 56739d69a145c55335ac2859ecbe5b43d556e3b1 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 14:01:03 2025 +0000 🚧 fixing mypy errors in `__init__.py` commit 4d2e3b55c4102ed75d997c8189847bbc5524cb2c Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 12:58:06 2025 +0000 🚧 ensuring routing_tables.py do not fail the ci commit c0cc7b4b09ef50d5ec95fdb0a916c7ed228bf366 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 12:09:24 2025 +0000 🐛 fixing linter problems commit 115a50211b604feb4106275204fe7f863da865f6 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:47:04 2025 +0000 🐛 fixing ruff errors commit 29b5bfaabc77a35ea036b57f75fded711228dbbf Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:33:31 2025 +0000 🎨 automatic ruff fixes commit 7c5a334c7d4649c2fc297993f89791c1e5643e5b Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:15:02 2025 +0000 Squashed commit of the following: commit e671aae5bcd4ea57d601ee73c9e3adf5e223e830 Merge: b0dd9a4f9114bef4
Author: Mac Misiura <82826099+m-misiura@users.noreply.github.com> Date: Fri Mar 21 09:45:08 2025 +0000 Merge branch 'meta-llama:main' into feat_fms_remote_safety_provider commit b0dd9a4f746b0c8c54d1189d381a7ff8e51c812c Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 09:27:21 2025 +0000 📝 updated `provider_id` commit 4c8906c1a4e960968b93251d09d5e5735db15026 Author: m-misiura <mmisiura@redhat.com> Date: Thu Mar 20 16:54:46 2025 +0000 📝 renaming from `fms` to `trustyai_fms` commit 4c0b62abc51b02143b5c818f2d30e1a1fee9e4f3 Merge: bb842d6954035825
Author: m-misiura <mmisiura@redhat.com> Date: Thu Mar 20 16:35:52 2025 +0000 Merge branch 'main' into feat_fms_remote_safety_provider commit bb842d69548df256927465792e0cd107a267d2a0 Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 19 15:03:17 2025 +0000 ✨ added a better way of handling params from the configs commit 58b6beabf0994849ac50317ed00b748596e8961d Merge: a22cf36c7c044845
Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 19 09:19:57 2025 +0000 Merge main into feat_fms_remote_safety_provider, resolve conflicts by keeping main version commit a22cf36c8757f74ed656c1310a4be6b288bf923a Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 5 16:17:46 2025 +0000 🎉 added a new remote safety provider compatible with FMS Orchestrator API and Detectors API Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
parent
9e1ddf2b53
commit
87d209d6ef
7 changed files with 2883 additions and 20 deletions
|
@ -118,7 +118,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
self.dist_registry = dist_registry
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
async def add_objects(
|
||||||
|
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||||
|
) -> None:
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
|
@ -153,7 +155,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
await p.shutdown()
|
await p.shutdown()
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
def get_provider_impl(
|
||||||
|
self, routing_key: str, provider_id: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
def apiname_object():
|
def apiname_object():
|
||||||
if isinstance(self, ModelsRoutingTable):
|
if isinstance(self, ModelsRoutingTable):
|
||||||
return ("Inference", "model")
|
return ("Inference", "model")
|
||||||
|
@ -191,24 +195,32 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
async def get_object_by_identifier(
|
||||||
|
self, type: str, identifier: str
|
||||||
|
) -> Optional[RoutableObjectWithProvider]:
|
||||||
# Get from disk registry
|
# Get from disk registry
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
if not obj:
|
if not obj:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
# Check if user has permission to access this object
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
if not check_access(obj, get_auth_attributes()):
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
logger.debug(
|
||||||
|
f"Access denied to {type} '{identifier}' based on attribute mismatch"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
await unregister_object_from_provider(
|
||||||
|
obj, self.impls_by_provider_id[obj.provider_id]
|
||||||
|
)
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
async def register_object(
|
||||||
|
self, obj: RoutableObjectWithProvider
|
||||||
|
) -> RoutableObjectWithProvider:
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -223,7 +235,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
creator_attributes = get_auth_attributes()
|
creator_attributes = get_auth_attributes()
|
||||||
if creator_attributes:
|
if creator_attributes:
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
logger.info(
|
||||||
|
f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity"
|
||||||
|
)
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
@ -242,9 +256,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
# Apply attribute-based access control filtering
|
# Apply attribute-based access control filtering
|
||||||
if filtered_objs:
|
if filtered_objs:
|
||||||
filtered_objs = [
|
filtered_objs = [
|
||||||
obj
|
obj for obj in filtered_objs if check_access(obj, get_auth_attributes())
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return filtered_objs
|
return filtered_objs
|
||||||
|
@ -283,7 +295,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError(
|
||||||
|
"Embedding model must have an embedding dimension in its metadata"
|
||||||
|
)
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
|
@ -302,8 +316,54 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize routing table and providers"""
|
||||||
|
# First do the common initialization
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
# Then explicitly initialize each safety provider
|
||||||
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
|
api = get_impl_api(provider)
|
||||||
|
if api == Api.safety:
|
||||||
|
logger.info(f"Explicitly initializing safety provider: {provider_id}")
|
||||||
|
await provider.initialize()
|
||||||
|
|
||||||
|
# Fetch shields after initialization - with robust error handling
|
||||||
|
try:
|
||||||
|
# Check if the provider implements list_shields
|
||||||
|
if hasattr(provider, "list_shields") and callable(
|
||||||
|
getattr(provider, "list_shields")
|
||||||
|
):
|
||||||
|
shields_response = await provider.list_shields()
|
||||||
|
if (
|
||||||
|
shields_response
|
||||||
|
and hasattr(shields_response, "data")
|
||||||
|
and shields_response.data
|
||||||
|
):
|
||||||
|
for shield in shields_response.data:
|
||||||
|
# Ensure type is set
|
||||||
|
if not hasattr(shield, "type") or not shield.type:
|
||||||
|
shield.type = ResourceType.shield.value
|
||||||
|
await self.dist_registry.register(shield)
|
||||||
|
logger.info(
|
||||||
|
f"Registered {len(shields_response.data)} shields from provider {provider_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"No shields found for provider {provider_id}")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Provider {provider_id} does not support listing shields"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but continue initialization
|
||||||
|
logger.warning(
|
||||||
|
f"Error listing shields from provider {provider_id}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
return ListShieldsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.shield.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
@ -368,14 +428,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
raise ValueError(
|
||||||
|
"No provider available. Please configure a vector_io provider."
|
||||||
|
)
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
raise ValueError(
|
||||||
|
f"Model {embedding_model} does not have an embedding dimension"
|
||||||
|
)
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_db_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
|
@ -397,7 +461,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
return ListDatasetsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.dataset.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
@ -459,10 +525,14 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
return ListScoringFunctionsResponse(
|
||||||
|
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
scoring_fn = await self.get_object_by_identifier(
|
||||||
|
"scoring_function", scoring_fn_id
|
||||||
|
)
|
||||||
if scoring_fn is None:
|
if scoring_fn is None:
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||||
return scoring_fn
|
return scoring_fn
|
||||||
|
@ -565,8 +635,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
args: Optional[Dict[str, Any]] = None,
|
args: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
tools = []
|
tools = []
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
toolgroup_id, mcp_endpoint
|
||||||
|
)
|
||||||
|
tool_host = (
|
||||||
|
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
|
)
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
|
|
|
@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.safety,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="trustyai_fms",
|
||||||
|
pip_packages=[],
|
||||||
|
module="llama_stack.providers.remote.safety.trustyai_fms",
|
||||||
|
config_class="llama_stack.providers.remote.safety.trustyai_fms.config.FMSSafetyProviderConfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
119
llama_stack/providers/remote/safety/trustyai_fms/__init__.py
Normal file
119
llama_stack/providers/remote/safety/trustyai_fms/__init__.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
# Remove register_provider import since registration is in registry/safety.py
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.config import (
|
||||||
|
ChatDetectorConfig,
|
||||||
|
ContentDetectorConfig,
|
||||||
|
DetectorParams,
|
||||||
|
EndpointType,
|
||||||
|
FMSSafetyProviderConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||||
|
BaseDetector,
|
||||||
|
DetectorProvider,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.detectors.chat import ChatDetector
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.detectors.content import (
|
||||||
|
ContentDetector,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type aliases for better readability
|
||||||
|
ConfigType = Union[ContentDetectorConfig, ChatDetectorConfig, FMSSafetyProviderConfig]
|
||||||
|
DetectorType = Union[BaseDetector, DetectorProvider]
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorConfigError(ValueError):
|
||||||
|
"""Raised when detector configuration is invalid"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def create_fms_provider(config: Dict[str, Any]) -> Safety:
|
||||||
|
"""Create FMS safety provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Safety: Configured FMS safety provider
|
||||||
|
"""
|
||||||
|
logger.debug("Creating trustyai-fms provider")
|
||||||
|
return await get_adapter_impl(FMSSafetyProviderConfig(**config))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: Union[Dict[str, Any], FMSSafetyProviderConfig],
|
||||||
|
_deps: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DetectorType:
|
||||||
|
"""Get appropriate detector implementation(s) based on config type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary or FMSSafetyProviderConfig instance
|
||||||
|
_deps: Optional dependencies for testing/injection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured detector implementation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DetectorConfigError: If configuration is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(config, FMSSafetyProviderConfig):
|
||||||
|
provider_config = config
|
||||||
|
else:
|
||||||
|
provider_config = FMSSafetyProviderConfig(**config)
|
||||||
|
|
||||||
|
detectors: Dict[str, DetectorType] = {}
|
||||||
|
|
||||||
|
# Changed from provider_config.detectors to provider_config.shields
|
||||||
|
for shield_id, shield_config in provider_config.shields.items():
|
||||||
|
impl: BaseDetector
|
||||||
|
if isinstance(shield_config, ChatDetectorConfig):
|
||||||
|
impl = ChatDetector(shield_config)
|
||||||
|
elif isinstance(shield_config, ContentDetectorConfig):
|
||||||
|
impl = ContentDetector(shield_config)
|
||||||
|
else:
|
||||||
|
raise DetectorConfigError(
|
||||||
|
f"Invalid shield config type for {shield_id}: {type(shield_config)}"
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
detectors[shield_id] = impl
|
||||||
|
|
||||||
|
detectors_for_provider: Dict[str, BaseDetector] = {}
|
||||||
|
for shield_id, detector in detectors.items():
|
||||||
|
if isinstance(detector, BaseDetector):
|
||||||
|
detectors_for_provider[shield_id] = detector
|
||||||
|
|
||||||
|
return DetectorProvider(detectors_for_provider)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise DetectorConfigError(
|
||||||
|
f"Failed to create detector implementation: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Factory methods
|
||||||
|
"get_adapter_impl",
|
||||||
|
"create_fms_provider",
|
||||||
|
# Configurations
|
||||||
|
"ContentDetectorConfig",
|
||||||
|
"ChatDetectorConfig",
|
||||||
|
"FMSSafetyProviderConfig",
|
||||||
|
"EndpointType",
|
||||||
|
"DetectorParams",
|
||||||
|
# Implementations
|
||||||
|
"ChatDetector",
|
||||||
|
"ContentDetector",
|
||||||
|
"BaseDetector",
|
||||||
|
"DetectorProvider",
|
||||||
|
# Types
|
||||||
|
"ConfigType",
|
||||||
|
"DetectorType",
|
||||||
|
"DetectorConfigError",
|
||||||
|
]
|
562
llama_stack/providers/remote/safety/trustyai_fms/config.py
Normal file
562
llama_stack/providers/remote/safety/trustyai_fms/config.py
Normal file
|
@ -0,0 +1,562 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(Enum):
|
||||||
|
"""Valid message types for detectors"""
|
||||||
|
|
||||||
|
USER = "user"
|
||||||
|
SYSTEM = "system"
|
||||||
|
TOOL = "tool"
|
||||||
|
COMPLETION = "completion"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_set(cls) -> Set[str]:
|
||||||
|
"""Get all valid message types as a set"""
|
||||||
|
return {member.value for member in cls}
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointType(Enum):
|
||||||
|
"""API endpoint types and their paths"""
|
||||||
|
|
||||||
|
DIRECT_CONTENT = {
|
||||||
|
"path": "/api/v1/text/contents",
|
||||||
|
"version": "v1",
|
||||||
|
"type": "content",
|
||||||
|
}
|
||||||
|
DIRECT_CHAT = {"path": "/api/v1/text/chat", "version": "v1", "type": "chat"}
|
||||||
|
ORCHESTRATOR_CONTENT = {
|
||||||
|
"path": "/api/v2/text/detection/content",
|
||||||
|
"version": "v2",
|
||||||
|
"type": "content",
|
||||||
|
}
|
||||||
|
ORCHESTRATOR_CHAT = {
|
||||||
|
"path": "/api/v2/text/detection/chat",
|
||||||
|
"version": "v2",
|
||||||
|
"type": "chat",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_endpoint(cls, is_orchestrator: bool, is_chat: bool) -> EndpointType:
|
||||||
|
"""Get appropriate endpoint based on configuration"""
|
||||||
|
if is_orchestrator:
|
||||||
|
return cls.ORCHESTRATOR_CHAT if is_chat else cls.ORCHESTRATOR_CONTENT
|
||||||
|
return cls.DIRECT_CHAT if is_chat else cls.DIRECT_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class DetectorParams:
|
||||||
|
"""Flexible parameter container supporting nested structure and arbitrary parameters"""
|
||||||
|
|
||||||
|
# Store all parameters in a single dictionary for maximum flexibility
|
||||||
|
params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Parameter categories for organization
|
||||||
|
model_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Store detectors directly as an attribute (not in params) for orchestrator mode
|
||||||
|
_raw_detectors: Optional[Dict[str, Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
# Standard parameters kept for backward compatibility
|
||||||
|
@property
|
||||||
|
def regex(self) -> Optional[List[str]]:
|
||||||
|
return self.params.get("regex")
|
||||||
|
|
||||||
|
@regex.setter
|
||||||
|
def regex(self, value: List[str]) -> None:
|
||||||
|
self.params["regex"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self) -> Optional[float]:
|
||||||
|
return self.model_params.get("temperature") or self.params.get("temperature")
|
||||||
|
|
||||||
|
@temperature.setter
|
||||||
|
def temperature(self, value: float) -> None:
|
||||||
|
self.model_params["temperature"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def risk_name(self) -> Optional[str]:
|
||||||
|
return self.metadata.get("risk_name") or self.params.get("risk_name")
|
||||||
|
|
||||||
|
@risk_name.setter
|
||||||
|
def risk_name(self, value: str) -> None:
|
||||||
|
self.metadata["risk_name"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def risk_definition(self) -> Optional[str]:
|
||||||
|
return self.metadata.get("risk_definition") or self.params.get(
|
||||||
|
"risk_definition"
|
||||||
|
)
|
||||||
|
|
||||||
|
@risk_definition.setter
|
||||||
|
def risk_definition(self, value: str) -> None:
|
||||||
|
self.metadata["risk_definition"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def orchestrator_detectors(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Return detectors in the format required by orchestrator API"""
|
||||||
|
if (
|
||||||
|
not hasattr(self, "_detectors") or not self._detectors
|
||||||
|
): # Direct attribute access
|
||||||
|
return {}
|
||||||
|
|
||||||
|
flattened = {}
|
||||||
|
for (
|
||||||
|
detector_id,
|
||||||
|
detector_config,
|
||||||
|
) in self._detectors.items(): # Direct attribute access
|
||||||
|
# Create a flattened version without extra nesting
|
||||||
|
flat_config = {}
|
||||||
|
|
||||||
|
# Extract detector_params if present and flatten them
|
||||||
|
params = detector_config.get("detector_params", {})
|
||||||
|
if isinstance(params, dict):
|
||||||
|
# Move params up to top level
|
||||||
|
for key, value in params.items():
|
||||||
|
flat_config[key] = value
|
||||||
|
|
||||||
|
flattened[detector_id] = flat_config
|
||||||
|
|
||||||
|
return flattened
|
||||||
|
|
||||||
|
@property
|
||||||
|
def formatted_detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||||
|
"""Return detectors properly formatted for orchestrator API"""
|
||||||
|
# Direct return for API usage - avoid calling other properties
|
||||||
|
if hasattr(self, "_detectors") and self._detectors:
|
||||||
|
return self.orchestrator_detectors
|
||||||
|
return None
|
||||||
|
|
||||||
|
@formatted_detectors.setter
|
||||||
|
def formatted_detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
|
||||||
|
self._detectors = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||||
|
"""COMPATIBILITY: Returns the same as formatted_detectors to maintain API compatibility"""
|
||||||
|
# Using a different implementation to avoid the redefinition error
|
||||||
|
# while maintaining the same functionality
|
||||||
|
if not hasattr(self, "_detectors") or not self._detectors:
|
||||||
|
return None
|
||||||
|
return self.orchestrator_detectors
|
||||||
|
|
||||||
|
@detectors.setter
|
||||||
|
def detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
|
||||||
|
"""COMPATIBILITY: Set detectors while maintaining compatibility"""
|
||||||
|
self._detectors = value
|
||||||
|
|
||||||
|
# And fix the __setitem__ method:
|
||||||
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
|
"""Allow dictionary-like assignment with smart categorization"""
|
||||||
|
# Special handling for known params
|
||||||
|
if key == "detectors":
|
||||||
|
self._detectors = value # Set underlying attribute directly
|
||||||
|
return
|
||||||
|
elif key == "regex":
|
||||||
|
self.params[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rest of the method remains unchanged
|
||||||
|
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
|
||||||
|
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
|
||||||
|
|
||||||
|
if key in known_model_params:
|
||||||
|
self.model_params[key] = value
|
||||||
|
elif key in known_metadata:
|
||||||
|
self.metadata[key] = value
|
||||||
|
else:
|
||||||
|
self.kwargs[key] = value
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Initialize from any keyword arguments with smart categorization"""
|
||||||
|
# Initialize containers
|
||||||
|
self.params = {}
|
||||||
|
self.model_params = {}
|
||||||
|
self.metadata = {}
|
||||||
|
self.kwargs = {}
|
||||||
|
self._raw_detectors = None
|
||||||
|
|
||||||
|
# Special handling for nested detectors structure
|
||||||
|
if "detectors" in kwargs:
|
||||||
|
self._raw_detectors = kwargs.pop("detectors")
|
||||||
|
|
||||||
|
# Special handling for regex
|
||||||
|
if "regex" in kwargs:
|
||||||
|
self.params["regex"] = kwargs.pop("regex")
|
||||||
|
|
||||||
|
# Categorize known parameters
|
||||||
|
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
|
||||||
|
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
|
||||||
|
|
||||||
|
# Explicit categories if provided
|
||||||
|
if "model_params" in kwargs:
|
||||||
|
self.model_params.update(kwargs.pop("model_params"))
|
||||||
|
|
||||||
|
if "metadata" in kwargs:
|
||||||
|
self.metadata.update(kwargs.pop("metadata"))
|
||||||
|
|
||||||
|
if "kwargs" in kwargs:
|
||||||
|
self.kwargs.update(kwargs.pop("kwargs"))
|
||||||
|
|
||||||
|
# Categorize remaining parameters
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in known_model_params:
|
||||||
|
self.model_params[key] = value
|
||||||
|
elif key in known_metadata:
|
||||||
|
self.metadata[key] = value
|
||||||
|
else:
|
||||||
|
self.kwargs[key] = value
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> Any:
|
||||||
|
"""Allow dictionary-like access with category lookup"""
|
||||||
|
if key in self.params:
|
||||||
|
return self.params[key]
|
||||||
|
elif key in self.model_params:
|
||||||
|
return self.model_params[key]
|
||||||
|
elif key in self.metadata:
|
||||||
|
return self.metadata[key]
|
||||||
|
elif key in self.kwargs:
|
||||||
|
return self.kwargs[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""Dictionary-style get with category lookup"""
|
||||||
|
result = self.__getitem__(key)
|
||||||
|
return default if result is None else result
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""Set a parameter value with smart categorization"""
|
||||||
|
self.__setitem__(key, value)
|
||||||
|
|
||||||
|
def update(self, params: Dict[str, Any]) -> None:
|
||||||
|
"""Update with multiple parameters, respecting categories"""
|
||||||
|
for key, value in params.items():
|
||||||
|
self.__setitem__(key, value)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert all parameters to a flat dictionary for API requests"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# Add core parameters
|
||||||
|
result.update(self.params)
|
||||||
|
|
||||||
|
# Add all categorized parameters, flattened
|
||||||
|
result.update(self.model_params)
|
||||||
|
result.update(self.metadata)
|
||||||
|
result.update(self.kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def to_categorized_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to a structured dictionary with categories preserved"""
|
||||||
|
result = dict(self.params)
|
||||||
|
|
||||||
|
if self.model_params:
|
||||||
|
result["model_params"] = dict(self.model_params)
|
||||||
|
|
||||||
|
if self.metadata:
|
||||||
|
result["metadata"] = dict(self.metadata)
|
||||||
|
|
||||||
|
if self.kwargs:
|
||||||
|
result["kwargs"] = dict(self.kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def create_flattened_detector_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Create flattened detector configurations for orchestrator mode.
|
||||||
|
This removes the extra detector_params nesting that causes API errors.
|
||||||
|
"""
|
||||||
|
if not self.detectors:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
flattened = {}
|
||||||
|
for detector_id, detector_config in self.detectors.items():
|
||||||
|
# Create a flattened version without extra nesting
|
||||||
|
flat_config = {}
|
||||||
|
|
||||||
|
# Extract detector_params if present and flatten them
|
||||||
|
params = detector_config.get("detector_params", {})
|
||||||
|
if isinstance(params, dict):
|
||||||
|
# Move params up to top level
|
||||||
|
for key, value in params.items():
|
||||||
|
flat_config[key] = value
|
||||||
|
|
||||||
|
flattened[detector_id] = flat_config
|
||||||
|
|
||||||
|
return flattened
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate parameter values"""
|
||||||
|
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||||
|
raise ValueError("Temperature must be between 0 and 1")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class BaseDetectorConfig:
|
||||||
|
"""Base configuration for all detectors with flexible parameter handling"""
|
||||||
|
|
||||||
|
detector_id: str
|
||||||
|
confidence_threshold: float = 0.5
|
||||||
|
message_types: Set[str] = field(default_factory=lambda: MessageType.as_set())
|
||||||
|
auth_token: Optional[str] = None
|
||||||
|
detector_params: Optional[DetectorParams] = None
|
||||||
|
|
||||||
|
# URL fields directly on detector configs
|
||||||
|
detector_url: Optional[str] = None
|
||||||
|
orchestrator_url: Optional[str] = None
|
||||||
|
|
||||||
|
# Flexible storage for any additional parameters
|
||||||
|
_extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Runtime execution parameters
|
||||||
|
max_concurrency: int = 10 # Maximum concurrent API requests
|
||||||
|
request_timeout: float = 30.0 # HTTP request timeout in seconds
|
||||||
|
max_retries: int = 3 # Maximum number of retry attempts
|
||||||
|
backoff_factor: float = 1.5 # Exponential backoff multiplier
|
||||||
|
max_keepalive_connections: int = 5 # Max number of keepalive connections
|
||||||
|
max_connections: int = 10 # Max number of connections in the pool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_orchestrator_api(self) -> bool:
|
||||||
|
"""Determine if orchestrator API should be used"""
|
||||||
|
return bool(self.orchestrator_url)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Process configuration after initialization"""
|
||||||
|
# Convert list/tuple message_types to set
|
||||||
|
if isinstance(self.message_types, (list, tuple)):
|
||||||
|
self.message_types = set(self.message_types)
|
||||||
|
|
||||||
|
# Validate message types
|
||||||
|
invalid_types = self.message_types - MessageType.as_set()
|
||||||
|
if invalid_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid message types: {invalid_types}. "
|
||||||
|
f"Valid types are: {MessageType.as_set()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize detector_params if needed
|
||||||
|
if self.detector_params is None:
|
||||||
|
self.detector_params = DetectorParams()
|
||||||
|
|
||||||
|
# Handle legacy URL field names
|
||||||
|
if hasattr(self, "base_url") and self.base_url and not self.detector_url:
|
||||||
|
self.detector_url = self.base_url
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self, "orchestrator_base_url")
|
||||||
|
and self.orchestrator_base_url
|
||||||
|
and not self.orchestrator_url
|
||||||
|
):
|
||||||
|
self.orchestrator_url = self.orchestrator_base_url
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration"""
|
||||||
|
# Validate detector_params
|
||||||
|
if self.detector_params:
|
||||||
|
self.detector_params.validate()
|
||||||
|
|
||||||
|
# Validate that at least one URL is provided
|
||||||
|
if not self.detector_url and not self.orchestrator_url:
|
||||||
|
raise ValueError(f"No URL provided for detector {self.detector_id}")
|
||||||
|
|
||||||
|
# Validate URLs if present
|
||||||
|
for url_name, url in [
|
||||||
|
("detector_url", self.detector_url),
|
||||||
|
("orchestrator_url", self.orchestrator_url),
|
||||||
|
]:
|
||||||
|
if url:
|
||||||
|
self._validate_url(url, url_name)
|
||||||
|
|
||||||
|
def _validate_url(self, url: str, url_name: str) -> None:
|
||||||
|
"""Validate URL format"""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not all([parsed.scheme, parsed.netloc]):
|
||||||
|
raise ValueError(f"Invalid {url_name} format: {url}")
|
||||||
|
if parsed.scheme not in {"http", "https"}:
|
||||||
|
raise ValueError(f"Invalid {url_name} scheme: {parsed.scheme}")
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""Get parameter with fallback to extra parameters"""
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
return self._extra_params.get(key, default)
|
||||||
|
|
||||||
|
def set(self, key: str, value: Any) -> None:
|
||||||
|
"""Set parameter, storing in extra_params if not a standard field"""
|
||||||
|
if hasattr(self, key) and key not in ["_extra_params"]:
|
||||||
|
setattr(self, key, value)
|
||||||
|
else:
|
||||||
|
self._extra_params[key] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_chat(self) -> bool:
|
||||||
|
"""Default implementation, should be overridden by subclasses"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class ContentDetectorConfig(BaseDetectorConfig):
|
||||||
|
"""Configuration for content detectors"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_chat(self) -> bool:
|
||||||
|
"""Content detectors are not chat detectors"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
@dataclass
|
||||||
|
class ChatDetectorConfig(BaseDetectorConfig):
|
||||||
|
"""Configuration for chat detectors"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_chat(self) -> bool:
|
||||||
|
"""Chat detectors are chat detectors"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FMSSafetyProviderConfig(BaseModel):
|
||||||
|
"""Configuration for the FMS Safety Provider"""
|
||||||
|
|
||||||
|
shields: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# Rename _detectors to remove the leading underscore
|
||||||
|
detectors_internal: Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]] = (
|
||||||
|
Field(default_factory=dict, exclude=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider-level orchestrator URL (can be copied to shields if needed)
|
||||||
|
orchestrator_url: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
# Add a model validator to replace __post_init__
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def setup_config(self):
|
||||||
|
"""Process shield configurations"""
|
||||||
|
# Process shields into detector objects
|
||||||
|
self._process_shields()
|
||||||
|
|
||||||
|
# Replace shields dictionary with processed detector configs
|
||||||
|
self.shields = self.detectors_internal
|
||||||
|
|
||||||
|
# Validate all shields
|
||||||
|
for shield in self.shields.values():
|
||||||
|
shield.validate()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _process_shields(self):
|
||||||
|
"""Process all shield configurations into detector configs"""
|
||||||
|
for shield_id, config in self.shields.items():
|
||||||
|
if isinstance(config, dict):
|
||||||
|
# Copy the config to avoid modifying the original
|
||||||
|
shield_config = dict(config)
|
||||||
|
|
||||||
|
# Check if this shield has nested detectors
|
||||||
|
nested_detectors = shield_config.pop("detectors", None)
|
||||||
|
|
||||||
|
# Determine detector type
|
||||||
|
detector_type = shield_config.pop("type", None)
|
||||||
|
is_chat = (
|
||||||
|
detector_type == "chat"
|
||||||
|
if detector_type
|
||||||
|
else shield_config.pop("is_chat", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set detector ID
|
||||||
|
shield_config["detector_id"] = shield_id
|
||||||
|
|
||||||
|
# Handle URL fields
|
||||||
|
# First handle legacy field names
|
||||||
|
if "base_url" in shield_config and "detector_url" not in shield_config:
|
||||||
|
shield_config["detector_url"] = shield_config.pop("base_url")
|
||||||
|
|
||||||
|
if (
|
||||||
|
"orchestrator_base_url" in shield_config
|
||||||
|
and "orchestrator_url" not in shield_config
|
||||||
|
):
|
||||||
|
shield_config["orchestrator_url"] = shield_config.pop(
|
||||||
|
"orchestrator_base_url"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no orchestrator_url in shield but provider has one, copy it
|
||||||
|
if self.orchestrator_url and "orchestrator_url" not in shield_config:
|
||||||
|
shield_config["orchestrator_url"] = self.orchestrator_url
|
||||||
|
|
||||||
|
# Initialize detector_params with proper structure for nested detectors
|
||||||
|
detector_params_dict = shield_config.get("detector_params", {})
|
||||||
|
if not isinstance(detector_params_dict, dict):
|
||||||
|
detector_params_dict = {}
|
||||||
|
|
||||||
|
# Create detector_params object
|
||||||
|
detector_params = DetectorParams(**detector_params_dict)
|
||||||
|
|
||||||
|
# Add nested detectors if present
|
||||||
|
if nested_detectors:
|
||||||
|
detector_params.detectors = nested_detectors
|
||||||
|
|
||||||
|
shield_config["detector_params"] = detector_params
|
||||||
|
|
||||||
|
# Create appropriate detector config
|
||||||
|
detector_class = (
|
||||||
|
ChatDetectorConfig if is_chat else ContentDetectorConfig
|
||||||
|
)
|
||||||
|
self.detectors_internal[shield_id] = detector_class(**shield_config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_detectors(
|
||||||
|
self,
|
||||||
|
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
|
||||||
|
"""Get all detector configurations"""
|
||||||
|
return self.detectors_internal
|
||||||
|
|
||||||
|
# Update other methods to use detectors_internal instead of _detectors
|
||||||
|
def get_detectors_by_type(
|
||||||
|
self, message_type: Union[str, MessageType]
|
||||||
|
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
|
||||||
|
"""Get detectors for a specific message type"""
|
||||||
|
type_value = (
|
||||||
|
message_type.value
|
||||||
|
if isinstance(message_type, MessageType)
|
||||||
|
else message_type
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
shield_id: shield
|
||||||
|
for shield_id, shield in self.detectors_internal.items()
|
||||||
|
if type_value in shield.message_types
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convenience properties
|
||||||
|
@property
|
||||||
|
def user_message_detectors(self):
|
||||||
|
return self.get_detectors_by_type(MessageType.USER)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def system_message_detectors(self):
|
||||||
|
return self.get_detectors_by_type(MessageType.SYSTEM)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_response_detectors(self):
|
||||||
|
return self.get_detectors_by_type(MessageType.TOOL)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_message_detectors(self):
|
||||||
|
return self.get_detectors_by_type(MessageType.COMPLETION)
|
1561
llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py
Normal file
1561
llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,329 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.config import ChatDetectorConfig
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||||
|
BaseDetector,
|
||||||
|
DetectionResult,
|
||||||
|
DetectorError,
|
||||||
|
DetectorRequestError,
|
||||||
|
DetectorValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Type aliases for better readability
|
||||||
|
ChatMessage = Dict[
|
||||||
|
str, Any
|
||||||
|
] # Changed from Dict[str, str] to Dict[str, Any] to handle complex content
|
||||||
|
ChatRequest = Dict[str, Any]
|
||||||
|
DetectorResponse = List[Dict[str, Any]]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDetectorError(DetectorError):
|
||||||
|
"""Specific errors for chat detector operations"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChatDetectionMetadata:
|
||||||
|
"""Structured metadata for chat detections"""
|
||||||
|
|
||||||
|
risk_name: Optional[str] = None
|
||||||
|
risk_definition: Optional[str] = None
|
||||||
|
additional_metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert metadata to dictionary format"""
|
||||||
|
result: Dict[str, Any] = {} # Fixed the type annotation here
|
||||||
|
if self.risk_name:
|
||||||
|
result["risk_name"] = self.risk_name
|
||||||
|
if self.risk_definition:
|
||||||
|
result["risk_definition"] = self.risk_definition
|
||||||
|
if self.additional_metadata:
|
||||||
|
result["metadata"] = self.additional_metadata
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDetector(BaseDetector):
|
||||||
|
"""Detector for chat-based safety checks"""
|
||||||
|
|
||||||
|
def __init__(self, config: ChatDetectorConfig) -> None:
|
||||||
|
"""Initialize chat detector with configuration"""
|
||||||
|
if not isinstance(config, ChatDetectorConfig):
|
||||||
|
raise DetectorValidationError(
|
||||||
|
"Config must be an instance of ChatDetectorConfig"
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
self.config: ChatDetectorConfig = config
|
||||||
|
logger.info(f"Initialized ChatDetector with config: {vars(config)}")
|
||||||
|
|
||||||
|
def _extract_detector_params(self) -> Dict[str, Any]:
|
||||||
|
"""Extract non-null detector parameters"""
|
||||||
|
if not self.config.detector_params:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
params = {
|
||||||
|
k: v for k, v in vars(self.config.detector_params).items() if v is not None
|
||||||
|
}
|
||||||
|
logger.debug(f"Extracted detector params: {params}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _prepare_chat_request(
|
||||||
|
self, messages: List[ChatMessage], params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> ChatRequest:
|
||||||
|
"""Prepare the request based on API mode"""
|
||||||
|
# Format messages for detector API
|
||||||
|
formatted_messages: List[Dict[str, str]] = [] # Explicitly typed
|
||||||
|
for msg in messages:
|
||||||
|
formatted_msg = {
|
||||||
|
"content": str(msg.get("content", "")), # Ensure string type
|
||||||
|
"role": "user", # Always send as user for detector API
|
||||||
|
}
|
||||||
|
formatted_messages.append(formatted_msg)
|
||||||
|
|
||||||
|
if self.config.use_orchestrator_api:
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"messages": formatted_messages
|
||||||
|
} # Explicitly typed
|
||||||
|
|
||||||
|
# Initialize detector_config to avoid None
|
||||||
|
detector_config: Dict[str, Any] = {} # Explicitly typed
|
||||||
|
|
||||||
|
# NEW STRUCTURE: Check for top-level detectors first
|
||||||
|
if hasattr(self.config, "detectors") and self.config.detectors:
|
||||||
|
for detector_id, det_config in self.config.detectors.items():
|
||||||
|
detector_config[detector_id] = det_config.get("detector_params", {})
|
||||||
|
|
||||||
|
# LEGACY STRUCTURE: Check for nested detectors
|
||||||
|
elif (
|
||||||
|
self.config.detector_params
|
||||||
|
and hasattr(self.config.detector_params, "detectors")
|
||||||
|
and self.config.detector_params.detectors
|
||||||
|
):
|
||||||
|
detector_config = self.config.detector_params.detectors
|
||||||
|
|
||||||
|
# Handle flat params - group them into generic containers
|
||||||
|
elif self.config.detector_params:
|
||||||
|
detector_params = self._extract_detector_params()
|
||||||
|
|
||||||
|
# Create a flat dictionary of parameters
|
||||||
|
flat_params = {}
|
||||||
|
|
||||||
|
# Extract from model_params
|
||||||
|
if "model_params" in detector_params and isinstance(
|
||||||
|
detector_params["model_params"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["model_params"])
|
||||||
|
|
||||||
|
# Extract from metadata
|
||||||
|
if "metadata" in detector_params and isinstance(
|
||||||
|
detector_params["metadata"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["metadata"])
|
||||||
|
|
||||||
|
# Extract from kwargs
|
||||||
|
if "kwargs" in detector_params and isinstance(
|
||||||
|
detector_params["kwargs"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["kwargs"])
|
||||||
|
|
||||||
|
# Add any other direct parameters, but skip container dictionaries
|
||||||
|
for k, v in detector_params.items():
|
||||||
|
if (
|
||||||
|
k not in ["model_params", "metadata", "kwargs", "params"]
|
||||||
|
and v is not None
|
||||||
|
):
|
||||||
|
flat_params[k] = v
|
||||||
|
|
||||||
|
# Add all flattened parameters directly to detector configuration
|
||||||
|
detector_config[self.config.detector_id] = flat_params
|
||||||
|
|
||||||
|
# Ensure we have a valid detectors map even if all checks fail
|
||||||
|
if not detector_config:
|
||||||
|
detector_config = {self.config.detector_id: {}}
|
||||||
|
|
||||||
|
payload["detectors"] = detector_config
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# Direct API format remains unchanged
|
||||||
|
else:
|
||||||
|
# DIRECT MODE: Use flat parameters for API compatibility
|
||||||
|
# Don't organize into containers for direct mode
|
||||||
|
detector_params = self._extract_detector_params()
|
||||||
|
|
||||||
|
# Flatten the parameters for direct mode too
|
||||||
|
flat_params = {}
|
||||||
|
|
||||||
|
# Extract from model_params
|
||||||
|
if "model_params" in detector_params and isinstance(
|
||||||
|
detector_params["model_params"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["model_params"])
|
||||||
|
|
||||||
|
# Extract from metadata
|
||||||
|
if "metadata" in detector_params and isinstance(
|
||||||
|
detector_params["metadata"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["metadata"])
|
||||||
|
|
||||||
|
# Extract from kwargs
|
||||||
|
if "kwargs" in detector_params and isinstance(
|
||||||
|
detector_params["kwargs"], dict
|
||||||
|
):
|
||||||
|
flat_params.update(detector_params["kwargs"])
|
||||||
|
|
||||||
|
# Add any other direct parameters
|
||||||
|
for k, v in detector_params.items():
|
||||||
|
if (
|
||||||
|
k not in ["model_params", "metadata", "kwargs", "params"]
|
||||||
|
and v is not None
|
||||||
|
):
|
||||||
|
flat_params[k] = v
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": formatted_messages,
|
||||||
|
"detector_params": flat_params if flat_params else params or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _call_detector_api(
|
||||||
|
self,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DetectorResponse:
|
||||||
|
"""Call chat detector API with proper endpoint selection"""
|
||||||
|
try:
|
||||||
|
request = self._prepare_chat_request(messages, params)
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
|
||||||
|
logger.info("Making detector API request")
|
||||||
|
logger.debug(f"Request headers: {headers}")
|
||||||
|
logger.debug(f"Request payload: {request}")
|
||||||
|
|
||||||
|
response = await self._make_request(request, headers)
|
||||||
|
return self._extract_detections(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API call failed: {str(e)}", exc_info=True)
|
||||||
|
raise DetectorRequestError(
|
||||||
|
f"Chat detector API call failed: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
|
||||||
|
"""Extract detections from API response"""
|
||||||
|
if not response:
|
||||||
|
logger.debug("Empty response received")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if self.config.use_orchestrator_api:
|
||||||
|
detections = response.get("detections", [])
|
||||||
|
if not detections:
|
||||||
|
# Add default detection when none returned
|
||||||
|
logger.debug("No detections found, adding default low-score detection")
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"detection_type": "risk",
|
||||||
|
"detection": "No",
|
||||||
|
"detector_id": self.config.detector_id,
|
||||||
|
"score": 0.0, # Default low score
|
||||||
|
}
|
||||||
|
]
|
||||||
|
logger.debug(f"Orchestrator detections: {detections}")
|
||||||
|
return cast(
|
||||||
|
DetectorResponse, detections
|
||||||
|
) # Explicit cast to correct return type
|
||||||
|
|
||||||
|
# Direct API returns a list where first item contains detections
|
||||||
|
if isinstance(response, list) and response:
|
||||||
|
detections = (
|
||||||
|
[response[0]] if not isinstance(response[0], list) else response[0]
|
||||||
|
)
|
||||||
|
logger.debug(f"Direct API detections: {detections}")
|
||||||
|
return cast(
|
||||||
|
DetectorResponse, detections
|
||||||
|
) # Explicit cast to correct return type
|
||||||
|
|
||||||
|
logger.debug("No detections found in response")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _process_detection(
|
||||||
|
self, detection: Dict[str, Any]
|
||||||
|
) -> Tuple[
|
||||||
|
Optional[DetectionResult], float
|
||||||
|
]: # Changed return type to match base class
|
||||||
|
"""Process detection result and validate against threshold"""
|
||||||
|
score = detection.get("score", 0.0) # Default to 0.0 if score is missing
|
||||||
|
|
||||||
|
if score > self.score_threshold:
|
||||||
|
metadata = ChatDetectionMetadata(
|
||||||
|
risk_name=(
|
||||||
|
self.config.detector_params.risk_name
|
||||||
|
if self.config.detector_params
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
risk_definition=(
|
||||||
|
self.config.detector_params.risk_definition
|
||||||
|
if self.config.detector_params
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
additional_metadata=detection.get("metadata"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = DetectionResult(
|
||||||
|
detection="Yes",
|
||||||
|
detection_type=detection["detection_type"],
|
||||||
|
score=score,
|
||||||
|
detector_id=detection.get("detector_id", self.config.detector_id),
|
||||||
|
text=detection.get("text", ""),
|
||||||
|
start=detection.get("start", 0),
|
||||||
|
end=detection.get("end", 0),
|
||||||
|
metadata=metadata.to_dict(),
|
||||||
|
)
|
||||||
|
return (result, score)
|
||||||
|
return (None, score)
|
||||||
|
|
||||||
|
async def _run_shield_impl(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
"""Implementation of shield checks for chat messages"""
|
||||||
|
try:
|
||||||
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
|
self._validate_shield(shield)
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(messages)} message(s)")
|
||||||
|
|
||||||
|
# Convert messages keeping only necessary fields
|
||||||
|
chat_messages: List[ChatMessage] = [] # Explicitly typed
|
||||||
|
for msg in messages:
|
||||||
|
message_dict: ChatMessage = {"content": msg.content, "role": msg.role}
|
||||||
|
# Preserve type if present for internal processing
|
||||||
|
if hasattr(msg, "type"):
|
||||||
|
message_dict["type"] = msg.type
|
||||||
|
chat_messages.append(message_dict)
|
||||||
|
|
||||||
|
logger.debug(f"Prepared messages: {chat_messages}")
|
||||||
|
detections = await self._call_detector_api(chat_messages, params)
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
processed, score = self._process_detection(detection)
|
||||||
|
if processed:
|
||||||
|
logger.info(f"Violation detected: {processed}")
|
||||||
|
return self.create_violation_response(
|
||||||
|
processed, detection.get("detector_id", self.config.detector_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("No violations detected")
|
||||||
|
return RunShieldResponse()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
|
||||||
|
raise ChatDetectorError(f"Shield execution failed: {str(e)}") from e
|
|
@ -0,0 +1,209 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.config import (
|
||||||
|
ContentDetectorConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||||
|
BaseDetector,
|
||||||
|
DetectionResult,
|
||||||
|
DetectorError,
|
||||||
|
DetectorRequestError,
|
||||||
|
DetectorValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Type aliases for better readability
|
||||||
|
ContentRequest = Dict[str, Any]
|
||||||
|
DetectorResponse = List[Dict[str, Any]]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentDetectorError(DetectorError):
|
||||||
|
"""Specific errors for content detector operations"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContentDetector(BaseDetector):
|
||||||
|
"""Detector for content-based safety checks"""
|
||||||
|
|
||||||
|
def __init__(self, config: ContentDetectorConfig) -> None:
|
||||||
|
"""Initialize content detector with configuration"""
|
||||||
|
if not isinstance(config, ContentDetectorConfig):
|
||||||
|
raise DetectorValidationError(
|
||||||
|
"Config must be an instance of ContentDetectorConfig"
|
||||||
|
)
|
||||||
|
super().__init__(config)
|
||||||
|
self.config: ContentDetectorConfig = config
|
||||||
|
logger.info(f"Initialized ContentDetector with config: {vars(config)}")
|
||||||
|
|
||||||
|
def _extract_detector_params(self) -> Dict[str, Any]:
|
||||||
|
"""Extract detector parameters with support for generic format"""
|
||||||
|
if not self.config.detector_params:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Use to_dict() to flatten our categorized structure into what the API expects
|
||||||
|
params = self.config.detector_params.to_dict()
|
||||||
|
logger.debug(f"Extracted detector params: {params}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _prepare_content_request(
|
||||||
|
self, content: str, params: Optional[Dict[str, Any]] = None
|
||||||
|
) -> ContentRequest:
|
||||||
|
"""Prepare the request based on API mode"""
|
||||||
|
|
||||||
|
if self.config.use_orchestrator_api:
|
||||||
|
payload: Dict[str, Any] = {"content": content} # Always use singular form
|
||||||
|
|
||||||
|
# NEW STRUCTURE: Check for top-level detectors first
|
||||||
|
if hasattr(self.config, "detectors") and self.config.detectors:
|
||||||
|
detector_config: Dict[str, Any] = {}
|
||||||
|
for detector_id, det_config in self.config.detectors.items():
|
||||||
|
detector_config[detector_id] = det_config.get("detector_params", {})
|
||||||
|
payload["detectors"] = detector_config
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# LEGACY STRUCTURE: Check for nested detectors
|
||||||
|
elif self.config.detector_params and hasattr(
|
||||||
|
self.config.detector_params, "detectors"
|
||||||
|
):
|
||||||
|
detectors = getattr(self.config.detector_params, "detectors", {})
|
||||||
|
payload["detectors"] = detectors
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# Handle flat params
|
||||||
|
else:
|
||||||
|
detector_config = {}
|
||||||
|
detector_params = self._extract_detector_params()
|
||||||
|
if detector_params:
|
||||||
|
detector_config[self.config.detector_id] = detector_params
|
||||||
|
payload["detectors"] = detector_config
|
||||||
|
return payload
|
||||||
|
|
||||||
|
else:
|
||||||
|
# DIRECT MODE: Use flat parameters for API compatibility
|
||||||
|
detector_params = self._extract_detector_params()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"contents": [content],
|
||||||
|
"detector_params": detector_params if detector_params else params or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
|
||||||
|
"""Extract detections from API response"""
|
||||||
|
if not response:
|
||||||
|
logger.debug("Empty response received")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if self.config.use_orchestrator_api:
|
||||||
|
detections = response.get("detections", [])
|
||||||
|
logger.debug(f"Orchestrator detections: {detections}")
|
||||||
|
return cast(List[Dict[str, Any]], detections)
|
||||||
|
|
||||||
|
# Direct API returns a list of lists where inner list contains detections
|
||||||
|
if isinstance(response, list) and response:
|
||||||
|
detections = response[0] if isinstance(response[0], list) else [response[0]]
|
||||||
|
logger.debug(f"Direct API detections: {detections}")
|
||||||
|
return cast(List[Dict[str, Any]], detections)
|
||||||
|
|
||||||
|
logger.debug("No detections found in response")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _call_detector_api(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DetectorResponse:
|
||||||
|
"""Call detector API with proper endpoint selection"""
|
||||||
|
try:
|
||||||
|
request = self._prepare_content_request(content, params)
|
||||||
|
headers = self._prepare_headers()
|
||||||
|
|
||||||
|
logger.info("Making detector API request")
|
||||||
|
logger.debug(f"Request headers: {headers}")
|
||||||
|
logger.debug(f"Request payload: {request}")
|
||||||
|
|
||||||
|
response = await self._make_request(request, headers)
|
||||||
|
return self._extract_detections(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"API call failed: {str(e)}", exc_info=True)
|
||||||
|
raise DetectorRequestError(
|
||||||
|
f"Content detector API call failed: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _process_detection(
|
||||||
|
self, detection: Dict[str, Any]
|
||||||
|
) -> tuple[Optional[DetectionResult], float]:
|
||||||
|
"""Process detection result and validate against threshold"""
|
||||||
|
if not detection.get("score"):
|
||||||
|
logger.warning("Detection missing score field")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
score = detection.get("score", 0)
|
||||||
|
if score > self.score_threshold:
|
||||||
|
result = DetectionResult(
|
||||||
|
detection="Yes",
|
||||||
|
detection_type=detection["detection_type"],
|
||||||
|
score=score,
|
||||||
|
detector_id=detection.get("detector_id", self.config.detector_id),
|
||||||
|
text=detection.get("text", ""),
|
||||||
|
start=detection.get("start", 0),
|
||||||
|
end=detection.get("end", 0),
|
||||||
|
metadata=detection.get("metadata", {}),
|
||||||
|
)
|
||||||
|
return result, score
|
||||||
|
return None, score
|
||||||
|
|
||||||
|
async def _run_shield_impl(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
"""Implementation of shield checks for content messages"""
|
||||||
|
try:
|
||||||
|
shield = await self.shield_store.get_shield(shield_id)
|
||||||
|
self._validate_shield(shield)
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.content
|
||||||
|
content_str: str
|
||||||
|
|
||||||
|
# Simplified content handling - just check for text attribute or convert to string
|
||||||
|
if hasattr(content, "text"):
|
||||||
|
content_str = str(content.text)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
content_str = " ".join(
|
||||||
|
str(getattr(item, "text", "")) for item in content
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content_str = str(content)
|
||||||
|
|
||||||
|
truncated_content = (
|
||||||
|
content_str[:100] + "..." if len(content_str) > 100 else content_str
|
||||||
|
)
|
||||||
|
logger.debug(f"Checking content: {truncated_content}")
|
||||||
|
|
||||||
|
detections = await self._call_detector_api(content_str, params)
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
processed, score = self._process_detection(detection)
|
||||||
|
if processed:
|
||||||
|
logger.info(f"Violation detected: {processed}")
|
||||||
|
return self.create_violation_response(
|
||||||
|
processed,
|
||||||
|
detection.get("detector_id", self.config.detector_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("No violations detected")
|
||||||
|
return RunShieldResponse()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
|
||||||
|
raise ContentDetectorError(f"Shield execution failed: {str(e)}") from e
|
Loading…
Add table
Add a link
Reference in a new issue