diff --git a/llama_stack/providers/adapters/safety/together/safety.py b/llama_stack/providers/adapters/safety/together/safety.py index f89d361e9..337d3332c 100644 --- a/llama_stack/providers/adapters/safety/together/safety.py +++ b/llama_stack/providers/adapters/safety/together/safety.py @@ -7,39 +7,40 @@ import pydantic from together import Together import asyncio + +from llama_stack.distribution.request_headers import get_request_provider_data from .config import TogetherSafetyConfig from llama_stack.apis.safety import * import logging +class TogetherHeaderInfo(BaseModel): + together_api_key: str + class TogetherSafetyImpl(Safety): def __init__(self, config: TogetherSafetyConfig) -> None: self.config = config - self._client = None - - @property - def client(self) -> Together: - if self._client == None: - self._client = Together(api_key=self.config.api_key) - return self._client - - @client.setter - def client(self, client: Together) -> None: - self._client = client - async def initialize(self) -> None: pass - async def run_shields( + async def run_shield( self, + shield_type: str, messages: List[Message], - shields: List[ShieldDefinition], + params: Dict[str, Any] = None, ) -> RunShieldResponse: # support only llama guard shield - for shield in shields: - if not isinstance(shield.shield_type, BuiltinShield) or shield.shield_type != BuiltinShield.llama_guard: - raise ValueError(f"shield type {shield.shield_type} is not supported") + + if shield_type != "llama_guard": + raise ValueError(f"shield type {shield_type} is not supported") + + provider_data = get_request_provider_data() + together_api_key = self.config.api_key + # @TODO error out if together_api_key is missing in the header + if provider_data is not None: + if isinstance(provider_data, TogetherHeaderInfo): + together_api_key = provider_data.together_api_key # messages can have role assistant or user api_messages = [] @@ -50,32 +51,26 @@ class TogetherSafetyImpl(Safety): raise ValueError(f"role {message.role} is not supported") # construct Together request - responses = await asyncio.gather(*[get_safety_response(self.client, api_messages)]) - return RunShieldResponse(responses=responses) + response = await asyncio.run(get_safety_response(together_api_key, api_messages)) + return RunShieldResponse(violation=response) -async def get_safety_response(client: Together, messages: List[Dict[str, str]]) -> Optional[ShieldResponse]: +async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]: + client = Together(api_key=api_key) response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B") if len(response.choices) == 0: - return ShieldResponse(shield_type=BuiltinShield.llama_guard, is_violation=False) + return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe") response_text = response.choices[0].message.content if response_text == 'safe': - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, - is_violation=False, - ) + return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe") else: parts = response_text.split("\n") if not len(parts) == 2: return None if parts[0] == 'unsafe': - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, - is_violation=True, - violation_type=parts[1], - violation_return_message="Sorry, I cannot do that" - ) + SafetyViolation(violation_level=ViolationLevel.WARN, user_message="unsafe", + metadata={"violation_type": parts[1]}) return None diff --git a/llama_stack/providers/adapters/safety/together/test_safety.py b/llama_stack/providers/adapters/safety/together/test_safety.py deleted file mode 100644 index e8ba46846..000000000 --- a/llama_stack/providers/adapters/safety/together/test_safety.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -from unittest.mock import Mock, AsyncMock -from llama_stack.apis.safety import UserMessage, ShieldDefinition, ShieldType, RunShieldResponse, ShieldResponse, \ - BuiltinShield -from llama_stack.providers.adapters.safety.together import TogetherSafetyImpl, TogetherSafetyConfig - -@pytest.fixture -def safety_config(): - return TogetherSafetyConfig(api_key="test_api_key") - -@pytest.fixture -def safety_impl(safety_config): - return TogetherSafetyImpl(safety_config) - -@pytest.mark.asyncio -async def test_initialize(safety_impl): - await safety_impl.initialize() - # Add assertions if needed for initialization - -@pytest.mark.asyncio -async def test_run_shields_safe(safety_impl, monkeypatch): - # Mock the Together client - mock_client = Mock() - mock_response = Mock() - mock_response.choices = [Mock(message=Mock(role="assistant", content="safe"))] - mock_client.chat.completions.create.return_value = mock_response - monkeypatch.setattr(safety_impl, 'client', mock_client) - - messages = [UserMessage(role="user", content="Hello, world!")] - shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] - - response = await safety_impl.run_shields(messages, shields) - - assert isinstance(response, RunShieldResponse) - assert len(response.responses) == 1 - assert response.responses[0].is_violation == False - assert response.responses[0].shield_type == BuiltinShield.llama_guard - -@pytest.mark.asyncio -async def test_run_shields_unsafe(safety_impl, monkeypatch): - # Mock the Together client - mock_client = Mock() - mock_response = Mock() - mock_response.choices = [Mock(message=Mock(role="assistant", content="unsafe\ns2"))] - mock_client.chat.completions.create.return_value = mock_response - monkeypatch.setattr(safety_impl, 'client', mock_client) - - messages = [UserMessage(role="user", content="Unsafe content")] - shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] - - response = await safety_impl.run_shields(messages, shields) - - assert isinstance(response, RunShieldResponse) - assert len(response.responses) == 1 - assert response.responses[0].is_violation == True - assert response.responses[0].shield_type == BuiltinShield.llama_guard - assert response.responses[0].violation_type == "s2" - -@pytest.mark.asyncio -async def test_run_shields_unsupported_shield(safety_impl): - messages = [UserMessage(role="user", content="Hello")] - shields = [ShieldDefinition(shield_type="unsupported_shield")] - - with pytest.raises(ValueError, match="shield type unsupported_shield is not supported"): - await safety_impl.run_shields(messages, shields) - -@pytest.mark.asyncio -async def test_run_shields_unsupported_message_type(safety_impl): - class UnsupportedMessage: - role = "unsupported" - content = "Hello" - - messages = [UnsupportedMessage()] - shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] - - with pytest.raises(ValueError, match="role unsupported is not supported"): - await safety_impl.run_shields(messages, shields) - - -@pytest.mark.asyncio -@pytest.mark.integtest -@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run") -async def test_actual_run(): - safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="")) - await safety_impl.initialize() - response = await safety_impl.run_shields([UserMessage(role="user", content="Hello")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]) - - assert isinstance(response, RunShieldResponse) - assert len(response.responses) == 1 - assert response.responses[0].is_violation == False - assert response.responses[0].shield_type == BuiltinShield.llama_guard - assert response.responses[0].violation_type == None - -@pytest.mark.asyncio -@pytest.mark.integtest -@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run") -async def test_actual_run_violation(): - safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="replace your together api key here")) - await safety_impl.initialize() - response = await safety_impl.run_shields([UserMessage(role="user", content="can I kill you?")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]) - - assert isinstance(response, RunShieldResponse) - assert len(response.responses) == 1 - assert response.responses[0].is_violation == True - assert response.responses[0].shield_type == BuiltinShield.llama_guard - assert response.responses[0].violation_type == "S1" \ No newline at end of file