From 87d209d6ef09193ec1500e01be9993ae43e0c5e4 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Mon, 24 Mar 2025 14:46:03 +0000 Subject: [PATCH] Squashed commit of the following: commit a95d2b15b83057e194cf69e57a03deeeeeadd7c2 Author: m-misiura Date: Mon Mar 24 14:33:50 2025 +0000 :construction: working on the config file so that it is inheriting from pydantic base models commit 0546379f817e37bca030247b48c72ce84899a766 Author: m-misiura Date: Mon Mar 24 09:14:31 2025 +0000 :construction: dealing with ruff checks commit 8abe39ee4cb4b8fb77c7252342c4809fa6ddc432 Author: m-misiura Date: Mon Mar 24 09:03:18 2025 +0000 :construction: dealing with mypy errors in `base.py` commit 045f833e79c9a25af3d46af6c8896da91a0e6e62 Author: m-misiura Date: Fri Mar 21 17:31:25 2025 +0000 :construction: fixing mypy errors in content.py commit a9c1ee4e92ad1b5db89039317555cd983edbde65 Author: m-misiura Date: Fri Mar 21 17:09:02 2025 +0000 :construction: fixing mypy errors in chat.py commit 69e8ddc2f8a4e13cecbab30272fd7d685d7864ec Author: m-misiura Date: Fri Mar 21 16:57:28 2025 +0000 :construction: fixing mypy errors commit 56739d69a145c55335ac2859ecbe5b43d556e3b1 Author: m-misiura Date: Fri Mar 21 14:01:03 2025 +0000 :construction: fixing mypy errors in `__init__.py` commit 4d2e3b55c4102ed75d997c8189847bbc5524cb2c Author: m-misiura Date: Fri Mar 21 12:58:06 2025 +0000 :construction: ensuring routing_tables.py do not fail the ci commit c0cc7b4b09ef50d5ec95fdb0a916c7ed228bf366 Author: m-misiura Date: Fri Mar 21 12:09:24 2025 +0000 :bug: fixing linter problems commit 115a50211b604feb4106275204fe7f863da865f6 Author: m-misiura Date: Fri Mar 21 11:47:04 2025 +0000 :bug: fixing ruff errors commit 29b5bfaabc77a35ea036b57f75fded711228dbbf Author: m-misiura Date: Fri Mar 21 11:33:31 2025 +0000 :art: automatic ruff fixes commit 7c5a334c7d4649c2fc297993f89791c1e5643e5b Author: m-misiura Date: Fri Mar 21 11:15:02 2025 +0000 Squashed commit of the following: commit e671aae5bcd4ea57d601ee73c9e3adf5e223e830 Merge: b0dd9a4f 9114bef4 Author: Mac Misiura <82826099+m-misiura@users.noreply.github.com> Date: Fri Mar 21 09:45:08 2025 +0000 Merge branch 'meta-llama:main' into feat_fms_remote_safety_provider commit b0dd9a4f746b0c8c54d1189d381a7ff8e51c812c Author: m-misiura Date: Fri Mar 21 09:27:21 2025 +0000 :memo: updated `provider_id` commit 4c8906c1a4e960968b93251d09d5e5735db15026 Author: m-misiura Date: Thu Mar 20 16:54:46 2025 +0000 :memo: renaming from `fms` to `trustyai_fms` commit 4c0b62abc51b02143b5c818f2d30e1a1fee9e4f3 Merge: bb842d69 54035825 Author: m-misiura Date: Thu Mar 20 16:35:52 2025 +0000 Merge branch 'main' into feat_fms_remote_safety_provider commit bb842d69548df256927465792e0cd107a267d2a0 Author: m-misiura Date: Wed Mar 19 15:03:17 2025 +0000 :sparkles: added a better way of handling params from the configs commit 58b6beabf0994849ac50317ed00b748596e8961d Merge: a22cf36c 7c044845 Author: m-misiura Date: Wed Mar 19 09:19:57 2025 +0000 Merge main into feat_fms_remote_safety_provider, resolve conflicts by keeping main version commit a22cf36c8757f74ed656c1310a4be6b288bf923a Author: m-misiura Date: Wed Mar 5 16:17:46 2025 +0000 :tada: added a new remote safety provider compatible with FMS Orchestrator API and Detectors API Signed-off-by: m-misiura --- .../distribution/routers/routing_tables.py | 114 +- llama_stack/providers/registry/safety.py | 9 + .../remote/safety/trustyai_fms/__init__.py | 119 ++ .../remote/safety/trustyai_fms/config.py | 562 ++++++ .../safety/trustyai_fms/detectors/base.py | 1561 +++++++++++++++++ .../safety/trustyai_fms/detectors/chat.py | 329 ++++ .../safety/trustyai_fms/detectors/content.py | 209 +++ 7 files changed, 2883 insertions(+), 20 deletions(-) create mode 100644 llama_stack/providers/remote/safety/trustyai_fms/__init__.py create mode 100644 llama_stack/providers/remote/safety/trustyai_fms/config.py create mode 100644 llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py create mode 100644 llama_stack/providers/remote/safety/trustyai_fms/detectors/chat.py create mode 100644 llama_stack/providers/remote/safety/trustyai_fms/detectors/content.py diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d444b03a3..048a67c3b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -118,7 +118,9 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -153,7 +155,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -191,24 +195,32 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: return None # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + if not check_access(obj, get_auth_attributes()): + logger.debug( + f"Access denied to {type} '{identifier}' based on attribute mismatch" + ) return None return obj async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -223,7 +235,9 @@ class CommonRoutingTableImpl(RoutingTable): creator_attributes = get_auth_attributes() if creator_attributes: obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + logger.info( + f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity" + ) registered_obj = await register_object_with_provider(obj, p) # TODO: This needs to be fixed for all APIs once they return the registered object @@ -242,9 +256,7 @@ class CommonRoutingTableImpl(RoutingTable): # Apply attribute-based access control filtering if filtered_objs: filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + obj for obj in filtered_objs if check_access(obj, get_auth_attributes()) ] return filtered_objs @@ -283,7 +295,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = ModelWithACL( identifier=model_id, provider_resource_id=provider_model_id, @@ -302,8 +316,54 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def initialize(self) -> None: + """Initialize routing table and providers""" + # First do the common initialization + await super().initialize() + + # Then explicitly initialize each safety provider + for provider_id, provider in self.impls_by_provider_id.items(): + api = get_impl_api(provider) + if api == Api.safety: + logger.info(f"Explicitly initializing safety provider: {provider_id}") + await provider.initialize() + + # Fetch shields after initialization - with robust error handling + try: + # Check if the provider implements list_shields + if hasattr(provider, "list_shields") and callable( + getattr(provider, "list_shields") + ): + shields_response = await provider.list_shields() + if ( + shields_response + and hasattr(shields_response, "data") + and shields_response.data + ): + for shield in shields_response.data: + # Ensure type is set + if not hasattr(shield, "type") or not shield.type: + shield.type = ResourceType.shield.value + await self.dist_registry.register(shield) + logger.info( + f"Registered {len(shields_response.data)} shields from provider {provider_id}" + ) + else: + logger.info(f"No shields found for provider {provider_id}") + else: + logger.info( + f"Provider {provider_id} does not support listing shields" + ) + except Exception as e: + # Log the error but continue initialization + logger.warning( + f"Error listing shields from provider {provider_id}: {str(e)}" + ) + async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse( + data=await self.get_all_with_type(ResourceType.shield.value) + ) async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) @@ -368,14 +428,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError("No provider available. Please configure a vector_io provider.") + raise ValueError( + "No provider available. Please configure a vector_io provider." + ) model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + raise ValueError( + f"Model {embedding_model} does not have an embedding dimension" + ) vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -397,7 +461,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse( + data=await self.get_all_with_type(ResourceType.dataset.value) + ) async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) @@ -459,10 +525,14 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + scoring_fn = await self.get_object_by_identifier( + "scoring_function", scoring_fn_id + ) if scoring_fn is None: raise ValueError(f"Scoring function '{scoring_fn_id}' not found") return scoring_fn @@ -565,8 +635,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 54dc51034..d18a68176 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_type="trustyai_fms", + pip_packages=[], + module="llama_stack.providers.remote.safety.trustyai_fms", + config_class="llama_stack.providers.remote.safety.trustyai_fms.config.FMSSafetyProviderConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/safety/trustyai_fms/__init__.py b/llama_stack/providers/remote/safety/trustyai_fms/__init__.py new file mode 100644 index 000000000..2fa0811a0 --- /dev/null +++ b/llama_stack/providers/remote/safety/trustyai_fms/__init__.py @@ -0,0 +1,119 @@ +import logging +from typing import Any, Dict, Optional, Union + +# Remove register_provider import since registration is in registry/safety.py +from llama_stack.apis.safety import Safety +from llama_stack.providers.remote.safety.trustyai_fms.config import ( + ChatDetectorConfig, + ContentDetectorConfig, + DetectorParams, + EndpointType, + FMSSafetyProviderConfig, +) +from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import ( + BaseDetector, + DetectorProvider, +) +from llama_stack.providers.remote.safety.trustyai_fms.detectors.chat import ChatDetector +from llama_stack.providers.remote.safety.trustyai_fms.detectors.content import ( + ContentDetector, +) + +# Set up logging +logger = logging.getLogger(__name__) + +# Type aliases for better readability +ConfigType = Union[ContentDetectorConfig, ChatDetectorConfig, FMSSafetyProviderConfig] +DetectorType = Union[BaseDetector, DetectorProvider] + + +class DetectorConfigError(ValueError): + """Raised when detector configuration is invalid""" + + pass + + +async def create_fms_provider(config: Dict[str, Any]) -> Safety: + """Create FMS safety provider instance. + + Args: + config: Configuration dictionary + + Returns: + Safety: Configured FMS safety provider + """ + logger.debug("Creating trustyai-fms provider") + return await get_adapter_impl(FMSSafetyProviderConfig(**config)) + + +async def get_adapter_impl( + config: Union[Dict[str, Any], FMSSafetyProviderConfig], + _deps: Optional[Dict[str, Any]] = None, +) -> DetectorType: + """Get appropriate detector implementation(s) based on config type. + + Args: + config: Configuration dictionary or FMSSafetyProviderConfig instance + _deps: Optional dependencies for testing/injection + + Returns: + Configured detector implementation + + Raises: + DetectorConfigError: If configuration is invalid + """ + try: + if isinstance(config, FMSSafetyProviderConfig): + provider_config = config + else: + provider_config = FMSSafetyProviderConfig(**config) + + detectors: Dict[str, DetectorType] = {} + + # Changed from provider_config.detectors to provider_config.shields + for shield_id, shield_config in provider_config.shields.items(): + impl: BaseDetector + if isinstance(shield_config, ChatDetectorConfig): + impl = ChatDetector(shield_config) + elif isinstance(shield_config, ContentDetectorConfig): + impl = ContentDetector(shield_config) + else: + raise DetectorConfigError( + f"Invalid shield config type for {shield_id}: {type(shield_config)}" + ) + await impl.initialize() + detectors[shield_id] = impl + + detectors_for_provider: Dict[str, BaseDetector] = {} + for shield_id, detector in detectors.items(): + if isinstance(detector, BaseDetector): + detectors_for_provider[shield_id] = detector + + return DetectorProvider(detectors_for_provider) + + except Exception as e: + raise DetectorConfigError( + f"Failed to create detector implementation: {str(e)}" + ) from e + + +__all__ = [ + # Factory methods + "get_adapter_impl", + "create_fms_provider", + # Configurations + "ContentDetectorConfig", + "ChatDetectorConfig", + "FMSSafetyProviderConfig", + "EndpointType", + "DetectorParams", + # Implementations + "ChatDetector", + "ContentDetector", + "BaseDetector", + "DetectorProvider", + # Types + "ConfigType", + "DetectorType", + "DetectorConfigError", +] diff --git a/llama_stack/providers/remote/safety/trustyai_fms/config.py b/llama_stack/providers/remote/safety/trustyai_fms/config.py new file mode 100644 index 000000000..a11ce6c86 --- /dev/null +++ b/llama_stack/providers/remote/safety/trustyai_fms/config.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Union +from urllib.parse import urlparse + +from pydantic import BaseModel, Field, model_validator + +from llama_stack.schema_utils import json_schema_type + + +class MessageType(Enum): + """Valid message types for detectors""" + + USER = "user" + SYSTEM = "system" + TOOL = "tool" + COMPLETION = "completion" + + @classmethod + def as_set(cls) -> Set[str]: + """Get all valid message types as a set""" + return {member.value for member in cls} + + +class EndpointType(Enum): + """API endpoint types and their paths""" + + DIRECT_CONTENT = { + "path": "/api/v1/text/contents", + "version": "v1", + "type": "content", + } + DIRECT_CHAT = {"path": "/api/v1/text/chat", "version": "v1", "type": "chat"} + ORCHESTRATOR_CONTENT = { + "path": "/api/v2/text/detection/content", + "version": "v2", + "type": "content", + } + ORCHESTRATOR_CHAT = { + "path": "/api/v2/text/detection/chat", + "version": "v2", + "type": "chat", + } + + @classmethod + def get_endpoint(cls, is_orchestrator: bool, is_chat: bool) -> EndpointType: + """Get appropriate endpoint based on configuration""" + if is_orchestrator: + return cls.ORCHESTRATOR_CHAT if is_chat else cls.ORCHESTRATOR_CONTENT + return cls.DIRECT_CHAT if is_chat else cls.DIRECT_CONTENT + + +@json_schema_type +@dataclass +class DetectorParams: + """Flexible parameter container supporting nested structure and arbitrary parameters""" + + # Store all parameters in a single dictionary for maximum flexibility + params: Dict[str, Any] = field(default_factory=dict) + + # Parameter categories for organization + model_params: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + kwargs: Dict[str, Any] = field(default_factory=dict) + + # Store detectors directly as an attribute (not in params) for orchestrator mode + _raw_detectors: Optional[Dict[str, Dict[str, Any]]] = None + + # Standard parameters kept for backward compatibility + @property + def regex(self) -> Optional[List[str]]: + return self.params.get("regex") + + @regex.setter + def regex(self, value: List[str]) -> None: + self.params["regex"] = value + + @property + def temperature(self) -> Optional[float]: + return self.model_params.get("temperature") or self.params.get("temperature") + + @temperature.setter + def temperature(self, value: float) -> None: + self.model_params["temperature"] = value + + @property + def risk_name(self) -> Optional[str]: + return self.metadata.get("risk_name") or self.params.get("risk_name") + + @risk_name.setter + def risk_name(self, value: str) -> None: + self.metadata["risk_name"] = value + + @property + def risk_definition(self) -> Optional[str]: + return self.metadata.get("risk_definition") or self.params.get( + "risk_definition" + ) + + @risk_definition.setter + def risk_definition(self, value: str) -> None: + self.metadata["risk_definition"] = value + + @property + def orchestrator_detectors(self) -> Dict[str, Dict[str, Any]]: + """Return detectors in the format required by orchestrator API""" + if ( + not hasattr(self, "_detectors") or not self._detectors + ): # Direct attribute access + return {} + + flattened = {} + for ( + detector_id, + detector_config, + ) in self._detectors.items(): # Direct attribute access + # Create a flattened version without extra nesting + flat_config = {} + + # Extract detector_params if present and flatten them + params = detector_config.get("detector_params", {}) + if isinstance(params, dict): + # Move params up to top level + for key, value in params.items(): + flat_config[key] = value + + flattened[detector_id] = flat_config + + return flattened + + @property + def formatted_detectors(self) -> Optional[Dict[str, Dict[str, Any]]]: + """Return detectors properly formatted for orchestrator API""" + # Direct return for API usage - avoid calling other properties + if hasattr(self, "_detectors") and self._detectors: + return self.orchestrator_detectors + return None + + @formatted_detectors.setter + def formatted_detectors(self, value: Dict[str, Dict[str, Any]]) -> None: + self._detectors = value + + @property + def detectors(self) -> Optional[Dict[str, Dict[str, Any]]]: + """COMPATIBILITY: Returns the same as formatted_detectors to maintain API compatibility""" + # Using a different implementation to avoid the redefinition error + # while maintaining the same functionality + if not hasattr(self, "_detectors") or not self._detectors: + return None + return self.orchestrator_detectors + + @detectors.setter + def detectors(self, value: Dict[str, Dict[str, Any]]) -> None: + """COMPATIBILITY: Set detectors while maintaining compatibility""" + self._detectors = value + + # And fix the __setitem__ method: + def __setitem__(self, key: str, value: Any) -> None: + """Allow dictionary-like assignment with smart categorization""" + # Special handling for known params + if key == "detectors": + self._detectors = value # Set underlying attribute directly + return + elif key == "regex": + self.params[key] = value + return + + # Rest of the method remains unchanged + known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"] + known_metadata = ["risk_name", "risk_definition", "category", "severity"] + + if key in known_model_params: + self.model_params[key] = value + elif key in known_metadata: + self.metadata[key] = value + else: + self.kwargs[key] = value + + def __init__(self, **kwargs): + """Initialize from any keyword arguments with smart categorization""" + # Initialize containers + self.params = {} + self.model_params = {} + self.metadata = {} + self.kwargs = {} + self._raw_detectors = None + + # Special handling for nested detectors structure + if "detectors" in kwargs: + self._raw_detectors = kwargs.pop("detectors") + + # Special handling for regex + if "regex" in kwargs: + self.params["regex"] = kwargs.pop("regex") + + # Categorize known parameters + known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"] + known_metadata = ["risk_name", "risk_definition", "category", "severity"] + + # Explicit categories if provided + if "model_params" in kwargs: + self.model_params.update(kwargs.pop("model_params")) + + if "metadata" in kwargs: + self.metadata.update(kwargs.pop("metadata")) + + if "kwargs" in kwargs: + self.kwargs.update(kwargs.pop("kwargs")) + + # Categorize remaining parameters + for key, value in kwargs.items(): + if key in known_model_params: + self.model_params[key] = value + elif key in known_metadata: + self.metadata[key] = value + else: + self.kwargs[key] = value + + def __getitem__(self, key: str) -> Any: + """Allow dictionary-like access with category lookup""" + if key in self.params: + return self.params[key] + elif key in self.model_params: + return self.model_params[key] + elif key in self.metadata: + return self.metadata[key] + elif key in self.kwargs: + return self.kwargs[key] + return None + + def get(self, key: str, default: Any = None) -> Any: + """Dictionary-style get with category lookup""" + result = self.__getitem__(key) + return default if result is None else result + + def set(self, key: str, value: Any) -> None: + """Set a parameter value with smart categorization""" + self.__setitem__(key, value) + + def update(self, params: Dict[str, Any]) -> None: + """Update with multiple parameters, respecting categories""" + for key, value in params.items(): + self.__setitem__(key, value) + + def to_dict(self) -> Dict[str, Any]: + """Convert all parameters to a flat dictionary for API requests""" + result = {} + + # Add core parameters + result.update(self.params) + + # Add all categorized parameters, flattened + result.update(self.model_params) + result.update(self.metadata) + result.update(self.kwargs) + + return result + + def to_categorized_dict(self) -> Dict[str, Any]: + """Convert to a structured dictionary with categories preserved""" + result = dict(self.params) + + if self.model_params: + result["model_params"] = dict(self.model_params) + + if self.metadata: + result["metadata"] = dict(self.metadata) + + if self.kwargs: + result["kwargs"] = dict(self.kwargs) + + return result + + def create_flattened_detector_configs(self) -> Dict[str, Dict[str, Any]]: + """Create flattened detector configurations for orchestrator mode. + This removes the extra detector_params nesting that causes API errors. + """ + if not self.detectors: + return {} + + flattened = {} + for detector_id, detector_config in self.detectors.items(): + # Create a flattened version without extra nesting + flat_config = {} + + # Extract detector_params if present and flatten them + params = detector_config.get("detector_params", {}) + if isinstance(params, dict): + # Move params up to top level + for key, value in params.items(): + flat_config[key] = value + + flattened[detector_id] = flat_config + + return flattened + + def validate(self) -> None: + """Validate parameter values""" + if self.temperature is not None and not 0 <= self.temperature <= 1: + raise ValueError("Temperature must be between 0 and 1") + + +@json_schema_type +@dataclass +class BaseDetectorConfig: + """Base configuration for all detectors with flexible parameter handling""" + + detector_id: str + confidence_threshold: float = 0.5 + message_types: Set[str] = field(default_factory=lambda: MessageType.as_set()) + auth_token: Optional[str] = None + detector_params: Optional[DetectorParams] = None + + # URL fields directly on detector configs + detector_url: Optional[str] = None + orchestrator_url: Optional[str] = None + + # Flexible storage for any additional parameters + _extra_params: Dict[str, Any] = field(default_factory=dict) + + # Runtime execution parameters + max_concurrency: int = 10 # Maximum concurrent API requests + request_timeout: float = 30.0 # HTTP request timeout in seconds + max_retries: int = 3 # Maximum number of retry attempts + backoff_factor: float = 1.5 # Exponential backoff multiplier + max_keepalive_connections: int = 5 # Max number of keepalive connections + max_connections: int = 10 # Max number of connections in the pool + + @property + def use_orchestrator_api(self) -> bool: + """Determine if orchestrator API should be used""" + return bool(self.orchestrator_url) + + def __post_init__(self) -> None: + """Process configuration after initialization""" + # Convert list/tuple message_types to set + if isinstance(self.message_types, (list, tuple)): + self.message_types = set(self.message_types) + + # Validate message types + invalid_types = self.message_types - MessageType.as_set() + if invalid_types: + raise ValueError( + f"Invalid message types: {invalid_types}. " + f"Valid types are: {MessageType.as_set()}" + ) + + # Initialize detector_params if needed + if self.detector_params is None: + self.detector_params = DetectorParams() + + # Handle legacy URL field names + if hasattr(self, "base_url") and self.base_url and not self.detector_url: + self.detector_url = self.base_url + + if ( + hasattr(self, "orchestrator_base_url") + and self.orchestrator_base_url + and not self.orchestrator_url + ): + self.orchestrator_url = self.orchestrator_base_url + + def validate(self) -> None: + """Validate configuration""" + # Validate detector_params + if self.detector_params: + self.detector_params.validate() + + # Validate that at least one URL is provided + if not self.detector_url and not self.orchestrator_url: + raise ValueError(f"No URL provided for detector {self.detector_id}") + + # Validate URLs if present + for url_name, url in [ + ("detector_url", self.detector_url), + ("orchestrator_url", self.orchestrator_url), + ]: + if url: + self._validate_url(url, url_name) + + def _validate_url(self, url: str, url_name: str) -> None: + """Validate URL format""" + parsed = urlparse(url) + if not all([parsed.scheme, parsed.netloc]): + raise ValueError(f"Invalid {url_name} format: {url}") + if parsed.scheme not in {"http", "https"}: + raise ValueError(f"Invalid {url_name} scheme: {parsed.scheme}") + + def get(self, key: str, default: Any = None) -> Any: + """Get parameter with fallback to extra parameters""" + try: + return getattr(self, key) + except AttributeError: + return self._extra_params.get(key, default) + + def set(self, key: str, value: Any) -> None: + """Set parameter, storing in extra_params if not a standard field""" + if hasattr(self, key) and key not in ["_extra_params"]: + setattr(self, key, value) + else: + self._extra_params[key] = value + + @property + def is_chat(self) -> bool: + """Default implementation, should be overridden by subclasses""" + return False + + +@json_schema_type +@dataclass +class ContentDetectorConfig(BaseDetectorConfig): + """Configuration for content detectors""" + + @property + def is_chat(self) -> bool: + """Content detectors are not chat detectors""" + return False + + +@json_schema_type +@dataclass +class ChatDetectorConfig(BaseDetectorConfig): + """Configuration for chat detectors""" + + @property + def is_chat(self) -> bool: + """Chat detectors are chat detectors""" + return True + + +@json_schema_type +class FMSSafetyProviderConfig(BaseModel): + """Configuration for the FMS Safety Provider""" + + shields: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + + # Rename _detectors to remove the leading underscore + detectors_internal: Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]] = ( + Field(default_factory=dict, exclude=True) + ) + + # Provider-level orchestrator URL (can be copied to shields if needed) + orchestrator_url: Optional[str] = None + + class Config: + arbitrary_types_allowed = True + + # Add a model validator to replace __post_init__ + @model_validator(mode="after") + def setup_config(self): + """Process shield configurations""" + # Process shields into detector objects + self._process_shields() + + # Replace shields dictionary with processed detector configs + self.shields = self.detectors_internal + + # Validate all shields + for shield in self.shields.values(): + shield.validate() + + return self + + def _process_shields(self): + """Process all shield configurations into detector configs""" + for shield_id, config in self.shields.items(): + if isinstance(config, dict): + # Copy the config to avoid modifying the original + shield_config = dict(config) + + # Check if this shield has nested detectors + nested_detectors = shield_config.pop("detectors", None) + + # Determine detector type + detector_type = shield_config.pop("type", None) + is_chat = ( + detector_type == "chat" + if detector_type + else shield_config.pop("is_chat", False) + ) + + # Set detector ID + shield_config["detector_id"] = shield_id + + # Handle URL fields + # First handle legacy field names + if "base_url" in shield_config and "detector_url" not in shield_config: + shield_config["detector_url"] = shield_config.pop("base_url") + + if ( + "orchestrator_base_url" in shield_config + and "orchestrator_url" not in shield_config + ): + shield_config["orchestrator_url"] = shield_config.pop( + "orchestrator_base_url" + ) + + # If no orchestrator_url in shield but provider has one, copy it + if self.orchestrator_url and "orchestrator_url" not in shield_config: + shield_config["orchestrator_url"] = self.orchestrator_url + + # Initialize detector_params with proper structure for nested detectors + detector_params_dict = shield_config.get("detector_params", {}) + if not isinstance(detector_params_dict, dict): + detector_params_dict = {} + + # Create detector_params object + detector_params = DetectorParams(**detector_params_dict) + + # Add nested detectors if present + if nested_detectors: + detector_params.detectors = nested_detectors + + shield_config["detector_params"] = detector_params + + # Create appropriate detector config + detector_class = ( + ChatDetectorConfig if is_chat else ContentDetectorConfig + ) + self.detectors_internal[shield_id] = detector_class(**shield_config) + + @property + def all_detectors( + self, + ) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]: + """Get all detector configurations""" + return self.detectors_internal + + # Update other methods to use detectors_internal instead of _detectors + def get_detectors_by_type( + self, message_type: Union[str, MessageType] + ) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]: + """Get detectors for a specific message type""" + type_value = ( + message_type.value + if isinstance(message_type, MessageType) + else message_type + ) + return { + shield_id: shield + for shield_id, shield in self.detectors_internal.items() + if type_value in shield.message_types + } + + # Convenience properties + @property + def user_message_detectors(self): + return self.get_detectors_by_type(MessageType.USER) + + @property + def system_message_detectors(self): + return self.get_detectors_by_type(MessageType.SYSTEM) + + @property + def tool_response_detectors(self): + return self.get_detectors_by_type(MessageType.TOOL) + + @property + def completion_message_detectors(self): + return self.get_detectors_by_type(MessageType.COMPLETION) diff --git a/llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py b/llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py new file mode 100644 index 000000000..e132db695 --- /dev/null +++ b/llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py @@ -0,0 +1,1561 @@ +from __future__ import annotations + +import asyncio +import datetime +import logging +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, ClassVar, Dict, List, Optional, Tuple, cast +from urllib.parse import urlparse + +import httpx + +from llama_stack.apis.inference import ( + CompletionMessage, + Message, + SystemMessage, + ToolResponseMessage, + UserMessage, +) +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ShieldStore, + ViolationLevel, +) +from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields +from llama_stack.providers.datatypes import ShieldsProtocolPrivate +from llama_stack.providers.remote.safety.trustyai_fms.config import ( + BaseDetectorConfig, + EndpointType, +) + +# Configure logging +logger = logging.getLogger(__name__) + + +# Custom exceptions +class DetectorError(Exception): + """Base exception for detector errors""" + + pass + + +class DetectorConfigError(DetectorError): + """Configuration related errors""" + + pass + + +class DetectorRequestError(DetectorError): + """HTTP request related errors""" + + pass + + +class DetectorValidationError(DetectorError): + """Validation related errors""" + + pass + + +class DetectorNetworkError(DetectorError): + """Network connectivity issues""" + + +class DetectorTimeoutError(DetectorError): + """Request timeout errors""" + + +class DetectorRateLimitError(DetectorError): + """Rate limiting errors""" + + +class DetectorAuthError(DetectorError): + """Authentication errors""" + + +# Type aliases +MessageDict = Dict[str, Any] +DetectorResponse = Dict[str, Any] +Headers = Dict[str, str] +RequestPayload = Dict[str, Any] + + +class MessageTypes(Enum): + """Message type constants""" + + USER = auto() + SYSTEM = auto() + TOOL = auto() + COMPLETION = auto() + + @classmethod + def to_str(cls, value: MessageTypes) -> str: + """Convert enum to string representation""" + return value.name.lower() + + +@dataclass(frozen=True) +class DetectionResult: + """Structured detection result""" + + detection: str + detection_type: str + score: float + detector_id: str + text: str = "" + start: int = 0 + end: int = 0 + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation""" + return { + "detection": self.detection, + "detection_type": self.detection_type, + "score": self.score, + "detector_id": self.detector_id, + "text": self.text, + "start": self.start, + "end": self.end, + **({"metadata": self.metadata} if self.metadata else {}), + } + + +class BaseDetector(Safety, ShieldsProtocolPrivate, ABC): + """Base class for all safety detectors""" + + # Class constants + VALID_SCHEMES: ClassVar[set] = {"http", "https"} + + def __init__(self, config: BaseDetectorConfig) -> None: + """Initialize detector with configuration""" + self.config = config + self.registered_shields: List[Shield] = [] + self.score_threshold: float = config.confidence_threshold + self._http_client: Optional[httpx.AsyncClient] = None + self._shield_store: ShieldStore = SimpleShieldStore() + self._validate_config() + + @property + def shield_store(self) -> ShieldStore: + """Get shield store instance""" + if self._shield_store is None: + self._shield_store = SimpleShieldStore() + return self._shield_store + + @shield_store.setter + def shield_store(self, value: ShieldStore) -> None: + """Set shield store instance""" + self._shield_store = value + + def _validate_config(self) -> None: + """Validate detector configuration""" + if not self.config: + raise DetectorConfigError("Configuration is required") + if not isinstance(self.config, BaseDetectorConfig): + raise DetectorConfigError(f"Invalid config type: {type(self.config)}") + + async def initialize(self) -> None: + """Initialize detector resources""" + logger.info(f"Initializing {self.__class__.__name__}") + self._http_client = httpx.AsyncClient( + timeout=self.config.request_timeout, + limits=httpx.Limits( + max_keepalive_connections=self.config.max_keepalive_connections, + max_connections=self.config.max_connections, + ), + ) + + async def shutdown(self) -> None: + """Clean up detector resources""" + logger.info(f"Shutting down {self.__class__.__name__}") + if self._http_client: + await self._http_client.aclose() + + async def register_shield(self, shield: Shield) -> None: + """Register a shield with the detector""" + if not shield or not shield.identifier: + raise DetectorValidationError("Invalid shield configuration") + logger.info(f"Registering shield {shield.identifier}") + self.registered_shields.append(shield) + + def _should_process_message(self, message: Message) -> bool: + """Check if this detector should process the given message type""" + # Get exact message type + if isinstance(message, UserMessage): + message_type = "user" + elif isinstance(message, SystemMessage): + message_type = "system" + elif isinstance(message, ToolResponseMessage): + message_type = "tool" + elif isinstance(message, CompletionMessage): + message_type = "completion" + else: + logger.warning(f"Unknown message type: {type(message)}") + return False + + # Debug logging + logger.debug( + f"Message type check - type:'{message_type}', " + f"config_types:{self.config.message_types}, " + f"detector:{self.config.detector_id}" + ) + + # Explicit type check + is_supported = message_type in self.config.message_types + if not is_supported: + logger.warning( + f"Message type '{message_type}' not in configured types " + f"{self.config.message_types} for detector {self.config.detector_id}" + ) + return is_supported + + def _filter_messages(self, messages: List[Message]) -> List[Message]: + """Filter messages based on configured message types""" + return [msg for msg in messages if self._should_process_message(msg)] + + def _validate_url(self, url: str) -> None: + """Validate URL format""" + parsed = urlparse(url) + if not all([parsed.scheme, parsed.netloc]): + raise DetectorConfigError(f"Invalid URL format: {url}") + if parsed.scheme not in self.VALID_SCHEMES: + raise DetectorConfigError(f"Invalid URL scheme: {parsed.scheme}") + + def _construct_url(self) -> str: + """Construct API URL based on configuration""" + if self.config.use_orchestrator_api: + if not self.config.orchestrator_url: + raise DetectorConfigError( + "orchestrator_url is required when use_orchestrator_api is True" + ) + base_url = self.config.orchestrator_url + endpoint_info = ( + EndpointType.ORCHESTRATOR_CHAT.value + if self.config.is_chat + else EndpointType.ORCHESTRATOR_CONTENT.value + ) + else: + if not self.config.detector_url: + raise DetectorConfigError( + "detector_url is required when use_orchestrator_api is False" + ) + base_url = self.config.detector_url + endpoint_info = ( + EndpointType.DIRECT_CHAT.value + if self.config.is_chat + else EndpointType.DIRECT_CONTENT.value + ) + + url = f"{base_url.rstrip('/')}{endpoint_info['path']}" + self._validate_url(url) + logger.debug( + f"Constructed URL: {url} for {'chat' if self.config.is_chat else 'content'} endpoint" + ) + return url + + def _extract_detector_params(self) -> Dict[str, Any]: + """Extract detector parameters from configuration""" + detector_params: Dict[str, Any] = {} + + if ( + hasattr(self.config, "detector_params") + and self.config.detector_params is not None + ): + # For chat detectors, extract model_params and metadata directly + if hasattr(self.config.detector_params, "model_params"): + detector_params.update(self.config.detector_params.model_params) + + if hasattr(self.config.detector_params, "metadata"): + detector_params.update(self.config.detector_params.metadata) + + # Include any direct parameters + for k, v in vars(self.config.detector_params).items(): + if v is not None and k not in [ + "model_params", + "metadata", + "kwargs", + "params", + ]: + if not (isinstance(v, (dict, list)) and len(v) == 0): + detector_params[k] = v + + return detector_params + + def _prepare_headers(self) -> Headers: + """Prepare request headers based on configuration""" + headers: Headers = { + "accept": "application/json", + "Content-Type": "application/json", + } + + if not self.config.use_orchestrator_api and self.config.detector_id: + headers["detector-id"] = self.config.detector_id + + if self.config.auth_token: + headers["Authorization"] = f"Bearer {self.config.auth_token}" + + return headers + + def _prepare_request_payload( + self, messages: List[Message], params: Optional[Dict[str, Any]] = None + ) -> RequestPayload: + """Prepare request payload based on endpoint type and orchestrator mode""" + logger.debug( + f"Preparing payload - use_orchestrator: {self.config.use_orchestrator_api}, " + f"detector_id: {self.config.detector_id}" + ) + + if self.config.use_orchestrator_api: + payload: RequestPayload = {} + + # NEW STRUCTURE: Handle detectors at top level instead of under detector_params + if hasattr(self.config, "detectors") and self.config.detectors: + # Process the new structure with detectors at top level + detector_config: Dict[str, Any] = {} + for detector_id, det_config in self.config.detectors.items(): + detector_config[detector_id] = det_config.get("detector_params", {}) + + payload["detectors"] = detector_config + + # BACKWARD COMPATIBILITY: Handle legacy structures + elif ( + hasattr(self.config, "detector_params") + and self.config.detector_params is not None + ): + # Create detector configuration + detector_config = {} + + # Extract parameters directly without wrapping them + detector_params = {} + + # For chat detectors, extract model_params and metadata properly + if hasattr(self.config.detector_params, "model_params"): + detector_params.update(self.config.detector_params.model_params) + + if hasattr(self.config.detector_params, "metadata"): + detector_params.update(self.config.detector_params.metadata) + + # Include direct parameters + for k, v in vars(self.config.detector_params).items(): + if v is not None and k not in [ + "model_params", + "metadata", + "kwargs", + "params", + "detectors", + ]: + if not (isinstance(v, (dict, list)) and len(v) == 0): + detector_params[k] = v + + # Handle composite detectors + if ( + hasattr(self.config.detector_params, "detectors") + and self.config.detector_params.detectors + ): + payload["detectors"] = self.config.detector_params.detectors + else: + # Add detector configuration to payload + detector_config[self.config.detector_id] = detector_params + payload["detectors"] = detector_config + + # Add content or messages based on mode + if self.config.is_chat: + payload["messages"] = [msg.dict() for msg in messages] + else: + payload["content"] = messages[0].content + + logger.debug(f"Prepared orchestrator payload: {payload}") + return payload + else: + # DIRECT MODE: Respect API-specific formats + detector_params = self._extract_detector_params() + + # Extract parameters from nested containers if present + flattened_params = {} + + # Handle complex parameter structures by flattening them for direct mode + if isinstance(detector_params, dict): + # First level: check for container structure + for container_name in ["metadata", "model_params", "kwargs"]: + if container_name in detector_params: + # Extract and flatten parameters from containers + container = detector_params.get(container_name, {}) + if isinstance(container, dict): + flattened_params.update(container) + + # If no container structure was found, use params directly + if not flattened_params: + flattened_params = detector_params + else: + flattened_params = detector_params + + # Merge with any passed parameters + if params: + flattened_params.update(params) + + # Remove empty params dictionary if present + if "params" in flattened_params and ( + flattened_params["params"] == {} or flattened_params["params"] is None + ): + del flattened_params["params"] + + if self.config.is_chat: + payload = { + "messages": [msg.dict() for msg in messages], + "detector_params": flattened_params if flattened_params else {}, + } + else: + # For content APIs in direct mode, use plural form for compatibility + payload = { + "contents": [ + messages[0].content + ], # Send as array for all content APIs + "detector_params": flattened_params if flattened_params else {}, + } + + logger.debug(f"Direct mode payload: {payload}") + return payload + + async def _make_request( + self, + request: RequestPayload, + headers: Optional[Headers] = None, + timeout: Optional[float] = None, + ) -> DetectorResponse: + """Make HTTP request with error handling and retries""" + if not self._http_client: + raise DetectorError("HTTP client not initialized") + + url = self._construct_url() + default_headers = self._prepare_headers() + headers = {**default_headers, **(headers or {})} + + for attempt in range(self.config.max_retries): + try: + response = await self._http_client.post( + url, + json=request, + headers=headers, + timeout=timeout or self.config.request_timeout, + ) + + # Handle different error codes specifically + if response.status_code == 429: + # Rate limit handling + retry_after = int( + response.headers.get( + "Retry-After", self.config.backoff_factor * 2 + ) + ) + logger.warning(f"Rate limited, waiting {retry_after}s before retry") + await asyncio.sleep(retry_after) + continue + + elif response.status_code == 401: + raise DetectorAuthError(f"Authentication failed: {response.text}") + + elif response.status_code == 503: + # Service unavailable - return informative error if this is our last retry + if attempt == self.config.max_retries - 1: + error_details = { + "timestamp": datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + "service": urlparse(url).netloc, + "detector_id": self.config.detector_id, + "retries_attempted": self.config.max_retries, + "status_code": 503, + } + + logger.error( + f"Service unavailable after {self.config.max_retries} attempts: " + f"{error_details['service']} for detector {self.config.detector_id}" + ) + + raise DetectorNetworkError( + f"Safety service is currently unavailable. The system attempted {self.config.max_retries}" + f"retries but couldn't connect to {error_details['service']}. Please try again " + f"later or contact your administrator if the problem persists." + ) + + # Continue with backoff if we have more retries + logger.warning( + f"Service unavailable (attempt {attempt+1}/{self.config.max_retries}), retrying..." + ) + else: + # SUCCESS PATH: Return immediately for successful responses + response.raise_for_status() + return cast(DetectorResponse, response.json()) + + except httpx.TimeoutException as e: + logger.error( + f"Request timed out (attempt {attempt + 1}/{self.config.max_retries})" + ) + if attempt == self.config.max_retries - 1: + raise DetectorTimeoutError( + f"Request timed out after {self.config.max_retries} attempts" + ) from e + + except httpx.HTTPStatusError as e: + # More specific error handling based on status code + logger.error( + f"HTTP error {e.response.status_code} (attempt {attempt + 1}/{self.config.max_retries}): {e.response.text}" + ) + if attempt == self.config.max_retries - 1: + raise DetectorRequestError( + f"API Error after {self.config.max_retries} attempts: {e.response.text}" + ) from e + + # Exponential backoff + jitter = random.uniform(0.8, 1.2) + await asyncio.sleep((self.config.backoff_factor**attempt) * jitter) + raise DetectorRequestError( + f"Request failed after {self.config.max_retries} attempts" + ) + + def _process_detection( + self, detection: Dict[str, Any] + ) -> Tuple[Optional[DetectionResult], float]: + """Process detection result and return both result and score""" + score = detection.get("score", 0.0) + + if "score" not in detection: + logger.warning("Detection missing score field") + return None, 0.0 + + if score > self.score_threshold: + return ( + DetectionResult( + detection="Yes", + detection_type=detection["detection_type"], + score=score, + detector_id=detection.get("detector_id", self.config.detector_id), + text=detection.get("text", ""), + start=detection.get("start", 0), + end=detection.get("end", 0), + metadata=detection.get("metadata"), + ), + score, + ) + return None, score + + def create_violation_response( + self, + detection: DetectionResult, + detector_id: str, + level: ViolationLevel = ViolationLevel.ERROR, + ) -> RunShieldResponse: + """Create standardized violation response""" + return RunShieldResponse( + violation=SafetyViolation( + user_message=f"Content flagged by {detector_id} as {detection.detection_type} with confidence {detection.score:.2f}", + violation_level=level, + metadata=detection.to_dict(), + ) + ) + + def _validate_shield(self, shield: Shield) -> None: + """Validate shield configuration""" + if not shield: + raise DetectorValidationError("Shield not found") + if not shield.identifier: + raise DetectorValidationError("Shield missing identifier") + + @abstractmethod + async def _run_shield_impl( + self, + shield_id: str, + messages: List[Message], + params: Optional[Dict[str, Any]] = None, + ) -> RunShieldResponse: + """Implementation specific shield running logic""" + pass + + async def run_shield( + self, + shield_id: str, + messages: List[Message], + params: Optional[Dict[str, Any]] = None, + ) -> RunShieldResponse: + """Run safety checks using configured shield""" + try: + if not messages: + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.INFO, + user_message="No messages to process", + metadata={"status": "skipped", "shield_id": shield_id}, + ) + ) + + supported_messages = [] + unsupported_types = set() + + for msg in messages: + if self._should_process_message(msg): + supported_messages.append(msg) + else: + msg_type = msg.type if hasattr(msg, "type") else type(msg).__name__ + unsupported_types.add(msg_type) + logger.warning( + f"Message type '{msg_type}' not supported by shield {shield_id}. " + f"Allowed types: {list(self.config.message_types)}" + ) + + if not supported_messages: + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.WARN, + user_message=( + f"No supported message types to process. Shield {shield_id} only handles: " + f"{list(self.config.message_types)}" + ), + metadata={ + "status": "skipped", + "error_type": "no_supported_messages", + "message_type": list(unsupported_types), + "supported_types": list(self.config.message_types), + "shield_id": shield_id, + }, + ) + ) + + # Step 4: Process supported messages + return await self._run_shield_impl(shield_id, supported_messages, params) + + except Exception as e: + logger.error(f"Shield execution failed: {str(e)}", exc_info=True) + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=f"Shield execution error: {str(e)}", + metadata={ + "status": "error", + "error_type": "execution_error", + "shield_id": shield_id, + "error": str(e), + }, + ) + ) + + +class SimpleShieldStore(ShieldStore): + """Simplified shield store with caching""" + + def __init__(self): + self._shields: Dict[str, Shield] = {} + self._detector_configs = {} + self._pending_configs = {} # Add this to store configs before initialization + self._store_id = id(self) + self._initialized = False + self._lock = asyncio.Lock() # Add lock + logger.info(f"Created SimpleShieldStore: {self._store_id}") + + async def register_detector_config(self, detector_id: str, config: Any) -> None: + """Register detector configuration""" + async with self._lock: + if self._initialized: + self._detector_configs[detector_id] = config + else: + self._pending_configs[detector_id] = config + logger.info( + f"Shield store {self._store_id} registered config for: {detector_id}" + ) + + async def initialize(self) -> None: + """Initialize store and process pending configurations""" + if self._initialized: + return + async with self._lock: + # Process any pending configurations + self._detector_configs.update(self._pending_configs) + self._pending_configs.clear() + self._initialized = True + logger.info( + f"Shield store {self._store_id} initialized with {len(self._detector_configs)} configs" + ) + + async def get_shield(self, identifier: str) -> Shield: + """Get or create shield by identifier""" + await self.initialize() + + if identifier in self._shields: + logger.debug( + f"Shield store {self._store_id} found existing shield: {identifier}" + ) + return self._shields[identifier] + + config = self._detector_configs.get(identifier) + if config: + logger.info( + f"Shield store {self._store_id} creating shield for {identifier} using config" + ) + + # Extract detector params with full support for all structures + detector_params: Dict[str, Any] = {} + + # NEW STRUCTURE: Check for top-level detectors first + if hasattr(config, "detectors") and config.detectors is not None: + detector_params = {"detectors": {}} + for det_id, det_config in config.detectors.items(): + detector_params["detectors"][det_id] = det_config.get( + "detector_params", {} + ) + + # LEGACY STRUCTURES: Handle detector_params variations + elif ( + hasattr(config, "detector_params") + and config.detector_params is not None + ): + # Check for generic parameter containers first + for param_key in ["model_params", "kwargs", "metadata"]: + if ( + hasattr(config.detector_params, param_key) + and getattr(config.detector_params, param_key) is not None + ): + generic_params = getattr(config.detector_params, param_key) + if generic_params: + detector_params = {param_key: generic_params} + break + + # If no generic containers, check for detectors object + if ( + not detector_params + and hasattr(config.detector_params, "detectors") + and config.detector_params.detectors is not None + ): + detector_params = {"detectors": config.detector_params.detectors} + + # If still empty, extract flat params + if not detector_params: + detector_params = { + k: v + for k, v in vars(config.detector_params).items() + if v is not None and k != "detectors" + } + + # Include display and metadata information in params + detector_params.update( + { + "display_name": f"{identifier} Shield", + "display_description": f"Safety shield for {identifier}", + "detector_type": "content" if not config.is_chat else "chat", + "message_types": list(config.message_types), + "confidence_threshold": config.confidence_threshold, + } + ) + + # Create shield with only the valid fields and explicit type annotation + shield: Shield = Shield( + identifier=identifier, + provider_id="trustyai_fms", + provider_resource_id=identifier, + type=ResourceType.shield.value, + params=detector_params, + ) + + logger.info( + f"Shield store {self._store_id} created shield: {identifier} with params: {detector_params}" + ) + self._shields[identifier] = shield + return shield + else: + # Fail explicitly if no config found + logger.error( + f"Shield store {self._store_id} failed to create shield: no configuration found for {identifier}" + ) + raise DetectorValidationError( + f"Cannot create shield '{identifier}': no detector configuration found. " + "Shields must have a valid detector configuration to ensure proper safety checks." + ) + + async def list_shields(self) -> ListShieldsResponse: + """List all registered shields""" + await self.initialize() + shields = list(self._shields.values()) + shield_ids = [s.identifier for s in shields] + logger.info( + f"Shield store {self._store_id} listing {len(shields)} shields: {shield_ids}" + ) + return ListShieldsResponse(data=shields) + + +class DetectorProvider(Safety, Shields): + """Provider for managing safety detectors and shields""" + + def __init__(self, detectors: Dict[str, BaseDetector]) -> None: + self.detectors = detectors + self._shield_store: ShieldStore = SimpleShieldStore() + self._shields: Dict[str, Shield] = {} + self._initialized = False + self._provider_id = id(self) + self._detector_key_to_id = {} # Add mapping dict + self._pending_configs = [] # Store configurations for later registration + + # Store configurations for async registration + for detector_key, detector in detectors.items(): + detector.shield_store = self._shield_store + config_id = detector.config.detector_id + self._detector_key_to_id[detector_key] = config_id + self._pending_configs.append((config_id, detector.config)) + logger.info( + f"Created DetectorProvider {self._provider_id} with {len(detectors)} detectors" + ) + + @property + def shield_store(self) -> ShieldStore: + return self._shield_store + + @shield_store.setter + def shield_store(self, value: ShieldStore) -> None: + """Set shield store instance""" + if not value: + logger.warning(f"Provider {self._provider_id} received null shield store") + return + + logger.info( + f"Provider {self._provider_id} setting new shield store: {id(value)}" + ) + self._shield_store = value + + # Update detectors and sync shields + for detector_id, detector in self.detectors.items(): + detector.shield_store = value + logger.debug( + f"Provider {self._provider_id} updated detector {detector_id} with shield store {id(value)}" + ) + + # Register detector configs if possible using getattr for safe access + if hasattr(value, "register_detector_config") and hasattr( + detector, "config" + ): + # Use getattr to get the method safely + register_method = getattr(value, "register_detector_config", None) + if callable(register_method): + asyncio.create_task( + register_method(detector.config.detector_id, detector.config) + ) + + async def initialize(self) -> None: + """Initialize provider and register initial shields""" + if self._initialized: + return + + logger.info(f"Provider {self._provider_id} starting initialization") + + try: + # First register all configurations if supported + if hasattr(self._shield_store, "register_detector_config"): + # Process these in parallel + tasks = [] + for config_id, config in self._pending_configs: + tasks.append( + self._shield_store.register_detector_config(config_id, config) + ) + + if tasks: + await asyncio.gather(*tasks) + else: + logger.debug( + f"Provider {self._provider_id} shield store doesn't support register_detector_config" + ) + + # Clear pending configs + self._pending_configs.clear() + + # Initialize detectors in parallel with controlled concurrency + detector_init_tasks = [] + for detector in self.detectors.values(): + detector_init_tasks.append(detector.initialize()) + + if detector_init_tasks: + await asyncio.gather(*detector_init_tasks) + + shields_to_register: List[Tuple[BaseDetector, Shield]] = [] + + # Create shields directly without relying on shield store methods + for detector in self.detectors.values(): + config_id = detector.config.detector_id + detector_params: Dict[str, Any] = {} + + # NEW STRUCTURE: Check for top-level detectors first + if ( + hasattr(detector.config, "detectors") + and detector.config.detectors is not None + ): + detector_params = {"detectors": {}} + for det_id, det_config in detector.config.detectors.items(): + detector_params["detectors"][det_id] = det_config.get( + "detector_params", {} + ) + # LEGACY STRUCTURES: Handle detector_params variations + elif ( + hasattr(detector.config, "detector_params") + and detector.config.detector_params is not None + ): + # Create flat_params by extracting from all containers + flat_params: Dict[str, Any] = {} + + # Extract parameters from model_params, metadata, kwargs containers + if ( + hasattr(detector.config.detector_params, "model_params") + and detector.config.detector_params.model_params is not None + ): + flat_params.update(detector.config.detector_params.model_params) + + if ( + hasattr(detector.config.detector_params, "metadata") + and detector.config.detector_params.metadata is not None + ): + flat_params.update(detector.config.detector_params.metadata) + + if ( + hasattr(detector.config.detector_params, "kwargs") + and detector.config.detector_params.kwargs is not None + ): + flat_params.update(detector.config.detector_params.kwargs) + + # Also include direct properties, skipping empty containers + for k, v in vars(detector.config.detector_params).items(): + if v is not None and k not in [ + "detectors", + "model_params", + "metadata", + "kwargs", + "params", + ]: + # Skip empty dictionaries and lists + if not (isinstance(v, (dict, list)) and len(v) == 0): + flat_params[k] = v + + # Initialize empty detector_params + detector_params = {} + + # Special handling for chat detectors + if detector.config.is_chat: + # Create a clean model_params dictionary with only the parameters we need + model_params: Dict[str, Any] = {} + + # Add relevant parameters from flat_params, excluding "params" + for k, v in flat_params.items(): + if ( + k != "params" + ): # Explicitly exclude the empty params dict + model_params[k] = v + + # Set model_params in detector_params + detector_params["model_params"] = model_params + elif ( + hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + ): + # Handle composite detectors + detector_params["detectors"] = ( + detector.config.detector_params.detectors + ) + else: + # For non-chat detectors, use params as-is + detector_params = flat_params + + # Add display information to params + detector_params.update( + { + "display_name": f"{config_id} Shield", + "display_description": f"Safety shield for {config_id}", + "detector_type": ( + "content" if not detector.config.is_chat else "chat" + ), + "message_types": list(detector.config.message_types), + "confidence_threshold": detector.config.confidence_threshold, + } + ) + + # Create shield with valid parameters only + shield = Shield( + identifier=config_id, + provider_id="trustyai_fms", + provider_resource_id=config_id, + type=ResourceType.shield.value, + params=detector_params, + ) + + # Special handling for different detector configurations + if detector.config.is_chat: + # Chat detectors already work correctly - no changes needed + pass + elif ( + detector.config.detector_params + is not None # Add explicit null check here + and hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + ): + # Orchestrator configuration with multiple detectors + nested_detectors: Dict[str, Any] = {} + + # Access the detectors through detector_params where they're actually stored + for ( + det_id, + det_config, + ) in detector.config.detector_params.detectors.items(): + # Extract detector parameters if present + if ( + "detector_params" in det_config + and det_config["detector_params"] + ): + nested_detectors[det_id] = det_config["detector_params"] + + # Set structured parameters + if nested_detectors: + shield.params = {"detectors": nested_detectors} + + elif detector.config.detector_params is not None: + # Standard content detector with direct parameters + if hasattr(detector.config.detector_params, "to_categorized_dict"): + shield.params = ( + detector.config.detector_params.to_categorized_dict() + ) + + self._shields[config_id] = shield + + # Register shields in parallel + register_tasks = [] + for detector, shield in shields_to_register: + register_tasks.append(detector.register_shield(shield)) + + if register_tasks: + await asyncio.gather(*register_tasks) + + self._initialized = True + logger.info( + f"Provider {self._provider_id} initialization complete with {len(self._shields)} shields" + ) + + except Exception as e: + logger.error(f"Provider {self._provider_id} initialization failed: {e}") + raise + + async def list_shields(self) -> ListShieldsResponse: + """List all registered shields""" + if not self._initialized: + await self.initialize() # Just await it, don't return its result + + shields = list(self._shields.values()) + shield_ids = [s.identifier for s in shields] + logger.info( + f"Provider {self._provider_id} listing {len(shields)} shields: {shield_ids}" + ) + return ListShieldsResponse(data=shields) + + async def get_shield(self, identifier: str) -> Shield: + """Get shield by identifier""" + await self.initialize() + + # Return existing shield + if identifier in self._shields: + return self._shields[identifier] + + # Get detector and config + detector = self.detectors.get(identifier) + if not detector: + raise DetectorValidationError(f"Shield not found: {identifier}") + + # Create shield from store + shield = await self._shield_store.get_shield(identifier) + if shield: + self._shields[identifier] = shield + return shield + + raise DetectorValidationError(f"Failed to get shield: {identifier}") + + async def register_shield( + self, + shield_id: str, + provider_shield_id: Optional[str] = None, + provider_id: Optional[str] = None, + params: Optional[Dict[str, Any]] = None, + ) -> Shield: + """Register a new shield""" + if not self._initialized: + await self.initialize() + + # Return existing shield if already registered + if shield_id in self._shields: + return self._shields[shield_id] + + # Create new shield + shield = await self._shield_store.get_shield(shield_id) + if not shield: + raise DetectorValidationError(f"Failed to create shield: {shield_id}") + + # Update fields if provided + if provider_id: + shield.provider_id = provider_id + if provider_shield_id: + shield.provider_resource_id = provider_shield_id + if params is not None: + shield.params = params + + # Register shield + self._shields[shield_id] = shield + + # Register with detectors + for detector in self.detectors.values(): + await detector.register_shield(shield) + + return shield + + async def run_shield( + self, + shield_id: str, + messages: List[Message], + params: Optional[Dict[str, Any]] = None, + ) -> RunShieldResponse: + """Run shield against messages with enhanced composite handling""" + try: + # Step 1: Initial validation and initialization + if not self._initialized: + await self.initialize() + + if not messages: + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.INFO, + user_message="No messages to process", + metadata={"status": "skipped", "shield_id": shield_id}, + ) + ) + + # Step 2: Get and validate shield configuration + shield_detectors = [ + detector + for detector in self.detectors.values() + if detector.config.detector_id == shield_id + ] + + if not shield_detectors: + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=f"No detectors found for shield: {shield_id}", + metadata={ + "status": "error", + "error_type": "detectors_not_found", + "shield_id": shield_id, + }, + ) + ) + + detector = shield_detectors[0] + + # Step 3: Filter messages and track skipped ones + skipped_messages = [] + filtered_messages = [] + + for idx, msg in enumerate(messages): + if detector._should_process_message(msg): + filtered_messages.append((idx, msg)) + else: + msg_type = msg.type if hasattr(msg, "type") else type(msg).__name__ + skipped_messages.append( + { + "index": idx, + "type": msg_type, + "reason": f"Message type '{msg_type}' not supported", + } + ) + + if not filtered_messages: + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.WARN, + user_message=( + f"No supported message types to process. Shield {shield_id} only handles: " + f"{list(detector.config.message_types)}" + ), + metadata={ + "status": "skipped", + "error_type": "no_supported_messages", + "supported_types": list(detector.config.message_types), + "shield_id": shield_id, + "skipped_messages": skipped_messages, + }, + ) + ) + + # Step 4: Initialize result tracking + message_results = [] + has_violation = False + highest_violation_score = 0.0 + total_detections = 0 + + # Step 5: Determine if detector is composite + is_composite = ( + detector.config.detector_params is not None + and hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + ) + + # Step 6: Process messages in parallel + if is_composite: + # Define function to process composite detector message + async def process_composite_message(orig_idx, message): + try: + current_result = { + "message_index": orig_idx, + "text": message.content, + "status": "pass", + "score": None, + "detection_type": None, + } + + # Make API request for this message + request = detector._prepare_request_payload([message], params) + response = await detector._make_request(request) + detections = response.get("detections", []) + configured_detectors = [] + if ( + detector.config.detector_params is not None + and hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + ): + configured_detectors = list( + detector.config.detector_params.detectors.keys() + ) + + individual_results = [] + message_has_violation = False + message_highest_score = 0.0 + message_detections = 0 + + for det_id in configured_detectors: + detection = next( + ( + d + for d in detections + if d.get("detector_id") == det_id + ), + None, + ) + + score = detection.get("score") if detection else None + is_violation = ( + score is not None and score > detector.score_threshold + ) + + if is_violation: + message_has_violation = True + message_detections += 1 + if score > message_highest_score: + message_highest_score = score + current_result.update( + { + "status": "violation", + "score": score, + "detection_type": detection.get( + "detection_type" + ), + } + ) + + individual_results.append( + { + "detector_id": det_id, + "status": "violation" if is_violation else "pass", + "score": score, + "detection_type": ( + detection.get("detection_type") + if detection + else None + ), + } + ) + + current_result["individual_detector_results"] = ( + individual_results + ) + + return { + "result": current_result, + "has_violation": message_has_violation, + "highest_score": message_highest_score, + "detections": message_detections, + } + except Exception as e: + logger.error( + f"Message processing failed for message {orig_idx}: {e}" + ) + return { + "result": { + "message_index": orig_idx, + "text": message.content if message else "", + "status": "error", + "error": str(e), + }, + "has_violation": False, + "highest_score": 0.0, + "detections": 0, + "error": str(e), + } + + # Create tasks for all messages with controlled concurrency + # Use semaphore to limit concurrent API calls + semaphore = asyncio.Semaphore(detector.config.max_concurrency) + + async def process_with_semaphore(orig_idx, message): + async with semaphore: + return await process_composite_message(orig_idx, message) + + # Create and execute tasks + tasks = [ + process_with_semaphore(orig_idx, message) + for orig_idx, message in filtered_messages + ] + + # Await all tasks + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + for result in task_results: + if isinstance(result, Exception): + # Handle unexpected exceptions + logger.error(f"Task execution failed: {result}") + continue + + # Extract data + assert isinstance( + result, dict + ), "Expected result to be a dictionary" + message_results.append(result["result"]) + + # Update aggregate metrics + if result["has_violation"]: + has_violation = True + total_detections += result["detections"] + if result["highest_score"] > highest_violation_score: + highest_violation_score = result["highest_score"] + + else: + # For non-composite detectors + async def process_standard_message(orig_idx, message): + try: + current_result = { + "message_index": orig_idx, + "text": message.content, + "status": "pass", + "score": None, + "detection_type": None, + } + + # Make API request for this message + response = await detector._run_shield_impl( + shield_id, [message], params + ) + + if response.violation: + score = response.violation.metadata.get("score") + current_result.update( + { + "status": "violation", + "score": score, + "detection_type": response.violation.metadata.get( + "detection_type" + ), + } + ) + + return { + "result": current_result, + "has_violation": True, + "highest_score": score or 0.0, + "detections": 1, + } + + return { + "result": current_result, + "has_violation": False, + "highest_score": 0.0, + "detections": 0, + } + except Exception as e: + logger.error( + f"Message processing failed for message {orig_idx}: {e}" + ) + return { + "result": { + "message_index": orig_idx, + "text": message.content if message else "", + "status": "error", + "error": str(e), + }, + "has_violation": False, + "highest_score": 0.0, + "detections": 0, + "error": str(e), + } + + # Create tasks with controlled concurrency + semaphore = asyncio.Semaphore(detector.config.max_concurrency) + + async def process_with_semaphore(orig_idx, message): + async with semaphore: + return await process_standard_message(orig_idx, message) + + # Create and execute tasks + tasks = [ + process_with_semaphore(orig_idx, message) + for orig_idx, message in filtered_messages + ] + + # Await all tasks + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process results + for result in task_results: + if isinstance(result, Exception): + # Handle unexpected exceptions + logger.error(f"Task execution failed: {result}") + continue + + # Extract data + assert isinstance( + result, dict + ), "Expected result to be a dictionary" + message_results.append(result["result"]) + + # Update aggregate metrics + if result["has_violation"]: + has_violation = True + total_detections += result["detections"] + if result["highest_score"] > highest_violation_score: + highest_violation_score = result["highest_score"] + + # Step 7: Calculate summary statistics + total_filtered = len(filtered_messages) + violated_messages = sum( + 1 for r in message_results if r["status"] == "violation" + ) + passed_messages = total_filtered - violated_messages + + message_pass_rate = round( + passed_messages / total_filtered if total_filtered > 0 else 0, + 3, + ) + message_fail_rate = round( + violated_messages / total_filtered if total_filtered > 0 else 0, + 3, + ) + + # Step 8: Prepare summary + summary = { + "total_messages": len(messages), + "processed_messages": total_filtered, + "skipped_messages": len(skipped_messages), + "messages_with_violations": violated_messages, + "messages_passed": passed_messages, + "message_fail_rate": message_fail_rate, + "message_pass_rate": message_pass_rate, + "total_detections": total_detections, + "detector_breakdown": { + "active_detectors": ( + len(detector.config.detector_params.detectors) + if is_composite + and detector.config.detector_params is not None + and hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + else 1 + ), + "total_checks_performed": ( + total_filtered * len(detector.config.detector_params.detectors) + if is_composite + and detector.config.detector_params is not None + and hasattr(detector.config.detector_params, "detectors") + and detector.config.detector_params.detectors is not None + else total_filtered + ), + "total_violations_found": total_detections, + "violations_per_message": round( + total_detections / total_filtered if total_filtered > 0 else 0, + 3, + ), + }, + } + + # Step 9: Prepare metadata + metadata = { + "status": "violation" if has_violation else "pass", + "shield_id": shield_id, + "confidence_threshold": detector.score_threshold, + "summary": summary, + "results": message_results, + } + + # Step 10: Prepare response message + skipped_msg = ( + f" ({len(skipped_messages)} messages skipped)" + if skipped_messages + else "" + ) + base_msg = ( + f"Content violation detected by shield {shield_id} " + f"(confidence: {highest_violation_score:.2f}, " + f"{violated_messages}/{total_filtered} processed messages violated)" + if has_violation + else f"Content verified by shield {shield_id} " + f"({total_filtered} messages processed)" + ) + + # Step 11: Return final response + return RunShieldResponse( + violation=SafetyViolation( + violation_level=( + ViolationLevel.ERROR if has_violation else ViolationLevel.INFO + ), + user_message=f"{base_msg}{skipped_msg}", + metadata=metadata, + ) + ) + + except Exception as e: + logger.error(f"Shield execution failed: {str(e)}", exc_info=True) + return RunShieldResponse( + violation=SafetyViolation( + violation_level=ViolationLevel.ERROR, + user_message=f"Shield execution error: {str(e)}", + metadata={ + "status": "error", + "error_type": "execution_error", + "shield_id": shield_id, + "error": str(e), + }, + ) + ) + + async def shutdown(self) -> None: + """Cleanup resources""" + logger.info(f"Provider {self._provider_id} shutting down") + errors = [] + + for detector_id, detector in self.detectors.items(): + try: + await detector.shutdown() + logger.debug( + f"Provider {self._provider_id} shutdown detector: {detector_id}" + ) + except Exception as e: + error_msg = f"Error shutting down detector {detector_id}: {e}" + logger.error(f"Provider {self._provider_id} {error_msg}") + errors.append(error_msg) + + if errors: + raise DetectorError( + f"Provider {self._provider_id} shutdown errors: {', '.join(errors)}" + ) + + logger.info(f"Provider {self._provider_id} shutdown complete") diff --git a/llama_stack/providers/remote/safety/trustyai_fms/detectors/chat.py b/llama_stack/providers/remote/safety/trustyai_fms/detectors/chat.py new file mode 100644 index 000000000..20e3a1b81 --- /dev/null +++ b/llama_stack/providers/remote/safety/trustyai_fms/detectors/chat.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, cast + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import RunShieldResponse +from llama_stack.providers.remote.safety.trustyai_fms.config import ChatDetectorConfig +from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import ( + BaseDetector, + DetectionResult, + DetectorError, + DetectorRequestError, + DetectorValidationError, +) + +# Type aliases for better readability +ChatMessage = Dict[ + str, Any +] # Changed from Dict[str, str] to Dict[str, Any] to handle complex content +ChatRequest = Dict[str, Any] +DetectorResponse = List[Dict[str, Any]] + +logger = logging.getLogger(__name__) + + +class ChatDetectorError(DetectorError): + """Specific errors for chat detector operations""" + + pass + + +@dataclass(frozen=True) +class ChatDetectionMetadata: + """Structured metadata for chat detections""" + + risk_name: Optional[str] = None + risk_definition: Optional[str] = None + additional_metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert metadata to dictionary format""" + result: Dict[str, Any] = {} # Fixed the type annotation here + if self.risk_name: + result["risk_name"] = self.risk_name + if self.risk_definition: + result["risk_definition"] = self.risk_definition + if self.additional_metadata: + result["metadata"] = self.additional_metadata + return result + + +class ChatDetector(BaseDetector): + """Detector for chat-based safety checks""" + + def __init__(self, config: ChatDetectorConfig) -> None: + """Initialize chat detector with configuration""" + if not isinstance(config, ChatDetectorConfig): + raise DetectorValidationError( + "Config must be an instance of ChatDetectorConfig" + ) + super().__init__(config) + self.config: ChatDetectorConfig = config + logger.info(f"Initialized ChatDetector with config: {vars(config)}") + + def _extract_detector_params(self) -> Dict[str, Any]: + """Extract non-null detector parameters""" + if not self.config.detector_params: + return {} + + params = { + k: v for k, v in vars(self.config.detector_params).items() if v is not None + } + logger.debug(f"Extracted detector params: {params}") + return params + + def _prepare_chat_request( + self, messages: List[ChatMessage], params: Optional[Dict[str, Any]] = None + ) -> ChatRequest: + """Prepare the request based on API mode""" + # Format messages for detector API + formatted_messages: List[Dict[str, str]] = [] # Explicitly typed + for msg in messages: + formatted_msg = { + "content": str(msg.get("content", "")), # Ensure string type + "role": "user", # Always send as user for detector API + } + formatted_messages.append(formatted_msg) + + if self.config.use_orchestrator_api: + payload: Dict[str, Any] = { + "messages": formatted_messages + } # Explicitly typed + + # Initialize detector_config to avoid None + detector_config: Dict[str, Any] = {} # Explicitly typed + + # NEW STRUCTURE: Check for top-level detectors first + if hasattr(self.config, "detectors") and self.config.detectors: + for detector_id, det_config in self.config.detectors.items(): + detector_config[detector_id] = det_config.get("detector_params", {}) + + # LEGACY STRUCTURE: Check for nested detectors + elif ( + self.config.detector_params + and hasattr(self.config.detector_params, "detectors") + and self.config.detector_params.detectors + ): + detector_config = self.config.detector_params.detectors + + # Handle flat params - group them into generic containers + elif self.config.detector_params: + detector_params = self._extract_detector_params() + + # Create a flat dictionary of parameters + flat_params = {} + + # Extract from model_params + if "model_params" in detector_params and isinstance( + detector_params["model_params"], dict + ): + flat_params.update(detector_params["model_params"]) + + # Extract from metadata + if "metadata" in detector_params and isinstance( + detector_params["metadata"], dict + ): + flat_params.update(detector_params["metadata"]) + + # Extract from kwargs + if "kwargs" in detector_params and isinstance( + detector_params["kwargs"], dict + ): + flat_params.update(detector_params["kwargs"]) + + # Add any other direct parameters, but skip container dictionaries + for k, v in detector_params.items(): + if ( + k not in ["model_params", "metadata", "kwargs", "params"] + and v is not None + ): + flat_params[k] = v + + # Add all flattened parameters directly to detector configuration + detector_config[self.config.detector_id] = flat_params + + # Ensure we have a valid detectors map even if all checks fail + if not detector_config: + detector_config = {self.config.detector_id: {}} + + payload["detectors"] = detector_config + return payload + + # Direct API format remains unchanged + else: + # DIRECT MODE: Use flat parameters for API compatibility + # Don't organize into containers for direct mode + detector_params = self._extract_detector_params() + + # Flatten the parameters for direct mode too + flat_params = {} + + # Extract from model_params + if "model_params" in detector_params and isinstance( + detector_params["model_params"], dict + ): + flat_params.update(detector_params["model_params"]) + + # Extract from metadata + if "metadata" in detector_params and isinstance( + detector_params["metadata"], dict + ): + flat_params.update(detector_params["metadata"]) + + # Extract from kwargs + if "kwargs" in detector_params and isinstance( + detector_params["kwargs"], dict + ): + flat_params.update(detector_params["kwargs"]) + + # Add any other direct parameters + for k, v in detector_params.items(): + if ( + k not in ["model_params", "metadata", "kwargs", "params"] + and v is not None + ): + flat_params[k] = v + + return { + "messages": formatted_messages, + "detector_params": flat_params if flat_params else params or {}, + } + + async def _call_detector_api( + self, + messages: List[ChatMessage], + params: Optional[Dict[str, Any]] = None, + ) -> DetectorResponse: + """Call chat detector API with proper endpoint selection""" + try: + request = self._prepare_chat_request(messages, params) + headers = self._prepare_headers() + + logger.info("Making detector API request") + logger.debug(f"Request headers: {headers}") + logger.debug(f"Request payload: {request}") + + response = await self._make_request(request, headers) + return self._extract_detections(response) + + except Exception as e: + logger.error(f"API call failed: {str(e)}", exc_info=True) + raise DetectorRequestError( + f"Chat detector API call failed: {str(e)}" + ) from e + + def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse: + """Extract detections from API response""" + if not response: + logger.debug("Empty response received") + return [] + + if self.config.use_orchestrator_api: + detections = response.get("detections", []) + if not detections: + # Add default detection when none returned + logger.debug("No detections found, adding default low-score detection") + return [ + { + "detection_type": "risk", + "detection": "No", + "detector_id": self.config.detector_id, + "score": 0.0, # Default low score + } + ] + logger.debug(f"Orchestrator detections: {detections}") + return cast( + DetectorResponse, detections + ) # Explicit cast to correct return type + + # Direct API returns a list where first item contains detections + if isinstance(response, list) and response: + detections = ( + [response[0]] if not isinstance(response[0], list) else response[0] + ) + logger.debug(f"Direct API detections: {detections}") + return cast( + DetectorResponse, detections + ) # Explicit cast to correct return type + + logger.debug("No detections found in response") + return [] + + def _process_detection( + self, detection: Dict[str, Any] + ) -> Tuple[ + Optional[DetectionResult], float + ]: # Changed return type to match base class + """Process detection result and validate against threshold""" + score = detection.get("score", 0.0) # Default to 0.0 if score is missing + + if score > self.score_threshold: + metadata = ChatDetectionMetadata( + risk_name=( + self.config.detector_params.risk_name + if self.config.detector_params + else None + ), + risk_definition=( + self.config.detector_params.risk_definition + if self.config.detector_params + else None + ), + additional_metadata=detection.get("metadata"), + ) + + result = DetectionResult( + detection="Yes", + detection_type=detection["detection_type"], + score=score, + detector_id=detection.get("detector_id", self.config.detector_id), + text=detection.get("text", ""), + start=detection.get("start", 0), + end=detection.get("end", 0), + metadata=metadata.to_dict(), + ) + return (result, score) + return (None, score) + + async def _run_shield_impl( + self, + shield_id: str, + messages: List[Message], + params: Optional[Dict[str, Any]] = None, + ) -> RunShieldResponse: + """Implementation of shield checks for chat messages""" + try: + shield = await self.shield_store.get_shield(shield_id) + self._validate_shield(shield) + + logger.info(f"Processing {len(messages)} message(s)") + + # Convert messages keeping only necessary fields + chat_messages: List[ChatMessage] = [] # Explicitly typed + for msg in messages: + message_dict: ChatMessage = {"content": msg.content, "role": msg.role} + # Preserve type if present for internal processing + if hasattr(msg, "type"): + message_dict["type"] = msg.type + chat_messages.append(message_dict) + + logger.debug(f"Prepared messages: {chat_messages}") + detections = await self._call_detector_api(chat_messages, params) + + for detection in detections: + processed, score = self._process_detection(detection) + if processed: + logger.info(f"Violation detected: {processed}") + return self.create_violation_response( + processed, detection.get("detector_id", self.config.detector_id) + ) + + logger.debug("No violations detected") + return RunShieldResponse() + + except Exception as e: + logger.error(f"Shield execution failed: {str(e)}", exc_info=True) + raise ChatDetectorError(f"Shield execution failed: {str(e)}") from e diff --git a/llama_stack/providers/remote/safety/trustyai_fms/detectors/content.py b/llama_stack/providers/remote/safety/trustyai_fms/detectors/content.py new file mode 100644 index 000000000..396b23284 --- /dev/null +++ b/llama_stack/providers/remote/safety/trustyai_fms/detectors/content.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, cast + +from llama_stack.apis.inference import Message +from llama_stack.apis.safety import RunShieldResponse +from llama_stack.providers.remote.safety.trustyai_fms.config import ( + ContentDetectorConfig, +) +from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import ( + BaseDetector, + DetectionResult, + DetectorError, + DetectorRequestError, + DetectorValidationError, +) + +# Type aliases for better readability +ContentRequest = Dict[str, Any] +DetectorResponse = List[Dict[str, Any]] + +logger = logging.getLogger(__name__) + + +class ContentDetectorError(DetectorError): + """Specific errors for content detector operations""" + + pass + + +class ContentDetector(BaseDetector): + """Detector for content-based safety checks""" + + def __init__(self, config: ContentDetectorConfig) -> None: + """Initialize content detector with configuration""" + if not isinstance(config, ContentDetectorConfig): + raise DetectorValidationError( + "Config must be an instance of ContentDetectorConfig" + ) + super().__init__(config) + self.config: ContentDetectorConfig = config + logger.info(f"Initialized ContentDetector with config: {vars(config)}") + + def _extract_detector_params(self) -> Dict[str, Any]: + """Extract detector parameters with support for generic format""" + if not self.config.detector_params: + return {} + + # Use to_dict() to flatten our categorized structure into what the API expects + params = self.config.detector_params.to_dict() + logger.debug(f"Extracted detector params: {params}") + return params + + def _prepare_content_request( + self, content: str, params: Optional[Dict[str, Any]] = None + ) -> ContentRequest: + """Prepare the request based on API mode""" + + if self.config.use_orchestrator_api: + payload: Dict[str, Any] = {"content": content} # Always use singular form + + # NEW STRUCTURE: Check for top-level detectors first + if hasattr(self.config, "detectors") and self.config.detectors: + detector_config: Dict[str, Any] = {} + for detector_id, det_config in self.config.detectors.items(): + detector_config[detector_id] = det_config.get("detector_params", {}) + payload["detectors"] = detector_config + return payload + + # LEGACY STRUCTURE: Check for nested detectors + elif self.config.detector_params and hasattr( + self.config.detector_params, "detectors" + ): + detectors = getattr(self.config.detector_params, "detectors", {}) + payload["detectors"] = detectors + return payload + + # Handle flat params + else: + detector_config = {} + detector_params = self._extract_detector_params() + if detector_params: + detector_config[self.config.detector_id] = detector_params + payload["detectors"] = detector_config + return payload + + else: + # DIRECT MODE: Use flat parameters for API compatibility + detector_params = self._extract_detector_params() + + return { + "contents": [content], + "detector_params": detector_params if detector_params else params or {}, + } + + def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse: + """Extract detections from API response""" + if not response: + logger.debug("Empty response received") + return [] + + if self.config.use_orchestrator_api: + detections = response.get("detections", []) + logger.debug(f"Orchestrator detections: {detections}") + return cast(List[Dict[str, Any]], detections) + + # Direct API returns a list of lists where inner list contains detections + if isinstance(response, list) and response: + detections = response[0] if isinstance(response[0], list) else [response[0]] + logger.debug(f"Direct API detections: {detections}") + return cast(List[Dict[str, Any]], detections) + + logger.debug("No detections found in response") + return [] + + async def _call_detector_api( + self, + content: str, + params: Optional[Dict[str, Any]] = None, + ) -> DetectorResponse: + """Call detector API with proper endpoint selection""" + try: + request = self._prepare_content_request(content, params) + headers = self._prepare_headers() + + logger.info("Making detector API request") + logger.debug(f"Request headers: {headers}") + logger.debug(f"Request payload: {request}") + + response = await self._make_request(request, headers) + return self._extract_detections(response) + + except Exception as e: + logger.error(f"API call failed: {str(e)}", exc_info=True) + raise DetectorRequestError( + f"Content detector API call failed: {str(e)}" + ) from e + + def _process_detection( + self, detection: Dict[str, Any] + ) -> tuple[Optional[DetectionResult], float]: + """Process detection result and validate against threshold""" + if not detection.get("score"): + logger.warning("Detection missing score field") + return None, 0.0 + + score = detection.get("score", 0) + if score > self.score_threshold: + result = DetectionResult( + detection="Yes", + detection_type=detection["detection_type"], + score=score, + detector_id=detection.get("detector_id", self.config.detector_id), + text=detection.get("text", ""), + start=detection.get("start", 0), + end=detection.get("end", 0), + metadata=detection.get("metadata", {}), + ) + return result, score + return None, score + + async def _run_shield_impl( + self, + shield_id: str, + messages: List[Message], + params: Optional[Dict[str, Any]] = None, + ) -> RunShieldResponse: + """Implementation of shield checks for content messages""" + try: + shield = await self.shield_store.get_shield(shield_id) + self._validate_shield(shield) + + for msg in messages: + content = msg.content + content_str: str + + # Simplified content handling - just check for text attribute or convert to string + if hasattr(content, "text"): + content_str = str(content.text) + elif isinstance(content, list): + content_str = " ".join( + str(getattr(item, "text", "")) for item in content + ) + else: + content_str = str(content) + + truncated_content = ( + content_str[:100] + "..." if len(content_str) > 100 else content_str + ) + logger.debug(f"Checking content: {truncated_content}") + + detections = await self._call_detector_api(content_str, params) + + for detection in detections: + processed, score = self._process_detection(detection) + if processed: + logger.info(f"Violation detected: {processed}") + return self.create_violation_response( + processed, + detection.get("detector_id", self.config.detector_id), + ) + + logger.debug("No violations detected") + return RunShieldResponse() + + except Exception as e: + logger.error(f"Shield execution failed: {str(e)}", exc_info=True) + raise ContentDetectorError(f"Shield execution failed: {str(e)}") from e