Squashed commit of the following:

commit a95d2b15b83057e194cf69e57a03deeeeeadd7c2
Author: m-misiura <mmisiura@redhat.com>
Date:   Mon Mar 24 14:33:50 2025 +0000

    🚧 working on the config file so that it is inheriting from pydantic base models

commit 0546379f817e37bca030247b48c72ce84899a766
Author: m-misiura <mmisiura@redhat.com>
Date:   Mon Mar 24 09:14:31 2025 +0000

    🚧 dealing with ruff checks

commit 8abe39ee4cb4b8fb77c7252342c4809fa6ddc432
Author: m-misiura <mmisiura@redhat.com>
Date:   Mon Mar 24 09:03:18 2025 +0000

    🚧 dealing with mypy errors in `base.py`

commit 045f833e79c9a25af3d46af6c8896da91a0e6e62
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 17:31:25 2025 +0000

    🚧 fixing mypy errors in content.py

commit a9c1ee4e92ad1b5db89039317555cd983edbde65
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 17:09:02 2025 +0000

    🚧 fixing mypy errors in chat.py

commit 69e8ddc2f8a4e13cecbab30272fd7d685d7864ec
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 16:57:28 2025 +0000

    🚧 fixing mypy errors

commit 56739d69a145c55335ac2859ecbe5b43d556e3b1
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 14:01:03 2025 +0000

    🚧 fixing mypy errors in `__init__.py`

commit 4d2e3b55c4102ed75d997c8189847bbc5524cb2c
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 12:58:06 2025 +0000

    🚧 ensuring routing_tables.py do not fail the ci

commit c0cc7b4b09ef50d5ec95fdb0a916c7ed228bf366
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 12:09:24 2025 +0000

    🐛 fixing linter problems

commit 115a50211b604feb4106275204fe7f863da865f6
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 11:47:04 2025 +0000

    🐛 fixing ruff errors

commit 29b5bfaabc77a35ea036b57f75fded711228dbbf
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 11:33:31 2025 +0000

    🎨 automatic ruff fixes

commit 7c5a334c7d4649c2fc297993f89791c1e5643e5b
Author: m-misiura <mmisiura@redhat.com>
Date:   Fri Mar 21 11:15:02 2025 +0000

    Squashed commit of the following:

    commit e671aae5bcd4ea57d601ee73c9e3adf5e223e830
    Merge: b0dd9a4f 9114bef4
    Author: Mac Misiura <82826099+m-misiura@users.noreply.github.com>
    Date:   Fri Mar 21 09:45:08 2025 +0000

        Merge branch 'meta-llama:main' into feat_fms_remote_safety_provider

    commit b0dd9a4f746b0c8c54d1189d381a7ff8e51c812c
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Fri Mar 21 09:27:21 2025 +0000

        📝 updated `provider_id`

    commit 4c8906c1a4e960968b93251d09d5e5735db15026
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Thu Mar 20 16:54:46 2025 +0000

        📝 renaming from `fms` to `trustyai_fms`

    commit 4c0b62abc51b02143b5c818f2d30e1a1fee9e4f3
    Merge: bb842d69 54035825
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Thu Mar 20 16:35:52 2025 +0000

        Merge branch 'main' into feat_fms_remote_safety_provider

    commit bb842d69548df256927465792e0cd107a267d2a0
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Wed Mar 19 15:03:17 2025 +0000

         added a better way of handling params from the configs

    commit 58b6beabf0994849ac50317ed00b748596e8961d
    Merge: a22cf36c 7c044845
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Wed Mar 19 09:19:57 2025 +0000

        Merge main into feat_fms_remote_safety_provider, resolve conflicts by keeping main version

    commit a22cf36c8757f74ed656c1310a4be6b288bf923a
    Author: m-misiura <mmisiura@redhat.com>
    Date:   Wed Mar 5 16:17:46 2025 +0000

        🎉 added a new remote safety provider compatible with FMS Orchestrator API and Detectors API

        Signed-off-by: m-misiura <mmisiura@redhat.com>
This commit is contained in:
m-misiura 2025-03-24 14:46:03 +00:00
parent 9e1ddf2b53
commit 87d209d6ef
7 changed files with 2883 additions and 20 deletions

View file

