From 9f3300df2581ce0d77aaa430cd8b1aa5faa3d5f4 Mon Sep 17 00:00:00 2001 From: Yogish Baliga Date: Fri, 20 Sep 2024 09:35:01 -0700 Subject: [PATCH] adding safety adapter for Together --- .../templates/local-together-build.yaml | 2 +- .../adapters/safety/together/__init__.py | 18 +++ .../adapters/safety/together/config.py | 20 ++++ .../adapters/safety/together/safety.py | 83 ++++++++++++++ .../adapters/safety/together/test_safety.py | 106 ++++++++++++++++++ llama_stack/providers/registry/safety.py | 15 +++ 6 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 llama_stack/providers/adapters/safety/together/__init__.py create mode 100644 llama_stack/providers/adapters/safety/together/config.py create mode 100644 llama_stack/providers/adapters/safety/together/safety.py create mode 100644 llama_stack/providers/adapters/safety/together/test_safety.py diff --git a/llama_stack/distribution/templates/local-together-build.yaml b/llama_stack/distribution/templates/local-together-build.yaml index 1ab891518..ebf0bf1fb 100644 --- a/llama_stack/distribution/templates/local-together-build.yaml +++ b/llama_stack/distribution/templates/local-together-build.yaml @@ -4,7 +4,7 @@ distribution_spec: providers: inference: remote::together memory: meta-reference - safety: meta-reference + safety: remote::together agents: meta-reference telemetry: meta-reference image_type: conda diff --git a/llama_stack/providers/adapters/safety/together/__init__.py b/llama_stack/providers/adapters/safety/together/__init__.py new file mode 100644 index 000000000..634659558 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from .config import TogetherSafetyConfig +from .safety import TogetherSafetyImpl + + +async def get_adapter_impl(config: TogetherSafetyConfig, _deps): + from .safety import TogetherSafetyImpl + + assert isinstance( + config, TogetherSafetyConfig + ), f"Unexpected config type: {type(config)}" + impl = TogetherSafetyImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/safety/together/config.py b/llama_stack/providers/adapters/safety/together/config.py new file mode 100644 index 000000000..58ceaa8c3 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/config.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class TogetherSafetyConfig(BaseModel): + url: str = Field( + default="https://api.together.xyz/v1", + description="The URL for the Together AI server", + ) + api_key: str = Field( + default="", + description="The Together AI API Key", + ) diff --git a/llama_stack/providers/adapters/safety/together/safety.py b/llama_stack/providers/adapters/safety/together/safety.py new file mode 100644 index 000000000..b5e28098b --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/safety.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from together import Together + +import asyncio +from .config import TogetherSafetyConfig +from llama_stack.apis.safety import * +import logging + + +class TogetherSafetyImpl(Safety): + def __init__(self, config: TogetherSafetyConfig) -> None: + self.config = config + self._client = None + + @property + def client(self) -> Together: + if self._client == None: + self._client = Together(api_key=self.config.api_key) + return self._client + + @client.setter + def client(self, client: Together) -> None: + self._client = client + + + async def initialize(self) -> None: + pass + + async def run_shields( + self, + messages: List[Message], + shields: List[ShieldDefinition], + ) -> RunShieldResponse: + # support only llama guard shield + for shield in shields: + if not isinstance(shield.shield_type, BuiltinShield) or shield.shield_type != BuiltinShield.llama_guard: + raise ValueError(f"shield type {shield.shield_type} is not supported") + + # messages can have role assistant or user + api_messages = [] + for message in messages: + if type(message) is UserMessage: + api_messages.append({'role': message.role, 'content': message.content}) + else: + raise ValueError(f"role {message.role} is not supported") + + # construct Together request + responses = await asyncio.gather(*[get_safety_response(self.client, api_messages)]) + return RunShieldResponse(responses=responses) + +async def get_safety_response(client: Together, messages: List[Dict[str, str]]) -> Optional[ShieldResponse]: + response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B") + if len(response.choices) == 0: + return ShieldResponse(shield_type=BuiltinShield.llama_guard, is_violation=False) + + response_text = response.choices[0].message.content + if response_text == 'safe': + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ) + else: + parts = response_text.split("\n") + if not len(parts) == 2: + return None + + if parts[0] == 'unsafe': + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=True, + violation_type=parts[1], + violation_return_message="Sorry, I cannot do that" + ) + + return None + + + + diff --git a/llama_stack/providers/adapters/safety/together/test_safety.py b/llama_stack/providers/adapters/safety/together/test_safety.py new file mode 100644 index 000000000..e8ba46846 --- /dev/null +++ b/llama_stack/providers/adapters/safety/together/test_safety.py @@ -0,0 +1,106 @@ +import pytest +from unittest.mock import Mock, AsyncMock +from llama_stack.apis.safety import UserMessage, ShieldDefinition, ShieldType, RunShieldResponse, ShieldResponse, \ + BuiltinShield +from llama_stack.providers.adapters.safety.together import TogetherSafetyImpl, TogetherSafetyConfig + +@pytest.fixture +def safety_config(): + return TogetherSafetyConfig(api_key="test_api_key") + +@pytest.fixture +def safety_impl(safety_config): + return TogetherSafetyImpl(safety_config) + +@pytest.mark.asyncio +async def test_initialize(safety_impl): + await safety_impl.initialize() + # Add assertions if needed for initialization + +@pytest.mark.asyncio +async def test_run_shields_safe(safety_impl, monkeypatch): + # Mock the Together client + mock_client = Mock() + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(role="assistant", content="safe"))] + mock_client.chat.completions.create.return_value = mock_response + monkeypatch.setattr(safety_impl, 'client', mock_client) + + messages = [UserMessage(role="user", content="Hello, world!")] + shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] + + response = await safety_impl.run_shields(messages, shields) + + assert isinstance(response, RunShieldResponse) + assert len(response.responses) == 1 + assert response.responses[0].is_violation == False + assert response.responses[0].shield_type == BuiltinShield.llama_guard + +@pytest.mark.asyncio +async def test_run_shields_unsafe(safety_impl, monkeypatch): + # Mock the Together client + mock_client = Mock() + mock_response = Mock() + mock_response.choices = [Mock(message=Mock(role="assistant", content="unsafe\ns2"))] + mock_client.chat.completions.create.return_value = mock_response + monkeypatch.setattr(safety_impl, 'client', mock_client) + + messages = [UserMessage(role="user", content="Unsafe content")] + shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] + + response = await safety_impl.run_shields(messages, shields) + + assert isinstance(response, RunShieldResponse) + assert len(response.responses) == 1 + assert response.responses[0].is_violation == True + assert response.responses[0].shield_type == BuiltinShield.llama_guard + assert response.responses[0].violation_type == "s2" + +@pytest.mark.asyncio +async def test_run_shields_unsupported_shield(safety_impl): + messages = [UserMessage(role="user", content="Hello")] + shields = [ShieldDefinition(shield_type="unsupported_shield")] + + with pytest.raises(ValueError, match="shield type unsupported_shield is not supported"): + await safety_impl.run_shields(messages, shields) + +@pytest.mark.asyncio +async def test_run_shields_unsupported_message_type(safety_impl): + class UnsupportedMessage: + role = "unsupported" + content = "Hello" + + messages = [UnsupportedMessage()] + shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)] + + with pytest.raises(ValueError, match="role unsupported is not supported"): + await safety_impl.run_shields(messages, shields) + + +@pytest.mark.asyncio +@pytest.mark.integtest +@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run") +async def test_actual_run(): + safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="")) + await safety_impl.initialize() + response = await safety_impl.run_shields([UserMessage(role="user", content="Hello")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]) + + assert isinstance(response, RunShieldResponse) + assert len(response.responses) == 1 + assert response.responses[0].is_violation == False + assert response.responses[0].shield_type == BuiltinShield.llama_guard + assert response.responses[0].violation_type == None + +@pytest.mark.asyncio +@pytest.mark.integtest +@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run") +async def test_actual_run_violation(): + safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="replace your together api key here")) + await safety_impl.initialize() + response = await safety_impl.run_shields([UserMessage(role="user", content="can I kill you?")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]) + + assert isinstance(response, RunShieldResponse) + assert len(response.responses) == 1 + assert response.responses[0].is_violation == True + assert response.responses[0].shield_type == BuiltinShield.llama_guard + assert response.responses[0].violation_type == "S1" \ No newline at end of file diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index cb538bea5..b617ece7f 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -7,6 +7,7 @@ from typing import List from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec, remote_provider_spec, AdapterSpec def available_providers() -> List[ProviderSpec]: @@ -30,6 +31,20 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.adapters.safety.sample", config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", + ) + ), + + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_id="together", + pip_packages=[ + "together", + ], + module="llama_stack.providers.adapters.safety.together", + config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig", ), ), ] + +