mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
Squashed commit of the following:
commit a95d2b15b83057e194cf69e57a03deeeeeadd7c2 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 14:33:50 2025 +0000 🚧 working on the config file so that it is inheriting from pydantic base models commit 0546379f817e37bca030247b48c72ce84899a766 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 09:14:31 2025 +0000 🚧 dealing with ruff checks commit 8abe39ee4cb4b8fb77c7252342c4809fa6ddc432 Author: m-misiura <mmisiura@redhat.com> Date: Mon Mar 24 09:03:18 2025 +0000 🚧 dealing with mypy errors in `base.py` commit 045f833e79c9a25af3d46af6c8896da91a0e6e62 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 17:31:25 2025 +0000 🚧 fixing mypy errors in content.py commit a9c1ee4e92ad1b5db89039317555cd983edbde65 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 17:09:02 2025 +0000 🚧 fixing mypy errors in chat.py commit 69e8ddc2f8a4e13cecbab30272fd7d685d7864ec Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 16:57:28 2025 +0000 🚧 fixing mypy errors commit 56739d69a145c55335ac2859ecbe5b43d556e3b1 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 14:01:03 2025 +0000 🚧 fixing mypy errors in `__init__.py` commit 4d2e3b55c4102ed75d997c8189847bbc5524cb2c Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 12:58:06 2025 +0000 🚧 ensuring routing_tables.py do not fail the ci commit c0cc7b4b09ef50d5ec95fdb0a916c7ed228bf366 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 12:09:24 2025 +0000 🐛 fixing linter problems commit 115a50211b604feb4106275204fe7f863da865f6 Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:47:04 2025 +0000 🐛 fixing ruff errors commit 29b5bfaabc77a35ea036b57f75fded711228dbbf Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:33:31 2025 +0000 🎨 automatic ruff fixes commit 7c5a334c7d4649c2fc297993f89791c1e5643e5b Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 11:15:02 2025 +0000 Squashed commit of the following: commit e671aae5bcd4ea57d601ee73c9e3adf5e223e830 Merge: b0dd9a4f9114bef4
Author: Mac Misiura <82826099+m-misiura@users.noreply.github.com> Date: Fri Mar 21 09:45:08 2025 +0000 Merge branch 'meta-llama:main' into feat_fms_remote_safety_provider commit b0dd9a4f746b0c8c54d1189d381a7ff8e51c812c Author: m-misiura <mmisiura@redhat.com> Date: Fri Mar 21 09:27:21 2025 +0000 📝 updated `provider_id` commit 4c8906c1a4e960968b93251d09d5e5735db15026 Author: m-misiura <mmisiura@redhat.com> Date: Thu Mar 20 16:54:46 2025 +0000 📝 renaming from `fms` to `trustyai_fms` commit 4c0b62abc51b02143b5c818f2d30e1a1fee9e4f3 Merge: bb842d6954035825
Author: m-misiura <mmisiura@redhat.com> Date: Thu Mar 20 16:35:52 2025 +0000 Merge branch 'main' into feat_fms_remote_safety_provider commit bb842d69548df256927465792e0cd107a267d2a0 Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 19 15:03:17 2025 +0000 ✨ added a better way of handling params from the configs commit 58b6beabf0994849ac50317ed00b748596e8961d Merge: a22cf36c7c044845
Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 19 09:19:57 2025 +0000 Merge main into feat_fms_remote_safety_provider, resolve conflicts by keeping main version commit a22cf36c8757f74ed656c1310a4be6b288bf923a Author: m-misiura <mmisiura@redhat.com> Date: Wed Mar 5 16:17:46 2025 +0000 🎉 added a new remote safety provider compatible with FMS Orchestrator API and Detectors API Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
parent
9e1ddf2b53
commit
87d209d6ef
7 changed files with 2883 additions and 20 deletions
|
@ -118,7 +118,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
self.dist_registry = dist_registry
|
||||
|
||||
async def initialize(self) -> None:
|
||||
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||
async def add_objects(
|
||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||
) -> None:
|
||||
for obj in objs:
|
||||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
|
@ -153,7 +155,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
|
||||
def get_provider_impl(
|
||||
self, routing_key: str, provider_id: Optional[str] = None
|
||||
) -> Any:
|
||||
def apiname_object():
|
||||
if isinstance(self, ModelsRoutingTable):
|
||||
return ("Inference", "model")
|
||||
|
@ -191,24 +195,32 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||
|
||||
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
|
||||
async def get_object_by_identifier(
|
||||
self, type: str, identifier: str
|
||||
) -> Optional[RoutableObjectWithProvider]:
|
||||
# Get from disk registry
|
||||
obj = await self.dist_registry.get(type, identifier)
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
if not check_access(obj, get_auth_attributes()):
|
||||
logger.debug(
|
||||
f"Access denied to {type} '{identifier}' based on attribute mismatch"
|
||||
)
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||
await unregister_object_from_provider(
|
||||
obj, self.impls_by_provider_id[obj.provider_id]
|
||||
)
|
||||
|
||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||
async def register_object(
|
||||
self, obj: RoutableObjectWithProvider
|
||||
) -> RoutableObjectWithProvider:
|
||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -223,7 +235,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
logger.info(
|
||||
f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity"
|
||||
)
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
|
@ -242,9 +256,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
obj for obj in filtered_objs if check_access(obj, get_auth_attributes())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
@ -283,7 +295,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
raise ValueError(
|
||||
"Embedding model must have an embedding dimension in its metadata"
|
||||
)
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
|
@ -302,8 +316,54 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize routing table and providers"""
|
||||
# First do the common initialization
|
||||
await super().initialize()
|
||||
|
||||
# Then explicitly initialize each safety provider
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
api = get_impl_api(provider)
|
||||
if api == Api.safety:
|
||||
logger.info(f"Explicitly initializing safety provider: {provider_id}")
|
||||
await provider.initialize()
|
||||
|
||||
# Fetch shields after initialization - with robust error handling
|
||||
try:
|
||||
# Check if the provider implements list_shields
|
||||
if hasattr(provider, "list_shields") and callable(
|
||||
getattr(provider, "list_shields")
|
||||
):
|
||||
shields_response = await provider.list_shields()
|
||||
if (
|
||||
shields_response
|
||||
and hasattr(shields_response, "data")
|
||||
and shields_response.data
|
||||
):
|
||||
for shield in shields_response.data:
|
||||
# Ensure type is set
|
||||
if not hasattr(shield, "type") or not shield.type:
|
||||
shield.type = ResourceType.shield.value
|
||||
await self.dist_registry.register(shield)
|
||||
logger.info(
|
||||
f"Registered {len(shields_response.data)} shields from provider {provider_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No shields found for provider {provider_id}")
|
||||
else:
|
||||
logger.info(
|
||||
f"Provider {provider_id} does not support listing shields"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error but continue initialization
|
||||
logger.warning(
|
||||
f"Error listing shields from provider {provider_id}: {str(e)}"
|
||||
)
|
||||
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||
return ListShieldsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.shield.value)
|
||||
)
|
||||
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
shield = await self.get_object_by_identifier("shield", identifier)
|
||||
|
@ -368,14 +428,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
raise ValueError(
|
||||
"No provider available. Please configure a vector_io provider."
|
||||
)
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
raise ValueError(
|
||||
f"Model {embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
|
@ -397,7 +461,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||
return ListDatasetsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.dataset.value)
|
||||
)
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||
|
@ -459,10 +525,14 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||
return ListScoringFunctionsResponse(
|
||||
data=await self.get_all_with_type(ResourceType.scoring_function.value)
|
||||
)
|
||||
|
||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||
scoring_fn = await self.get_object_by_identifier(
|
||||
"scoring_function", scoring_fn_id
|
||||
)
|
||||
if scoring_fn is None:
|
||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||
return scoring_fn
|
||||
|
@ -565,8 +635,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
tools = []
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
|
||||
toolgroup_id, mcp_endpoint
|
||||
)
|
||||
tool_host = (
|
||||
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||
)
|
||||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
|
|
|
@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="trustyai_fms",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.safety.trustyai_fms",
|
||||
config_class="llama_stack.providers.remote.safety.trustyai_fms.config.FMSSafetyProviderConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
119
llama_stack/providers/remote/safety/trustyai_fms/__init__.py
Normal file
119
llama_stack/providers/remote/safety/trustyai_fms/__init__.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import logging
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
# Remove register_provider import since registration is in registry/safety.py
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.config import (
|
||||
ChatDetectorConfig,
|
||||
ContentDetectorConfig,
|
||||
DetectorParams,
|
||||
EndpointType,
|
||||
FMSSafetyProviderConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||
BaseDetector,
|
||||
DetectorProvider,
|
||||
)
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.detectors.chat import ChatDetector
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.detectors.content import (
|
||||
ContentDetector,
|
||||
)
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type aliases for better readability
|
||||
ConfigType = Union[ContentDetectorConfig, ChatDetectorConfig, FMSSafetyProviderConfig]
|
||||
DetectorType = Union[BaseDetector, DetectorProvider]
|
||||
|
||||
|
||||
class DetectorConfigError(ValueError):
|
||||
"""Raised when detector configuration is invalid"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def create_fms_provider(config: Dict[str, Any]) -> Safety:
|
||||
"""Create FMS safety provider instance.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Safety: Configured FMS safety provider
|
||||
"""
|
||||
logger.debug("Creating trustyai-fms provider")
|
||||
return await get_adapter_impl(FMSSafetyProviderConfig(**config))
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: Union[Dict[str, Any], FMSSafetyProviderConfig],
|
||||
_deps: Optional[Dict[str, Any]] = None,
|
||||
) -> DetectorType:
|
||||
"""Get appropriate detector implementation(s) based on config type.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary or FMSSafetyProviderConfig instance
|
||||
_deps: Optional dependencies for testing/injection
|
||||
|
||||
Returns:
|
||||
Configured detector implementation
|
||||
|
||||
Raises:
|
||||
DetectorConfigError: If configuration is invalid
|
||||
"""
|
||||
try:
|
||||
if isinstance(config, FMSSafetyProviderConfig):
|
||||
provider_config = config
|
||||
else:
|
||||
provider_config = FMSSafetyProviderConfig(**config)
|
||||
|
||||
detectors: Dict[str, DetectorType] = {}
|
||||
|
||||
# Changed from provider_config.detectors to provider_config.shields
|
||||
for shield_id, shield_config in provider_config.shields.items():
|
||||
impl: BaseDetector
|
||||
if isinstance(shield_config, ChatDetectorConfig):
|
||||
impl = ChatDetector(shield_config)
|
||||
elif isinstance(shield_config, ContentDetectorConfig):
|
||||
impl = ContentDetector(shield_config)
|
||||
else:
|
||||
raise DetectorConfigError(
|
||||
f"Invalid shield config type for {shield_id}: {type(shield_config)}"
|
||||
)
|
||||
await impl.initialize()
|
||||
detectors[shield_id] = impl
|
||||
|
||||
detectors_for_provider: Dict[str, BaseDetector] = {}
|
||||
for shield_id, detector in detectors.items():
|
||||
if isinstance(detector, BaseDetector):
|
||||
detectors_for_provider[shield_id] = detector
|
||||
|
||||
return DetectorProvider(detectors_for_provider)
|
||||
|
||||
except Exception as e:
|
||||
raise DetectorConfigError(
|
||||
f"Failed to create detector implementation: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Factory methods
|
||||
"get_adapter_impl",
|
||||
"create_fms_provider",
|
||||
# Configurations
|
||||
"ContentDetectorConfig",
|
||||
"ChatDetectorConfig",
|
||||
"FMSSafetyProviderConfig",
|
||||
"EndpointType",
|
||||
"DetectorParams",
|
||||
# Implementations
|
||||
"ChatDetector",
|
||||
"ContentDetector",
|
||||
"BaseDetector",
|
||||
"DetectorProvider",
|
||||
# Types
|
||||
"ConfigType",
|
||||
"DetectorType",
|
||||
"DetectorConfigError",
|
||||
]
|
562
llama_stack/providers/remote/safety/trustyai_fms/config.py
Normal file
562
llama_stack/providers/remote/safety/trustyai_fms/config.py
Normal file
|
@ -0,0 +1,562 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Valid message types for detectors"""
|
||||
|
||||
USER = "user"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
COMPLETION = "completion"
|
||||
|
||||
@classmethod
|
||||
def as_set(cls) -> Set[str]:
|
||||
"""Get all valid message types as a set"""
|
||||
return {member.value for member in cls}
|
||||
|
||||
|
||||
class EndpointType(Enum):
|
||||
"""API endpoint types and their paths"""
|
||||
|
||||
DIRECT_CONTENT = {
|
||||
"path": "/api/v1/text/contents",
|
||||
"version": "v1",
|
||||
"type": "content",
|
||||
}
|
||||
DIRECT_CHAT = {"path": "/api/v1/text/chat", "version": "v1", "type": "chat"}
|
||||
ORCHESTRATOR_CONTENT = {
|
||||
"path": "/api/v2/text/detection/content",
|
||||
"version": "v2",
|
||||
"type": "content",
|
||||
}
|
||||
ORCHESTRATOR_CHAT = {
|
||||
"path": "/api/v2/text/detection/chat",
|
||||
"version": "v2",
|
||||
"type": "chat",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_endpoint(cls, is_orchestrator: bool, is_chat: bool) -> EndpointType:
|
||||
"""Get appropriate endpoint based on configuration"""
|
||||
if is_orchestrator:
|
||||
return cls.ORCHESTRATOR_CHAT if is_chat else cls.ORCHESTRATOR_CONTENT
|
||||
return cls.DIRECT_CHAT if is_chat else cls.DIRECT_CONTENT
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class DetectorParams:
|
||||
"""Flexible parameter container supporting nested structure and arbitrary parameters"""
|
||||
|
||||
# Store all parameters in a single dictionary for maximum flexibility
|
||||
params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Parameter categories for organization
|
||||
model_params: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Store detectors directly as an attribute (not in params) for orchestrator mode
|
||||
_raw_detectors: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
|
||||
# Standard parameters kept for backward compatibility
|
||||
@property
|
||||
def regex(self) -> Optional[List[str]]:
|
||||
return self.params.get("regex")
|
||||
|
||||
@regex.setter
|
||||
def regex(self, value: List[str]) -> None:
|
||||
self.params["regex"] = value
|
||||
|
||||
@property
|
||||
def temperature(self) -> Optional[float]:
|
||||
return self.model_params.get("temperature") or self.params.get("temperature")
|
||||
|
||||
@temperature.setter
|
||||
def temperature(self, value: float) -> None:
|
||||
self.model_params["temperature"] = value
|
||||
|
||||
@property
|
||||
def risk_name(self) -> Optional[str]:
|
||||
return self.metadata.get("risk_name") or self.params.get("risk_name")
|
||||
|
||||
@risk_name.setter
|
||||
def risk_name(self, value: str) -> None:
|
||||
self.metadata["risk_name"] = value
|
||||
|
||||
@property
|
||||
def risk_definition(self) -> Optional[str]:
|
||||
return self.metadata.get("risk_definition") or self.params.get(
|
||||
"risk_definition"
|
||||
)
|
||||
|
||||
@risk_definition.setter
|
||||
def risk_definition(self, value: str) -> None:
|
||||
self.metadata["risk_definition"] = value
|
||||
|
||||
@property
|
||||
def orchestrator_detectors(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Return detectors in the format required by orchestrator API"""
|
||||
if (
|
||||
not hasattr(self, "_detectors") or not self._detectors
|
||||
): # Direct attribute access
|
||||
return {}
|
||||
|
||||
flattened = {}
|
||||
for (
|
||||
detector_id,
|
||||
detector_config,
|
||||
) in self._detectors.items(): # Direct attribute access
|
||||
# Create a flattened version without extra nesting
|
||||
flat_config = {}
|
||||
|
||||
# Extract detector_params if present and flatten them
|
||||
params = detector_config.get("detector_params", {})
|
||||
if isinstance(params, dict):
|
||||
# Move params up to top level
|
||||
for key, value in params.items():
|
||||
flat_config[key] = value
|
||||
|
||||
flattened[detector_id] = flat_config
|
||||
|
||||
return flattened
|
||||
|
||||
@property
|
||||
def formatted_detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||
"""Return detectors properly formatted for orchestrator API"""
|
||||
# Direct return for API usage - avoid calling other properties
|
||||
if hasattr(self, "_detectors") and self._detectors:
|
||||
return self.orchestrator_detectors
|
||||
return None
|
||||
|
||||
@formatted_detectors.setter
|
||||
def formatted_detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
|
||||
self._detectors = value
|
||||
|
||||
@property
|
||||
def detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||
"""COMPATIBILITY: Returns the same as formatted_detectors to maintain API compatibility"""
|
||||
# Using a different implementation to avoid the redefinition error
|
||||
# while maintaining the same functionality
|
||||
if not hasattr(self, "_detectors") or not self._detectors:
|
||||
return None
|
||||
return self.orchestrator_detectors
|
||||
|
||||
@detectors.setter
|
||||
def detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""COMPATIBILITY: Set detectors while maintaining compatibility"""
|
||||
self._detectors = value
|
||||
|
||||
# And fix the __setitem__ method:
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""Allow dictionary-like assignment with smart categorization"""
|
||||
# Special handling for known params
|
||||
if key == "detectors":
|
||||
self._detectors = value # Set underlying attribute directly
|
||||
return
|
||||
elif key == "regex":
|
||||
self.params[key] = value
|
||||
return
|
||||
|
||||
# Rest of the method remains unchanged
|
||||
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
|
||||
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
|
||||
|
||||
if key in known_model_params:
|
||||
self.model_params[key] = value
|
||||
elif key in known_metadata:
|
||||
self.metadata[key] = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize from any keyword arguments with smart categorization"""
|
||||
# Initialize containers
|
||||
self.params = {}
|
||||
self.model_params = {}
|
||||
self.metadata = {}
|
||||
self.kwargs = {}
|
||||
self._raw_detectors = None
|
||||
|
||||
# Special handling for nested detectors structure
|
||||
if "detectors" in kwargs:
|
||||
self._raw_detectors = kwargs.pop("detectors")
|
||||
|
||||
# Special handling for regex
|
||||
if "regex" in kwargs:
|
||||
self.params["regex"] = kwargs.pop("regex")
|
||||
|
||||
# Categorize known parameters
|
||||
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
|
||||
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
|
||||
|
||||
# Explicit categories if provided
|
||||
if "model_params" in kwargs:
|
||||
self.model_params.update(kwargs.pop("model_params"))
|
||||
|
||||
if "metadata" in kwargs:
|
||||
self.metadata.update(kwargs.pop("metadata"))
|
||||
|
||||
if "kwargs" in kwargs:
|
||||
self.kwargs.update(kwargs.pop("kwargs"))
|
||||
|
||||
# Categorize remaining parameters
|
||||
for key, value in kwargs.items():
|
||||
if key in known_model_params:
|
||||
self.model_params[key] = value
|
||||
elif key in known_metadata:
|
||||
self.metadata[key] = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Allow dictionary-like access with category lookup"""
|
||||
if key in self.params:
|
||||
return self.params[key]
|
||||
elif key in self.model_params:
|
||||
return self.model_params[key]
|
||||
elif key in self.metadata:
|
||||
return self.metadata[key]
|
||||
elif key in self.kwargs:
|
||||
return self.kwargs[key]
|
||||
return None
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Dictionary-style get with category lookup"""
|
||||
result = self.__getitem__(key)
|
||||
return default if result is None else result
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set a parameter value with smart categorization"""
|
||||
self.__setitem__(key, value)
|
||||
|
||||
def update(self, params: Dict[str, Any]) -> None:
|
||||
"""Update with multiple parameters, respecting categories"""
|
||||
for key, value in params.items():
|
||||
self.__setitem__(key, value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert all parameters to a flat dictionary for API requests"""
|
||||
result = {}
|
||||
|
||||
# Add core parameters
|
||||
result.update(self.params)
|
||||
|
||||
# Add all categorized parameters, flattened
|
||||
result.update(self.model_params)
|
||||
result.update(self.metadata)
|
||||
result.update(self.kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def to_categorized_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a structured dictionary with categories preserved"""
|
||||
result = dict(self.params)
|
||||
|
||||
if self.model_params:
|
||||
result["model_params"] = dict(self.model_params)
|
||||
|
||||
if self.metadata:
|
||||
result["metadata"] = dict(self.metadata)
|
||||
|
||||
if self.kwargs:
|
||||
result["kwargs"] = dict(self.kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def create_flattened_detector_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Create flattened detector configurations for orchestrator mode.
|
||||
This removes the extra detector_params nesting that causes API errors.
|
||||
"""
|
||||
if not self.detectors:
|
||||
return {}
|
||||
|
||||
flattened = {}
|
||||
for detector_id, detector_config in self.detectors.items():
|
||||
# Create a flattened version without extra nesting
|
||||
flat_config = {}
|
||||
|
||||
# Extract detector_params if present and flatten them
|
||||
params = detector_config.get("detector_params", {})
|
||||
if isinstance(params, dict):
|
||||
# Move params up to top level
|
||||
for key, value in params.items():
|
||||
flat_config[key] = value
|
||||
|
||||
flattened[detector_id] = flat_config
|
||||
|
||||
return flattened
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate parameter values"""
|
||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||
raise ValueError("Temperature must be between 0 and 1")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class BaseDetectorConfig:
|
||||
"""Base configuration for all detectors with flexible parameter handling"""
|
||||
|
||||
detector_id: str
|
||||
confidence_threshold: float = 0.5
|
||||
message_types: Set[str] = field(default_factory=lambda: MessageType.as_set())
|
||||
auth_token: Optional[str] = None
|
||||
detector_params: Optional[DetectorParams] = None
|
||||
|
||||
# URL fields directly on detector configs
|
||||
detector_url: Optional[str] = None
|
||||
orchestrator_url: Optional[str] = None
|
||||
|
||||
# Flexible storage for any additional parameters
|
||||
_extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Runtime execution parameters
|
||||
max_concurrency: int = 10 # Maximum concurrent API requests
|
||||
request_timeout: float = 30.0 # HTTP request timeout in seconds
|
||||
max_retries: int = 3 # Maximum number of retry attempts
|
||||
backoff_factor: float = 1.5 # Exponential backoff multiplier
|
||||
max_keepalive_connections: int = 5 # Max number of keepalive connections
|
||||
max_connections: int = 10 # Max number of connections in the pool
|
||||
|
||||
@property
|
||||
def use_orchestrator_api(self) -> bool:
|
||||
"""Determine if orchestrator API should be used"""
|
||||
return bool(self.orchestrator_url)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Process configuration after initialization"""
|
||||
# Convert list/tuple message_types to set
|
||||
if isinstance(self.message_types, (list, tuple)):
|
||||
self.message_types = set(self.message_types)
|
||||
|
||||
# Validate message types
|
||||
invalid_types = self.message_types - MessageType.as_set()
|
||||
if invalid_types:
|
||||
raise ValueError(
|
||||
f"Invalid message types: {invalid_types}. "
|
||||
f"Valid types are: {MessageType.as_set()}"
|
||||
)
|
||||
|
||||
# Initialize detector_params if needed
|
||||
if self.detector_params is None:
|
||||
self.detector_params = DetectorParams()
|
||||
|
||||
# Handle legacy URL field names
|
||||
if hasattr(self, "base_url") and self.base_url and not self.detector_url:
|
||||
self.detector_url = self.base_url
|
||||
|
||||
if (
|
||||
hasattr(self, "orchestrator_base_url")
|
||||
and self.orchestrator_base_url
|
||||
and not self.orchestrator_url
|
||||
):
|
||||
self.orchestrator_url = self.orchestrator_base_url
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate configuration"""
|
||||
# Validate detector_params
|
||||
if self.detector_params:
|
||||
self.detector_params.validate()
|
||||
|
||||
# Validate that at least one URL is provided
|
||||
if not self.detector_url and not self.orchestrator_url:
|
||||
raise ValueError(f"No URL provided for detector {self.detector_id}")
|
||||
|
||||
# Validate URLs if present
|
||||
for url_name, url in [
|
||||
("detector_url", self.detector_url),
|
||||
("orchestrator_url", self.orchestrator_url),
|
||||
]:
|
||||
if url:
|
||||
self._validate_url(url, url_name)
|
||||
|
||||
def _validate_url(self, url: str, url_name: str) -> None:
|
||||
"""Validate URL format"""
|
||||
parsed = urlparse(url)
|
||||
if not all([parsed.scheme, parsed.netloc]):
|
||||
raise ValueError(f"Invalid {url_name} format: {url}")
|
||||
if parsed.scheme not in {"http", "https"}:
|
||||
raise ValueError(f"Invalid {url_name} scheme: {parsed.scheme}")
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""Get parameter with fallback to extra parameters"""
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
return self._extra_params.get(key, default)
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
"""Set parameter, storing in extra_params if not a standard field"""
|
||||
if hasattr(self, key) and key not in ["_extra_params"]:
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
self._extra_params[key] = value
|
||||
|
||||
@property
|
||||
def is_chat(self) -> bool:
|
||||
"""Default implementation, should be overridden by subclasses"""
|
||||
return False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ContentDetectorConfig(BaseDetectorConfig):
|
||||
"""Configuration for content detectors"""
|
||||
|
||||
@property
|
||||
def is_chat(self) -> bool:
|
||||
"""Content detectors are not chat detectors"""
|
||||
return False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@dataclass
|
||||
class ChatDetectorConfig(BaseDetectorConfig):
|
||||
"""Configuration for chat detectors"""
|
||||
|
||||
@property
|
||||
def is_chat(self) -> bool:
|
||||
"""Chat detectors are chat detectors"""
|
||||
return True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FMSSafetyProviderConfig(BaseModel):
|
||||
"""Configuration for the FMS Safety Provider"""
|
||||
|
||||
shields: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
# Rename _detectors to remove the leading underscore
|
||||
detectors_internal: Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]] = (
|
||||
Field(default_factory=dict, exclude=True)
|
||||
)
|
||||
|
||||
# Provider-level orchestrator URL (can be copied to shields if needed)
|
||||
orchestrator_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# Add a model validator to replace __post_init__
|
||||
@model_validator(mode="after")
|
||||
def setup_config(self):
|
||||
"""Process shield configurations"""
|
||||
# Process shields into detector objects
|
||||
self._process_shields()
|
||||
|
||||
# Replace shields dictionary with processed detector configs
|
||||
self.shields = self.detectors_internal
|
||||
|
||||
# Validate all shields
|
||||
for shield in self.shields.values():
|
||||
shield.validate()
|
||||
|
||||
return self
|
||||
|
||||
def _process_shields(self):
|
||||
"""Process all shield configurations into detector configs"""
|
||||
for shield_id, config in self.shields.items():
|
||||
if isinstance(config, dict):
|
||||
# Copy the config to avoid modifying the original
|
||||
shield_config = dict(config)
|
||||
|
||||
# Check if this shield has nested detectors
|
||||
nested_detectors = shield_config.pop("detectors", None)
|
||||
|
||||
# Determine detector type
|
||||
detector_type = shield_config.pop("type", None)
|
||||
is_chat = (
|
||||
detector_type == "chat"
|
||||
if detector_type
|
||||
else shield_config.pop("is_chat", False)
|
||||
)
|
||||
|
||||
# Set detector ID
|
||||
shield_config["detector_id"] = shield_id
|
||||
|
||||
# Handle URL fields
|
||||
# First handle legacy field names
|
||||
if "base_url" in shield_config and "detector_url" not in shield_config:
|
||||
shield_config["detector_url"] = shield_config.pop("base_url")
|
||||
|
||||
if (
|
||||
"orchestrator_base_url" in shield_config
|
||||
and "orchestrator_url" not in shield_config
|
||||
):
|
||||
shield_config["orchestrator_url"] = shield_config.pop(
|
||||
"orchestrator_base_url"
|
||||
)
|
||||
|
||||
# If no orchestrator_url in shield but provider has one, copy it
|
||||
if self.orchestrator_url and "orchestrator_url" not in shield_config:
|
||||
shield_config["orchestrator_url"] = self.orchestrator_url
|
||||
|
||||
# Initialize detector_params with proper structure for nested detectors
|
||||
detector_params_dict = shield_config.get("detector_params", {})
|
||||
if not isinstance(detector_params_dict, dict):
|
||||
detector_params_dict = {}
|
||||
|
||||
# Create detector_params object
|
||||
detector_params = DetectorParams(**detector_params_dict)
|
||||
|
||||
# Add nested detectors if present
|
||||
if nested_detectors:
|
||||
detector_params.detectors = nested_detectors
|
||||
|
||||
shield_config["detector_params"] = detector_params
|
||||
|
||||
# Create appropriate detector config
|
||||
detector_class = (
|
||||
ChatDetectorConfig if is_chat else ContentDetectorConfig
|
||||
)
|
||||
self.detectors_internal[shield_id] = detector_class(**shield_config)
|
||||
|
||||
@property
|
||||
def all_detectors(
|
||||
self,
|
||||
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
|
||||
"""Get all detector configurations"""
|
||||
return self.detectors_internal
|
||||
|
||||
# Update other methods to use detectors_internal instead of _detectors
|
||||
def get_detectors_by_type(
|
||||
self, message_type: Union[str, MessageType]
|
||||
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
|
||||
"""Get detectors for a specific message type"""
|
||||
type_value = (
|
||||
message_type.value
|
||||
if isinstance(message_type, MessageType)
|
||||
else message_type
|
||||
)
|
||||
return {
|
||||
shield_id: shield
|
||||
for shield_id, shield in self.detectors_internal.items()
|
||||
if type_value in shield.message_types
|
||||
}
|
||||
|
||||
# Convenience properties
|
||||
@property
|
||||
def user_message_detectors(self):
|
||||
return self.get_detectors_by_type(MessageType.USER)
|
||||
|
||||
@property
|
||||
def system_message_detectors(self):
|
||||
return self.get_detectors_by_type(MessageType.SYSTEM)
|
||||
|
||||
@property
|
||||
def tool_response_detectors(self):
|
||||
return self.get_detectors_by_type(MessageType.TOOL)
|
||||
|
||||
@property
|
||||
def completion_message_detectors(self):
|
||||
return self.get_detectors_by_type(MessageType.COMPLETION)
|
1561
llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py
Normal file
1561
llama_stack/providers/remote/safety/trustyai_fms/detectors/base.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,329 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.config import ChatDetectorConfig
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||
BaseDetector,
|
||||
DetectionResult,
|
||||
DetectorError,
|
||||
DetectorRequestError,
|
||||
DetectorValidationError,
|
||||
)
|
||||
|
||||
# Type aliases for better readability
|
||||
ChatMessage = Dict[
|
||||
str, Any
|
||||
] # Changed from Dict[str, str] to Dict[str, Any] to handle complex content
|
||||
ChatRequest = Dict[str, Any]
|
||||
DetectorResponse = List[Dict[str, Any]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDetectorError(DetectorError):
|
||||
"""Specific errors for chat detector operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatDetectionMetadata:
|
||||
"""Structured metadata for chat detections"""
|
||||
|
||||
risk_name: Optional[str] = None
|
||||
risk_definition: Optional[str] = None
|
||||
additional_metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert metadata to dictionary format"""
|
||||
result: Dict[str, Any] = {} # Fixed the type annotation here
|
||||
if self.risk_name:
|
||||
result["risk_name"] = self.risk_name
|
||||
if self.risk_definition:
|
||||
result["risk_definition"] = self.risk_definition
|
||||
if self.additional_metadata:
|
||||
result["metadata"] = self.additional_metadata
|
||||
return result
|
||||
|
||||
|
||||
class ChatDetector(BaseDetector):
|
||||
"""Detector for chat-based safety checks"""
|
||||
|
||||
def __init__(self, config: ChatDetectorConfig) -> None:
|
||||
"""Initialize chat detector with configuration"""
|
||||
if not isinstance(config, ChatDetectorConfig):
|
||||
raise DetectorValidationError(
|
||||
"Config must be an instance of ChatDetectorConfig"
|
||||
)
|
||||
super().__init__(config)
|
||||
self.config: ChatDetectorConfig = config
|
||||
logger.info(f"Initialized ChatDetector with config: {vars(config)}")
|
||||
|
||||
def _extract_detector_params(self) -> Dict[str, Any]:
|
||||
"""Extract non-null detector parameters"""
|
||||
if not self.config.detector_params:
|
||||
return {}
|
||||
|
||||
params = {
|
||||
k: v for k, v in vars(self.config.detector_params).items() if v is not None
|
||||
}
|
||||
logger.debug(f"Extracted detector params: {params}")
|
||||
return params
|
||||
|
||||
def _prepare_chat_request(
|
||||
self, messages: List[ChatMessage], params: Optional[Dict[str, Any]] = None
|
||||
) -> ChatRequest:
|
||||
"""Prepare the request based on API mode"""
|
||||
# Format messages for detector API
|
||||
formatted_messages: List[Dict[str, str]] = [] # Explicitly typed
|
||||
for msg in messages:
|
||||
formatted_msg = {
|
||||
"content": str(msg.get("content", "")), # Ensure string type
|
||||
"role": "user", # Always send as user for detector API
|
||||
}
|
||||
formatted_messages.append(formatted_msg)
|
||||
|
||||
if self.config.use_orchestrator_api:
|
||||
payload: Dict[str, Any] = {
|
||||
"messages": formatted_messages
|
||||
} # Explicitly typed
|
||||
|
||||
# Initialize detector_config to avoid None
|
||||
detector_config: Dict[str, Any] = {} # Explicitly typed
|
||||
|
||||
# NEW STRUCTURE: Check for top-level detectors first
|
||||
if hasattr(self.config, "detectors") and self.config.detectors:
|
||||
for detector_id, det_config in self.config.detectors.items():
|
||||
detector_config[detector_id] = det_config.get("detector_params", {})
|
||||
|
||||
# LEGACY STRUCTURE: Check for nested detectors
|
||||
elif (
|
||||
self.config.detector_params
|
||||
and hasattr(self.config.detector_params, "detectors")
|
||||
and self.config.detector_params.detectors
|
||||
):
|
||||
detector_config = self.config.detector_params.detectors
|
||||
|
||||
# Handle flat params - group them into generic containers
|
||||
elif self.config.detector_params:
|
||||
detector_params = self._extract_detector_params()
|
||||
|
||||
# Create a flat dictionary of parameters
|
||||
flat_params = {}
|
||||
|
||||
# Extract from model_params
|
||||
if "model_params" in detector_params and isinstance(
|
||||
detector_params["model_params"], dict
|
||||
):
|
||||
flat_params.update(detector_params["model_params"])
|
||||
|
||||
# Extract from metadata
|
||||
if "metadata" in detector_params and isinstance(
|
||||
detector_params["metadata"], dict
|
||||
):
|
||||
flat_params.update(detector_params["metadata"])
|
||||
|
||||
# Extract from kwargs
|
||||
if "kwargs" in detector_params and isinstance(
|
||||
detector_params["kwargs"], dict
|
||||
):
|
||||
flat_params.update(detector_params["kwargs"])
|
||||
|
||||
# Add any other direct parameters, but skip container dictionaries
|
||||
for k, v in detector_params.items():
|
||||
if (
|
||||
k not in ["model_params", "metadata", "kwargs", "params"]
|
||||
and v is not None
|
||||
):
|
||||
flat_params[k] = v
|
||||
|
||||
# Add all flattened parameters directly to detector configuration
|
||||
detector_config[self.config.detector_id] = flat_params
|
||||
|
||||
# Ensure we have a valid detectors map even if all checks fail
|
||||
if not detector_config:
|
||||
detector_config = {self.config.detector_id: {}}
|
||||
|
||||
payload["detectors"] = detector_config
|
||||
return payload
|
||||
|
||||
# Direct API format remains unchanged
|
||||
else:
|
||||
# DIRECT MODE: Use flat parameters for API compatibility
|
||||
# Don't organize into containers for direct mode
|
||||
detector_params = self._extract_detector_params()
|
||||
|
||||
# Flatten the parameters for direct mode too
|
||||
flat_params = {}
|
||||
|
||||
# Extract from model_params
|
||||
if "model_params" in detector_params and isinstance(
|
||||
detector_params["model_params"], dict
|
||||
):
|
||||
flat_params.update(detector_params["model_params"])
|
||||
|
||||
# Extract from metadata
|
||||
if "metadata" in detector_params and isinstance(
|
||||
detector_params["metadata"], dict
|
||||
):
|
||||
flat_params.update(detector_params["metadata"])
|
||||
|
||||
# Extract from kwargs
|
||||
if "kwargs" in detector_params and isinstance(
|
||||
detector_params["kwargs"], dict
|
||||
):
|
||||
flat_params.update(detector_params["kwargs"])
|
||||
|
||||
# Add any other direct parameters
|
||||
for k, v in detector_params.items():
|
||||
if (
|
||||
k not in ["model_params", "metadata", "kwargs", "params"]
|
||||
and v is not None
|
||||
):
|
||||
flat_params[k] = v
|
||||
|
||||
return {
|
||||
"messages": formatted_messages,
|
||||
"detector_params": flat_params if flat_params else params or {},
|
||||
}
|
||||
|
||||
async def _call_detector_api(
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> DetectorResponse:
|
||||
"""Call chat detector API with proper endpoint selection"""
|
||||
try:
|
||||
request = self._prepare_chat_request(messages, params)
|
||||
headers = self._prepare_headers()
|
||||
|
||||
logger.info("Making detector API request")
|
||||
logger.debug(f"Request headers: {headers}")
|
||||
logger.debug(f"Request payload: {request}")
|
||||
|
||||
response = await self._make_request(request, headers)
|
||||
return self._extract_detections(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API call failed: {str(e)}", exc_info=True)
|
||||
raise DetectorRequestError(
|
||||
f"Chat detector API call failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
|
||||
"""Extract detections from API response"""
|
||||
if not response:
|
||||
logger.debug("Empty response received")
|
||||
return []
|
||||
|
||||
if self.config.use_orchestrator_api:
|
||||
detections = response.get("detections", [])
|
||||
if not detections:
|
||||
# Add default detection when none returned
|
||||
logger.debug("No detections found, adding default low-score detection")
|
||||
return [
|
||||
{
|
||||
"detection_type": "risk",
|
||||
"detection": "No",
|
||||
"detector_id": self.config.detector_id,
|
||||
"score": 0.0, # Default low score
|
||||
}
|
||||
]
|
||||
logger.debug(f"Orchestrator detections: {detections}")
|
||||
return cast(
|
||||
DetectorResponse, detections
|
||||
) # Explicit cast to correct return type
|
||||
|
||||
# Direct API returns a list where first item contains detections
|
||||
if isinstance(response, list) and response:
|
||||
detections = (
|
||||
[response[0]] if not isinstance(response[0], list) else response[0]
|
||||
)
|
||||
logger.debug(f"Direct API detections: {detections}")
|
||||
return cast(
|
||||
DetectorResponse, detections
|
||||
) # Explicit cast to correct return type
|
||||
|
||||
logger.debug("No detections found in response")
|
||||
return []
|
||||
|
||||
def _process_detection(
|
||||
self, detection: Dict[str, Any]
|
||||
) -> Tuple[
|
||||
Optional[DetectionResult], float
|
||||
]: # Changed return type to match base class
|
||||
"""Process detection result and validate against threshold"""
|
||||
score = detection.get("score", 0.0) # Default to 0.0 if score is missing
|
||||
|
||||
if score > self.score_threshold:
|
||||
metadata = ChatDetectionMetadata(
|
||||
risk_name=(
|
||||
self.config.detector_params.risk_name
|
||||
if self.config.detector_params
|
||||
else None
|
||||
),
|
||||
risk_definition=(
|
||||
self.config.detector_params.risk_definition
|
||||
if self.config.detector_params
|
||||
else None
|
||||
),
|
||||
additional_metadata=detection.get("metadata"),
|
||||
)
|
||||
|
||||
result = DetectionResult(
|
||||
detection="Yes",
|
||||
detection_type=detection["detection_type"],
|
||||
score=score,
|
||||
detector_id=detection.get("detector_id", self.config.detector_id),
|
||||
text=detection.get("text", ""),
|
||||
start=detection.get("start", 0),
|
||||
end=detection.get("end", 0),
|
||||
metadata=metadata.to_dict(),
|
||||
)
|
||||
return (result, score)
|
||||
return (None, score)
|
||||
|
||||
async def _run_shield_impl(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> RunShieldResponse:
|
||||
"""Implementation of shield checks for chat messages"""
|
||||
try:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
self._validate_shield(shield)
|
||||
|
||||
logger.info(f"Processing {len(messages)} message(s)")
|
||||
|
||||
# Convert messages keeping only necessary fields
|
||||
chat_messages: List[ChatMessage] = [] # Explicitly typed
|
||||
for msg in messages:
|
||||
message_dict: ChatMessage = {"content": msg.content, "role": msg.role}
|
||||
# Preserve type if present for internal processing
|
||||
if hasattr(msg, "type"):
|
||||
message_dict["type"] = msg.type
|
||||
chat_messages.append(message_dict)
|
||||
|
||||
logger.debug(f"Prepared messages: {chat_messages}")
|
||||
detections = await self._call_detector_api(chat_messages, params)
|
||||
|
||||
for detection in detections:
|
||||
processed, score = self._process_detection(detection)
|
||||
if processed:
|
||||
logger.info(f"Violation detected: {processed}")
|
||||
return self.create_violation_response(
|
||||
processed, detection.get("detector_id", self.config.detector_id)
|
||||
)
|
||||
|
||||
logger.debug("No violations detected")
|
||||
return RunShieldResponse()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
|
||||
raise ChatDetectorError(f"Shield execution failed: {str(e)}") from e
|
|
@ -0,0 +1,209 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.config import (
|
||||
ContentDetectorConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
|
||||
BaseDetector,
|
||||
DetectionResult,
|
||||
DetectorError,
|
||||
DetectorRequestError,
|
||||
DetectorValidationError,
|
||||
)
|
||||
|
||||
# Type aliases for better readability
|
||||
ContentRequest = Dict[str, Any]
|
||||
DetectorResponse = List[Dict[str, Any]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentDetectorError(DetectorError):
|
||||
"""Specific errors for content detector operations"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ContentDetector(BaseDetector):
|
||||
"""Detector for content-based safety checks"""
|
||||
|
||||
def __init__(self, config: ContentDetectorConfig) -> None:
|
||||
"""Initialize content detector with configuration"""
|
||||
if not isinstance(config, ContentDetectorConfig):
|
||||
raise DetectorValidationError(
|
||||
"Config must be an instance of ContentDetectorConfig"
|
||||
)
|
||||
super().__init__(config)
|
||||
self.config: ContentDetectorConfig = config
|
||||
logger.info(f"Initialized ContentDetector with config: {vars(config)}")
|
||||
|
||||
def _extract_detector_params(self) -> Dict[str, Any]:
|
||||
"""Extract detector parameters with support for generic format"""
|
||||
if not self.config.detector_params:
|
||||
return {}
|
||||
|
||||
# Use to_dict() to flatten our categorized structure into what the API expects
|
||||
params = self.config.detector_params.to_dict()
|
||||
logger.debug(f"Extracted detector params: {params}")
|
||||
return params
|
||||
|
||||
def _prepare_content_request(
|
||||
self, content: str, params: Optional[Dict[str, Any]] = None
|
||||
) -> ContentRequest:
|
||||
"""Prepare the request based on API mode"""
|
||||
|
||||
if self.config.use_orchestrator_api:
|
||||
payload: Dict[str, Any] = {"content": content} # Always use singular form
|
||||
|
||||
# NEW STRUCTURE: Check for top-level detectors first
|
||||
if hasattr(self.config, "detectors") and self.config.detectors:
|
||||
detector_config: Dict[str, Any] = {}
|
||||
for detector_id, det_config in self.config.detectors.items():
|
||||
detector_config[detector_id] = det_config.get("detector_params", {})
|
||||
payload["detectors"] = detector_config
|
||||
return payload
|
||||
|
||||
# LEGACY STRUCTURE: Check for nested detectors
|
||||
elif self.config.detector_params and hasattr(
|
||||
self.config.detector_params, "detectors"
|
||||
):
|
||||
detectors = getattr(self.config.detector_params, "detectors", {})
|
||||
payload["detectors"] = detectors
|
||||
return payload
|
||||
|
||||
# Handle flat params
|
||||
else:
|
||||
detector_config = {}
|
||||
detector_params = self._extract_detector_params()
|
||||
if detector_params:
|
||||
detector_config[self.config.detector_id] = detector_params
|
||||
payload["detectors"] = detector_config
|
||||
return payload
|
||||
|
||||
else:
|
||||
# DIRECT MODE: Use flat parameters for API compatibility
|
||||
detector_params = self._extract_detector_params()
|
||||
|
||||
return {
|
||||
"contents": [content],
|
||||
"detector_params": detector_params if detector_params else params or {},
|
||||
}
|
||||
|
||||
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
|
||||
"""Extract detections from API response"""
|
||||
if not response:
|
||||
logger.debug("Empty response received")
|
||||
return []
|
||||
|
||||
if self.config.use_orchestrator_api:
|
||||
detections = response.get("detections", [])
|
||||
logger.debug(f"Orchestrator detections: {detections}")
|
||||
return cast(List[Dict[str, Any]], detections)
|
||||
|
||||
# Direct API returns a list of lists where inner list contains detections
|
||||
if isinstance(response, list) and response:
|
||||
detections = response[0] if isinstance(response[0], list) else [response[0]]
|
||||
logger.debug(f"Direct API detections: {detections}")
|
||||
return cast(List[Dict[str, Any]], detections)
|
||||
|
||||
logger.debug("No detections found in response")
|
||||
return []
|
||||
|
||||
async def _call_detector_api(
|
||||
self,
|
||||
content: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> DetectorResponse:
|
||||
"""Call detector API with proper endpoint selection"""
|
||||
try:
|
||||
request = self._prepare_content_request(content, params)
|
||||
headers = self._prepare_headers()
|
||||
|
||||
logger.info("Making detector API request")
|
||||
logger.debug(f"Request headers: {headers}")
|
||||
logger.debug(f"Request payload: {request}")
|
||||
|
||||
response = await self._make_request(request, headers)
|
||||
return self._extract_detections(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API call failed: {str(e)}", exc_info=True)
|
||||
raise DetectorRequestError(
|
||||
f"Content detector API call failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _process_detection(
|
||||
self, detection: Dict[str, Any]
|
||||
) -> tuple[Optional[DetectionResult], float]:
|
||||
"""Process detection result and validate against threshold"""
|
||||
if not detection.get("score"):
|
||||
logger.warning("Detection missing score field")
|
||||
return None, 0.0
|
||||
|
||||
score = detection.get("score", 0)
|
||||
if score > self.score_threshold:
|
||||
result = DetectionResult(
|
||||
detection="Yes",
|
||||
detection_type=detection["detection_type"],
|
||||
score=score,
|
||||
detector_id=detection.get("detector_id", self.config.detector_id),
|
||||
text=detection.get("text", ""),
|
||||
start=detection.get("start", 0),
|
||||
end=detection.get("end", 0),
|
||||
metadata=detection.get("metadata", {}),
|
||||
)
|
||||
return result, score
|
||||
return None, score
|
||||
|
||||
async def _run_shield_impl(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> RunShieldResponse:
|
||||
"""Implementation of shield checks for content messages"""
|
||||
try:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
self._validate_shield(shield)
|
||||
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
content_str: str
|
||||
|
||||
# Simplified content handling - just check for text attribute or convert to string
|
||||
if hasattr(content, "text"):
|
||||
content_str = str(content.text)
|
||||
elif isinstance(content, list):
|
||||
content_str = " ".join(
|
||||
str(getattr(item, "text", "")) for item in content
|
||||
)
|
||||
else:
|
||||
content_str = str(content)
|
||||
|
||||
truncated_content = (
|
||||
content_str[:100] + "..." if len(content_str) > 100 else content_str
|
||||
)
|
||||
logger.debug(f"Checking content: {truncated_content}")
|
||||
|
||||
detections = await self._call_detector_api(content_str, params)
|
||||
|
||||
for detection in detections:
|
||||
processed, score = self._process_detection(detection)
|
||||
if processed:
|
||||
logger.info(f"Violation detected: {processed}")
|
||||
return self.create_violation_response(
|
||||
processed,
|
||||
detection.get("detector_id", self.config.detector_id),
|
||||
)
|
||||
|
||||
logger.debug("No violations detected")
|
||||
return RunShieldResponse()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
|
||||
raise ContentDetectorError(f"Shield execution failed: {str(e)}") from e
|
Loading…
Add table
Add a link
Reference in a new issue