@ -118,7 +118,9 @@ class CommonRoutingTableImpl(RoutingTable):
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None:
async def add_objects(
objs: List[RoutableObjectWithProvider], provider_id: str, cls
) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
@ -153,7 +155,9 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any:
def get_provider_impl(
self, routing_key: str, provider_id: Optional[str] = None
) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
@ -191,24 +195,32 @@ class CommonRoutingTableImpl(RoutingTable):
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]:
async def get_object_by_identifier(
self, type: str, identifier: str
) -> Optional[RoutableObjectWithProvider]:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
if not check_access(obj, get_auth_attributes()):
logger.debug(
f"Access denied to {type} '{identifier}' based on attribute mismatch"
)
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
await unregister_object_from_provider(
obj, self.impls_by_provider_id[obj.provider_id]
)
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
async def register_object(
self, obj: RoutableObjectWithProvider
) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
@ -223,7 +235,9 @@ class CommonRoutingTableImpl(RoutingTable):
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
logger.info(
f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity"
)
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
@ -242,9 +256,7 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
obj for obj in filtered_objs if check_access(obj, get_auth_attributes())
]
return filtered_objs
@ -283,7 +295,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
@ -302,8 +316,54 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def initialize(self) -> None:
"""Initialize routing table and providers"""
# First do the common initialization
await super().initialize()
# Then explicitly initialize each safety provider
for provider_id, provider in self.impls_by_provider_id.items():
api = get_impl_api(provider)
if api == Api.safety:
logger.info(f"Explicitly initializing safety provider: {provider_id}")
await provider.initialize()
# Fetch shields after initialization - with robust error handling
try:
# Check if the provider implements list_shields
if hasattr(provider, "list_shields") and callable(
getattr(provider, "list_shields")
):
shields_response = await provider.list_shields()
if (
shields_response
and hasattr(shields_response, "data")
and shields_response.data
):
for shield in shields_response.data:
# Ensure type is set
if not hasattr(shield, "type") or not shield.type:
shield.type = ResourceType.shield.value
await self.dist_registry.register(shield)
logger.info(
f"Registered {len(shields_response.data)} shields from provider {provider_id}"
)
else:
logger.info(f"No shields found for provider {provider_id}")
else:
logger.info(
f"Provider {provider_id} does not support listing shields"
)
except Exception as e:
# Log the error but continue initialization
logger.warning(
f"Error listing shields from provider {provider_id}: {str(e)}"
)
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
return ListShieldsResponse(
data=await self.get_all_with_type(ResourceType.shield.value)
)
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
@ -368,14 +428,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
raise ValueError(
"No provider available. Please configure a vector_io provider."
)
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
raise ValueError(
f"Model {embedding_model} does not have an embedding dimension"
)
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
@ -397,7 +461,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
return ListDatasetsResponse(
data=await self.get_all_with_type(ResourceType.dataset.value)
)
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
@ -459,10 +525,14 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
return ListScoringFunctionsResponse(
data=await self.get_all_with_type(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
scoring_fn = await self.get_object_by_identifier(
"scoring_function", scoring_fn_id
)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
@ -565,8 +635,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
args: Optional[Dict[str, Any]] = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(
toolgroup_id, mcp_endpoint
)
tool_host = (
ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
)
for tool_def in tool_defs:
tools.append(

View file

@ -64,4 +64,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="trustyai_fms",
pip_packages=[],
module="llama_stack.providers.remote.safety.trustyai_fms",
config_class="llama_stack.providers.remote.safety.trustyai_fms.config.FMSSafetyProviderConfig",
),
),
]

View file

@ -0,0 +1,119 @@
import logging
from typing import Any, Dict, Optional, Union
# Remove register_provider import since registration is in registry/safety.py
from llama_stack.apis.safety import Safety
from llama_stack.providers.remote.safety.trustyai_fms.config import (
ChatDetectorConfig,
ContentDetectorConfig,
DetectorParams,
EndpointType,
FMSSafetyProviderConfig,
)
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
BaseDetector,
DetectorProvider,
)
from llama_stack.providers.remote.safety.trustyai_fms.detectors.chat import ChatDetector
from llama_stack.providers.remote.safety.trustyai_fms.detectors.content import (
ContentDetector,
)
# Set up logging
logger = logging.getLogger(__name__)
# Type aliases for better readability
ConfigType = Union[ContentDetectorConfig, ChatDetectorConfig, FMSSafetyProviderConfig]
DetectorType = Union[BaseDetector, DetectorProvider]
class DetectorConfigError(ValueError):
"""Raised when detector configuration is invalid"""
pass
async def create_fms_provider(config: Dict[str, Any]) -> Safety:
"""Create FMS safety provider instance.
Args:
config: Configuration dictionary
Returns:
Safety: Configured FMS safety provider
"""
logger.debug("Creating trustyai-fms provider")
return await get_adapter_impl(FMSSafetyProviderConfig(**config))
async def get_adapter_impl(
config: Union[Dict[str, Any], FMSSafetyProviderConfig],
_deps: Optional[Dict[str, Any]] = None,
) -> DetectorType:
"""Get appropriate detector implementation(s) based on config type.
Args:
config: Configuration dictionary or FMSSafetyProviderConfig instance
_deps: Optional dependencies for testing/injection
Returns:
Configured detector implementation
Raises:
DetectorConfigError: If configuration is invalid
"""
try:
if isinstance(config, FMSSafetyProviderConfig):
provider_config = config
else:
provider_config = FMSSafetyProviderConfig(**config)
detectors: Dict[str, DetectorType] = {}
# Changed from provider_config.detectors to provider_config.shields
for shield_id, shield_config in provider_config.shields.items():
impl: BaseDetector
if isinstance(shield_config, ChatDetectorConfig):
impl = ChatDetector(shield_config)
elif isinstance(shield_config, ContentDetectorConfig):
impl = ContentDetector(shield_config)
else:
raise DetectorConfigError(
f"Invalid shield config type for {shield_id}: {type(shield_config)}"
)
await impl.initialize()
detectors[shield_id] = impl
detectors_for_provider: Dict[str, BaseDetector] = {}
for shield_id, detector in detectors.items():
if isinstance(detector, BaseDetector):
detectors_for_provider[shield_id] = detector
return DetectorProvider(detectors_for_provider)
except Exception as e:
raise DetectorConfigError(
f"Failed to create detector implementation: {str(e)}"
) from e
__all__ = [
# Factory methods
"get_adapter_impl",
"create_fms_provider",
# Configurations
"ContentDetectorConfig",
"ChatDetectorConfig",
"FMSSafetyProviderConfig",
"EndpointType",
"DetectorParams",
# Implementations
"ChatDetector",
"ContentDetector",
"BaseDetector",
"DetectorProvider",
# Types
"ConfigType",
"DetectorType",
"DetectorConfigError",
]

View file

@ -0,0 +1,562 @@
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import urlparse
from pydantic import BaseModel, Field, model_validator
from llama_stack.schema_utils import json_schema_type
class MessageType(Enum):
"""Valid message types for detectors"""
USER = "user"
SYSTEM = "system"
TOOL = "tool"
COMPLETION = "completion"
@classmethod
def as_set(cls) -> Set[str]:
"""Get all valid message types as a set"""
return {member.value for member in cls}
class EndpointType(Enum):
"""API endpoint types and their paths"""
DIRECT_CONTENT = {
"path": "/api/v1/text/contents",
"version": "v1",
"type": "content",
}
DIRECT_CHAT = {"path": "/api/v1/text/chat", "version": "v1", "type": "chat"}
ORCHESTRATOR_CONTENT = {
"path": "/api/v2/text/detection/content",
"version": "v2",
"type": "content",
}
ORCHESTRATOR_CHAT = {
"path": "/api/v2/text/detection/chat",
"version": "v2",
"type": "chat",
}
@classmethod
def get_endpoint(cls, is_orchestrator: bool, is_chat: bool) -> EndpointType:
"""Get appropriate endpoint based on configuration"""
if is_orchestrator:
return cls.ORCHESTRATOR_CHAT if is_chat else cls.ORCHESTRATOR_CONTENT
return cls.DIRECT_CHAT if is_chat else cls.DIRECT_CONTENT
@json_schema_type
@dataclass
class DetectorParams:
"""Flexible parameter container supporting nested structure and arbitrary parameters"""
# Store all parameters in a single dictionary for maximum flexibility
params: Dict[str, Any] = field(default_factory=dict)
# Parameter categories for organization
model_params: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
kwargs: Dict[str, Any] = field(default_factory=dict)
# Store detectors directly as an attribute (not in params) for orchestrator mode
_raw_detectors: Optional[Dict[str, Dict[str, Any]]] = None
# Standard parameters kept for backward compatibility
@property
def regex(self) -> Optional[List[str]]:
return self.params.get("regex")
@regex.setter
def regex(self, value: List[str]) -> None:
self.params["regex"] = value
@property
def temperature(self) -> Optional[float]:
return self.model_params.get("temperature") or self.params.get("temperature")
@temperature.setter
def temperature(self, value: float) -> None:
self.model_params["temperature"] = value
@property
def risk_name(self) -> Optional[str]:
return self.metadata.get("risk_name") or self.params.get("risk_name")
@risk_name.setter
def risk_name(self, value: str) -> None:
self.metadata["risk_name"] = value
@property
def risk_definition(self) -> Optional[str]:
return self.metadata.get("risk_definition") or self.params.get(
"risk_definition"
)
@risk_definition.setter
def risk_definition(self, value: str) -> None:
self.metadata["risk_definition"] = value
@property
def orchestrator_detectors(self) -> Dict[str, Dict[str, Any]]:
"""Return detectors in the format required by orchestrator API"""
if (
not hasattr(self, "_detectors") or not self._detectors
): # Direct attribute access
return {}
flattened = {}
for (
detector_id,
detector_config,
) in self._detectors.items(): # Direct attribute access
# Create a flattened version without extra nesting
flat_config = {}
# Extract detector_params if present and flatten them
params = detector_config.get("detector_params", {})
if isinstance(params, dict):
# Move params up to top level
for key, value in params.items():
flat_config[key] = value
flattened[detector_id] = flat_config
return flattened
@property
def formatted_detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
"""Return detectors properly formatted for orchestrator API"""
# Direct return for API usage - avoid calling other properties
if hasattr(self, "_detectors") and self._detectors:
return self.orchestrator_detectors
return None
@formatted_detectors.setter
def formatted_detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
self._detectors = value
@property
def detectors(self) -> Optional[Dict[str, Dict[str, Any]]]:
"""COMPATIBILITY: Returns the same as formatted_detectors to maintain API compatibility"""
# Using a different implementation to avoid the redefinition error
# while maintaining the same functionality
if not hasattr(self, "_detectors") or not self._detectors:
return None
return self.orchestrator_detectors
@detectors.setter
def detectors(self, value: Dict[str, Dict[str, Any]]) -> None:
"""COMPATIBILITY: Set detectors while maintaining compatibility"""
self._detectors = value
# And fix the __setitem__ method:
def __setitem__(self, key: str, value: Any) -> None:
"""Allow dictionary-like assignment with smart categorization"""
# Special handling for known params
if key == "detectors":
self._detectors = value # Set underlying attribute directly
return
elif key == "regex":
self.params[key] = value
return
# Rest of the method remains unchanged
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
if key in known_model_params:
self.model_params[key] = value
elif key in known_metadata:
self.metadata[key] = value
else:
self.kwargs[key] = value
def __init__(self, **kwargs):
"""Initialize from any keyword arguments with smart categorization"""
# Initialize containers
self.params = {}
self.model_params = {}
self.metadata = {}
self.kwargs = {}
self._raw_detectors = None
# Special handling for nested detectors structure
if "detectors" in kwargs:
self._raw_detectors = kwargs.pop("detectors")
# Special handling for regex
if "regex" in kwargs:
self.params["regex"] = kwargs.pop("regex")
# Categorize known parameters
known_model_params = ["temperature", "top_p", "top_k", "max_tokens", "n"]
known_metadata = ["risk_name", "risk_definition", "category", "severity"]
# Explicit categories if provided
if "model_params" in kwargs:
self.model_params.update(kwargs.pop("model_params"))
if "metadata" in kwargs:
self.metadata.update(kwargs.pop("metadata"))
if "kwargs" in kwargs:
self.kwargs.update(kwargs.pop("kwargs"))
# Categorize remaining parameters
for key, value in kwargs.items():
if key in known_model_params:
self.model_params[key] = value
elif key in known_metadata:
self.metadata[key] = value
else:
self.kwargs[key] = value
def __getitem__(self, key: str) -> Any:
"""Allow dictionary-like access with category lookup"""
if key in self.params:
return self.params[key]
elif key in self.model_params:
return self.model_params[key]
elif key in self.metadata:
return self.metadata[key]
elif key in self.kwargs:
return self.kwargs[key]
return None
def get(self, key: str, default: Any = None) -> Any:
"""Dictionary-style get with category lookup"""
result = self.__getitem__(key)
return default if result is None else result
def set(self, key: str, value: Any) -> None:
"""Set a parameter value with smart categorization"""
self.__setitem__(key, value)
def update(self, params: Dict[str, Any]) -> None:
"""Update with multiple parameters, respecting categories"""
for key, value in params.items():
self.__setitem__(key, value)
def to_dict(self) -> Dict[str, Any]:
"""Convert all parameters to a flat dictionary for API requests"""
result = {}
# Add core parameters
result.update(self.params)
# Add all categorized parameters, flattened
result.update(self.model_params)
result.update(self.metadata)
result.update(self.kwargs)
return result
def to_categorized_dict(self) -> Dict[str, Any]:
"""Convert to a structured dictionary with categories preserved"""
result = dict(self.params)
if self.model_params:
result["model_params"] = dict(self.model_params)
if self.metadata:
result["metadata"] = dict(self.metadata)
if self.kwargs:
result["kwargs"] = dict(self.kwargs)
return result
def create_flattened_detector_configs(self) -> Dict[str, Dict[str, Any]]:
"""Create flattened detector configurations for orchestrator mode.
This removes the extra detector_params nesting that causes API errors.
"""
if not self.detectors:
return {}
flattened = {}
for detector_id, detector_config in self.detectors.items():
# Create a flattened version without extra nesting
flat_config = {}
# Extract detector_params if present and flatten them
params = detector_config.get("detector_params", {})
if isinstance(params, dict):
# Move params up to top level
for key, value in params.items():
flat_config[key] = value
flattened[detector_id] = flat_config
return flattened
def validate(self) -> None:
"""Validate parameter values"""
if self.temperature is not None and not 0 <= self.temperature <= 1:
raise ValueError("Temperature must be between 0 and 1")
@json_schema_type
@dataclass
class BaseDetectorConfig:
"""Base configuration for all detectors with flexible parameter handling"""
detector_id: str
confidence_threshold: float = 0.5
message_types: Set[str] = field(default_factory=lambda: MessageType.as_set())
auth_token: Optional[str] = None
detector_params: Optional[DetectorParams] = None
# URL fields directly on detector configs
detector_url: Optional[str] = None
orchestrator_url: Optional[str] = None
# Flexible storage for any additional parameters
_extra_params: Dict[str, Any] = field(default_factory=dict)
# Runtime execution parameters
max_concurrency: int = 10 # Maximum concurrent API requests
request_timeout: float = 30.0 # HTTP request timeout in seconds
max_retries: int = 3 # Maximum number of retry attempts
backoff_factor: float = 1.5 # Exponential backoff multiplier
max_keepalive_connections: int = 5 # Max number of keepalive connections
max_connections: int = 10 # Max number of connections in the pool
@property
def use_orchestrator_api(self) -> bool:
"""Determine if orchestrator API should be used"""
return bool(self.orchestrator_url)
def __post_init__(self) -> None:
"""Process configuration after initialization"""
# Convert list/tuple message_types to set
if isinstance(self.message_types, (list, tuple)):
self.message_types = set(self.message_types)
# Validate message types
invalid_types = self.message_types - MessageType.as_set()
if invalid_types:
raise ValueError(
f"Invalid message types: {invalid_types}. "
f"Valid types are: {MessageType.as_set()}"
)
# Initialize detector_params if needed
if self.detector_params is None:
self.detector_params = DetectorParams()
# Handle legacy URL field names
if hasattr(self, "base_url") and self.base_url and not self.detector_url:
self.detector_url = self.base_url
if (
hasattr(self, "orchestrator_base_url")
and self.orchestrator_base_url
and not self.orchestrator_url
):
self.orchestrator_url = self.orchestrator_base_url
def validate(self) -> None:
"""Validate configuration"""
# Validate detector_params
if self.detector_params:
self.detector_params.validate()
# Validate that at least one URL is provided
if not self.detector_url and not self.orchestrator_url:
raise ValueError(f"No URL provided for detector {self.detector_id}")
# Validate URLs if present
for url_name, url in [
("detector_url", self.detector_url),
("orchestrator_url", self.orchestrator_url),
]:
if url:
self._validate_url(url, url_name)
def _validate_url(self, url: str, url_name: str) -> None:
"""Validate URL format"""
parsed = urlparse(url)
if not all([parsed.scheme, parsed.netloc]):
raise ValueError(f"Invalid {url_name} format: {url}")
if parsed.scheme not in {"http", "https"}:
raise ValueError(f"Invalid {url_name} scheme: {parsed.scheme}")
def get(self, key: str, default: Any = None) -> Any:
"""Get parameter with fallback to extra parameters"""
try:
return getattr(self, key)
except AttributeError:
return self._extra_params.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set parameter, storing in extra_params if not a standard field"""
if hasattr(self, key) and key not in ["_extra_params"]:
setattr(self, key, value)
else:
self._extra_params[key] = value
@property
def is_chat(self) -> bool:
"""Default implementation, should be overridden by subclasses"""
return False
@json_schema_type
@dataclass
class ContentDetectorConfig(BaseDetectorConfig):
"""Configuration for content detectors"""
@property
def is_chat(self) -> bool:
"""Content detectors are not chat detectors"""
return False
@json_schema_type
@dataclass
class ChatDetectorConfig(BaseDetectorConfig):
"""Configuration for chat detectors"""
@property
def is_chat(self) -> bool:
"""Chat detectors are chat detectors"""
return True
@json_schema_type
class FMSSafetyProviderConfig(BaseModel):
"""Configuration for the FMS Safety Provider"""
shields: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
# Rename _detectors to remove the leading underscore
detectors_internal: Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]] = (
Field(default_factory=dict, exclude=True)
)
# Provider-level orchestrator URL (can be copied to shields if needed)
orchestrator_url: Optional[str] = None
class Config:
arbitrary_types_allowed = True
# Add a model validator to replace __post_init__
@model_validator(mode="after")
def setup_config(self):
"""Process shield configurations"""
# Process shields into detector objects
self._process_shields()
# Replace shields dictionary with processed detector configs
self.shields = self.detectors_internal
# Validate all shields
for shield in self.shields.values():
shield.validate()
return self
def _process_shields(self):
"""Process all shield configurations into detector configs"""
for shield_id, config in self.shields.items():
if isinstance(config, dict):
# Copy the config to avoid modifying the original
shield_config = dict(config)
# Check if this shield has nested detectors
nested_detectors = shield_config.pop("detectors", None)
# Determine detector type
detector_type = shield_config.pop("type", None)
is_chat = (
detector_type == "chat"
if detector_type
else shield_config.pop("is_chat", False)
)
# Set detector ID
shield_config["detector_id"] = shield_id
# Handle URL fields
# First handle legacy field names
if "base_url" in shield_config and "detector_url" not in shield_config:
shield_config["detector_url"] = shield_config.pop("base_url")
if (
"orchestrator_base_url" in shield_config
and "orchestrator_url" not in shield_config
):
shield_config["orchestrator_url"] = shield_config.pop(
"orchestrator_base_url"
)
# If no orchestrator_url in shield but provider has one, copy it
if self.orchestrator_url and "orchestrator_url" not in shield_config:
shield_config["orchestrator_url"] = self.orchestrator_url
# Initialize detector_params with proper structure for nested detectors
detector_params_dict = shield_config.get("detector_params", {})
if not isinstance(detector_params_dict, dict):
detector_params_dict = {}
# Create detector_params object
detector_params = DetectorParams(**detector_params_dict)
# Add nested detectors if present
if nested_detectors:
detector_params.detectors = nested_detectors
shield_config["detector_params"] = detector_params
# Create appropriate detector config
detector_class = (
ChatDetectorConfig if is_chat else ContentDetectorConfig
)
self.detectors_internal[shield_id] = detector_class(**shield_config)
@property
def all_detectors(
self,
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
"""Get all detector configurations"""
return self.detectors_internal
# Update other methods to use detectors_internal instead of _detectors
def get_detectors_by_type(
self, message_type: Union[str, MessageType]
) -> Dict[str, Union[ContentDetectorConfig, ChatDetectorConfig]]:
"""Get detectors for a specific message type"""
type_value = (
message_type.value
if isinstance(message_type, MessageType)
else message_type
)
return {
shield_id: shield
for shield_id, shield in self.detectors_internal.items()
if type_value in shield.message_types
}
# Convenience properties
@property
def user_message_detectors(self):
return self.get_detectors_by_type(MessageType.USER)
@property
def system_message_detectors(self):
return self.get_detectors_by_type(MessageType.SYSTEM)
@property
def tool_response_detectors(self):
return self.get_detectors_by_type(MessageType.TOOL)
@property
def completion_message_detectors(self):
return self.get_detectors_by_type(MessageType.COMPLETION)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,329 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, cast
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.providers.remote.safety.trustyai_fms.config import ChatDetectorConfig
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
BaseDetector,
DetectionResult,
DetectorError,
DetectorRequestError,
DetectorValidationError,
)
# Type aliases for better readability
ChatMessage = Dict[
str, Any
] # Changed from Dict[str, str] to Dict[str, Any] to handle complex content
ChatRequest = Dict[str, Any]
DetectorResponse = List[Dict[str, Any]]
logger = logging.getLogger(__name__)
class ChatDetectorError(DetectorError):
"""Specific errors for chat detector operations"""
pass
@dataclass(frozen=True)
class ChatDetectionMetadata:
"""Structured metadata for chat detections"""
risk_name: Optional[str] = None
risk_definition: Optional[str] = None
additional_metadata: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert metadata to dictionary format"""
result: Dict[str, Any] = {} # Fixed the type annotation here
if self.risk_name:
result["risk_name"] = self.risk_name
if self.risk_definition:
result["risk_definition"] = self.risk_definition
if self.additional_metadata:
result["metadata"] = self.additional_metadata
return result
class ChatDetector(BaseDetector):
"""Detector for chat-based safety checks"""
def __init__(self, config: ChatDetectorConfig) -> None:
"""Initialize chat detector with configuration"""
if not isinstance(config, ChatDetectorConfig):
raise DetectorValidationError(
"Config must be an instance of ChatDetectorConfig"
)
super().__init__(config)
self.config: ChatDetectorConfig = config
logger.info(f"Initialized ChatDetector with config: {vars(config)}")
def _extract_detector_params(self) -> Dict[str, Any]:
"""Extract non-null detector parameters"""
if not self.config.detector_params:
return {}
params = {
k: v for k, v in vars(self.config.detector_params).items() if v is not None
}
logger.debug(f"Extracted detector params: {params}")
return params
def _prepare_chat_request(
self, messages: List[ChatMessage], params: Optional[Dict[str, Any]] = None
) -> ChatRequest:
"""Prepare the request based on API mode"""
# Format messages for detector API
formatted_messages: List[Dict[str, str]] = [] # Explicitly typed
for msg in messages:
formatted_msg = {
"content": str(msg.get("content", "")), # Ensure string type
"role": "user", # Always send as user for detector API
}
formatted_messages.append(formatted_msg)
if self.config.use_orchestrator_api:
payload: Dict[str, Any] = {
"messages": formatted_messages
} # Explicitly typed
# Initialize detector_config to avoid None
detector_config: Dict[str, Any] = {} # Explicitly typed
# NEW STRUCTURE: Check for top-level detectors first
if hasattr(self.config, "detectors") and self.config.detectors:
for detector_id, det_config in self.config.detectors.items():
detector_config[detector_id] = det_config.get("detector_params", {})
# LEGACY STRUCTURE: Check for nested detectors
elif (
self.config.detector_params
and hasattr(self.config.detector_params, "detectors")
and self.config.detector_params.detectors
):
detector_config = self.config.detector_params.detectors
# Handle flat params - group them into generic containers
elif self.config.detector_params:
detector_params = self._extract_detector_params()
# Create a flat dictionary of parameters
flat_params = {}
# Extract from model_params
if "model_params" in detector_params and isinstance(
detector_params["model_params"], dict
):
flat_params.update(detector_params["model_params"])
# Extract from metadata
if "metadata" in detector_params and isinstance(
detector_params["metadata"], dict
):
flat_params.update(detector_params["metadata"])
# Extract from kwargs
if "kwargs" in detector_params and isinstance(
detector_params["kwargs"], dict
):
flat_params.update(detector_params["kwargs"])
# Add any other direct parameters, but skip container dictionaries
for k, v in detector_params.items():
if (
k not in ["model_params", "metadata", "kwargs", "params"]
and v is not None
):
flat_params[k] = v
# Add all flattened parameters directly to detector configuration
detector_config[self.config.detector_id] = flat_params
# Ensure we have a valid detectors map even if all checks fail
if not detector_config:
detector_config = {self.config.detector_id: {}}
payload["detectors"] = detector_config
return payload
# Direct API format remains unchanged
else:
# DIRECT MODE: Use flat parameters for API compatibility
# Don't organize into containers for direct mode
detector_params = self._extract_detector_params()
# Flatten the parameters for direct mode too
flat_params = {}
# Extract from model_params
if "model_params" in detector_params and isinstance(
detector_params["model_params"], dict
):
flat_params.update(detector_params["model_params"])
# Extract from metadata
if "metadata" in detector_params and isinstance(
detector_params["metadata"], dict
):
flat_params.update(detector_params["metadata"])
# Extract from kwargs
if "kwargs" in detector_params and isinstance(
detector_params["kwargs"], dict
):
flat_params.update(detector_params["kwargs"])
# Add any other direct parameters
for k, v in detector_params.items():
if (
k not in ["model_params", "metadata", "kwargs", "params"]
and v is not None
):
flat_params[k] = v
return {
"messages": formatted_messages,
"detector_params": flat_params if flat_params else params or {},
}
async def _call_detector_api(
self,
messages: List[ChatMessage],
params: Optional[Dict[str, Any]] = None,
) -> DetectorResponse:
"""Call chat detector API with proper endpoint selection"""
try:
request = self._prepare_chat_request(messages, params)
headers = self._prepare_headers()
logger.info("Making detector API request")
logger.debug(f"Request headers: {headers}")
logger.debug(f"Request payload: {request}")
response = await self._make_request(request, headers)
return self._extract_detections(response)
except Exception as e:
logger.error(f"API call failed: {str(e)}", exc_info=True)
raise DetectorRequestError(
f"Chat detector API call failed: {str(e)}"
) from e
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
"""Extract detections from API response"""
if not response:
logger.debug("Empty response received")
return []
if self.config.use_orchestrator_api:
detections = response.get("detections", [])
if not detections:
# Add default detection when none returned
logger.debug("No detections found, adding default low-score detection")
return [
{
"detection_type": "risk",
"detection": "No",
"detector_id": self.config.detector_id,
"score": 0.0, # Default low score
}
]
logger.debug(f"Orchestrator detections: {detections}")
return cast(
DetectorResponse, detections
) # Explicit cast to correct return type
# Direct API returns a list where first item contains detections
if isinstance(response, list) and response:
detections = (
[response[0]] if not isinstance(response[0], list) else response[0]
)
logger.debug(f"Direct API detections: {detections}")
return cast(
DetectorResponse, detections
) # Explicit cast to correct return type
logger.debug("No detections found in response")
return []
def _process_detection(
self, detection: Dict[str, Any]
) -> Tuple[
Optional[DetectionResult], float
]: # Changed return type to match base class
"""Process detection result and validate against threshold"""
score = detection.get("score", 0.0) # Default to 0.0 if score is missing
if score > self.score_threshold:
metadata = ChatDetectionMetadata(
risk_name=(
self.config.detector_params.risk_name
if self.config.detector_params
else None
),
risk_definition=(
self.config.detector_params.risk_definition
if self.config.detector_params
else None
),
additional_metadata=detection.get("metadata"),
)
result = DetectionResult(
detection="Yes",
detection_type=detection["detection_type"],
score=score,
detector_id=detection.get("detector_id", self.config.detector_id),
text=detection.get("text", ""),
start=detection.get("start", 0),
end=detection.get("end", 0),
metadata=metadata.to_dict(),
)
return (result, score)
return (None, score)
async def _run_shield_impl(
self,
shield_id: str,
messages: List[Message],
params: Optional[Dict[str, Any]] = None,
) -> RunShieldResponse:
"""Implementation of shield checks for chat messages"""
try:
shield = await self.shield_store.get_shield(shield_id)
self._validate_shield(shield)
logger.info(f"Processing {len(messages)} message(s)")
# Convert messages keeping only necessary fields
chat_messages: List[ChatMessage] = [] # Explicitly typed
for msg in messages:
message_dict: ChatMessage = {"content": msg.content, "role": msg.role}
# Preserve type if present for internal processing
if hasattr(msg, "type"):
message_dict["type"] = msg.type
chat_messages.append(message_dict)
logger.debug(f"Prepared messages: {chat_messages}")
detections = await self._call_detector_api(chat_messages, params)
for detection in detections:
processed, score = self._process_detection(detection)
if processed:
logger.info(f"Violation detected: {processed}")
return self.create_violation_response(
processed, detection.get("detector_id", self.config.detector_id)
)
logger.debug("No violations detected")
return RunShieldResponse()
except Exception as e:
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
raise ChatDetectorError(f"Shield execution failed: {str(e)}") from e

View file

@ -0,0 +1,209 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, cast
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.providers.remote.safety.trustyai_fms.config import (
ContentDetectorConfig,
)
from llama_stack.providers.remote.safety.trustyai_fms.detectors.base import (
BaseDetector,
DetectionResult,
DetectorError,
DetectorRequestError,
DetectorValidationError,
)
# Type aliases for better readability
ContentRequest = Dict[str, Any]
DetectorResponse = List[Dict[str, Any]]
logger = logging.getLogger(__name__)
class ContentDetectorError(DetectorError):
"""Specific errors for content detector operations"""
pass
class ContentDetector(BaseDetector):
"""Detector for content-based safety checks"""
def __init__(self, config: ContentDetectorConfig) -> None:
"""Initialize content detector with configuration"""
if not isinstance(config, ContentDetectorConfig):
raise DetectorValidationError(
"Config must be an instance of ContentDetectorConfig"
)
super().__init__(config)
self.config: ContentDetectorConfig = config
logger.info(f"Initialized ContentDetector with config: {vars(config)}")
def _extract_detector_params(self) -> Dict[str, Any]:
"""Extract detector parameters with support for generic format"""
if not self.config.detector_params:
return {}
# Use to_dict() to flatten our categorized structure into what the API expects
params = self.config.detector_params.to_dict()
logger.debug(f"Extracted detector params: {params}")
return params
def _prepare_content_request(
self, content: str, params: Optional[Dict[str, Any]] = None
) -> ContentRequest:
"""Prepare the request based on API mode"""
if self.config.use_orchestrator_api:
payload: Dict[str, Any] = {"content": content} # Always use singular form
# NEW STRUCTURE: Check for top-level detectors first
if hasattr(self.config, "detectors") and self.config.detectors:
detector_config: Dict[str, Any] = {}
for detector_id, det_config in self.config.detectors.items():
detector_config[detector_id] = det_config.get("detector_params", {})
payload["detectors"] = detector_config
return payload
# LEGACY STRUCTURE: Check for nested detectors
elif self.config.detector_params and hasattr(
self.config.detector_params, "detectors"
):
detectors = getattr(self.config.detector_params, "detectors", {})
payload["detectors"] = detectors
return payload
# Handle flat params
else:
detector_config = {}
detector_params = self._extract_detector_params()
if detector_params:
detector_config[self.config.detector_id] = detector_params
payload["detectors"] = detector_config
return payload
else:
# DIRECT MODE: Use flat parameters for API compatibility
detector_params = self._extract_detector_params()
return {
"contents": [content],
"detector_params": detector_params if detector_params else params or {},
}
def _extract_detections(self, response: Dict[str, Any]) -> DetectorResponse:
"""Extract detections from API response"""
if not response:
logger.debug("Empty response received")
return []
if self.config.use_orchestrator_api:
detections = response.get("detections", [])
logger.debug(f"Orchestrator detections: {detections}")
return cast(List[Dict[str, Any]], detections)
# Direct API returns a list of lists where inner list contains detections
if isinstance(response, list) and response:
detections = response[0] if isinstance(response[0], list) else [response[0]]
logger.debug(f"Direct API detections: {detections}")
return cast(List[Dict[str, Any]], detections)
logger.debug("No detections found in response")
return []
async def _call_detector_api(
self,
content: str,
params: Optional[Dict[str, Any]] = None,
) -> DetectorResponse:
"""Call detector API with proper endpoint selection"""
try:
request = self._prepare_content_request(content, params)
headers = self._prepare_headers()
logger.info("Making detector API request")
logger.debug(f"Request headers: {headers}")
logger.debug(f"Request payload: {request}")
response = await self._make_request(request, headers)
return self._extract_detections(response)
except Exception as e:
logger.error(f"API call failed: {str(e)}", exc_info=True)
raise DetectorRequestError(
f"Content detector API call failed: {str(e)}"
) from e
def _process_detection(
self, detection: Dict[str, Any]
) -> tuple[Optional[DetectionResult], float]:
"""Process detection result and validate against threshold"""
if not detection.get("score"):
logger.warning("Detection missing score field")
return None, 0.0
score = detection.get("score", 0)
if score > self.score_threshold:
result = DetectionResult(
detection="Yes",
detection_type=detection["detection_type"],
score=score,
detector_id=detection.get("detector_id", self.config.detector_id),
text=detection.get("text", ""),
start=detection.get("start", 0),
end=detection.get("end", 0),
metadata=detection.get("metadata", {}),
)
return result, score
return None, score
async def _run_shield_impl(
self,
shield_id: str,
messages: List[Message],
params: Optional[Dict[str, Any]] = None,
) -> RunShieldResponse:
"""Implementation of shield checks for content messages"""
try:
shield = await self.shield_store.get_shield(shield_id)
self._validate_shield(shield)
for msg in messages:
content = msg.content
content_str: str
# Simplified content handling - just check for text attribute or convert to string
if hasattr(content, "text"):
content_str = str(content.text)
elif isinstance(content, list):
content_str = " ".join(
str(getattr(item, "text", "")) for item in content
)
else:
content_str = str(content)
truncated_content = (
content_str[:100] + "..." if len(content_str) > 100 else content_str
)
logger.debug(f"Checking content: {truncated_content}")
detections = await self._call_detector_api(content_str, params)
for detection in detections:
processed, score = self._process_detection(detection)
if processed:
logger.info(f"Violation detected: {processed}")
return self.create_violation_response(
processed,
detection.get("detector_id", self.config.detector_id),
)
logger.debug("No violations detected")
return RunShieldResponse()
except Exception as e:
logger.error(f"Shield execution failed: {str(e)}", exc_info=True)
raise ContentDetectorError(f"Shield execution failed: {str(e)}") from e