mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
adding safety adapter for Together
This commit is contained in:
parent
e617273d8c
commit
9f3300df25
6 changed files with 243 additions and 1 deletions
|
@ -4,7 +4,7 @@ distribution_spec:
|
|||
providers:
|
||||
inference: remote::together
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
safety: remote::together
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: conda
|
||||
|
|
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
18
llama_stack/providers/adapters/safety/together/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .config import TogetherSafetyConfig
|
||||
from .safety import TogetherSafetyImpl
|
||||
|
||||
|
||||
async def get_adapter_impl(config: TogetherSafetyConfig, _deps):
|
||||
from .safety import TogetherSafetyImpl
|
||||
|
||||
assert isinstance(
|
||||
config, TogetherSafetyConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
impl = TogetherSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
20
llama_stack/providers/adapters/safety/together/config.py
Normal file
20
llama_stack/providers/adapters/safety/together/config.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherSafetyConfig(BaseModel):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="The Together AI API Key",
|
||||
)
|
83
llama_stack/providers/adapters/safety/together/safety.py
Normal file
83
llama_stack/providers/adapters/safety/together/safety.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from together import Together
|
||||
|
||||
import asyncio
|
||||
from .config import TogetherSafetyConfig
|
||||
from llama_stack.apis.safety import *
|
||||
import logging
|
||||
|
||||
|
||||
class TogetherSafetyImpl(Safety):
|
||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> Together:
|
||||
if self._client == None:
|
||||
self._client = Together(api_key=self.config.api_key)
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, client: Together) -> None:
|
||||
self._client = client
|
||||
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def run_shields(
|
||||
self,
|
||||
messages: List[Message],
|
||||
shields: List[ShieldDefinition],
|
||||
) -> RunShieldResponse:
|
||||
# support only llama guard shield
|
||||
for shield in shields:
|
||||
if not isinstance(shield.shield_type, BuiltinShield) or shield.shield_type != BuiltinShield.llama_guard:
|
||||
raise ValueError(f"shield type {shield.shield_type} is not supported")
|
||||
|
||||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if type(message) is UserMessage:
|
||||
api_messages.append({'role': message.role, 'content': message.content})
|
||||
else:
|
||||
raise ValueError(f"role {message.role} is not supported")
|
||||
|
||||
# construct Together request
|
||||
responses = await asyncio.gather(*[get_safety_response(self.client, api_messages)])
|
||||
return RunShieldResponse(responses=responses)
|
||||
|
||||
async def get_safety_response(client: Together, messages: List[Dict[str, str]]) -> Optional[ShieldResponse]:
|
||||
response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B")
|
||||
if len(response.choices) == 0:
|
||||
return ShieldResponse(shield_type=BuiltinShield.llama_guard, is_violation=False)
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
if response_text == 'safe':
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
parts = response_text.split("\n")
|
||||
if not len(parts) == 2:
|
||||
return None
|
||||
|
||||
if parts[0] == 'unsafe':
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=True,
|
||||
violation_type=parts[1],
|
||||
violation_return_message="Sorry, I cannot do that"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
106
llama_stack/providers/adapters/safety/together/test_safety.py
Normal file
106
llama_stack/providers/adapters/safety/together/test_safety.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from llama_stack.apis.safety import UserMessage, ShieldDefinition, ShieldType, RunShieldResponse, ShieldResponse, \
|
||||
BuiltinShield
|
||||
from llama_stack.providers.adapters.safety.together import TogetherSafetyImpl, TogetherSafetyConfig
|
||||
|
||||
@pytest.fixture
|
||||
def safety_config():
|
||||
return TogetherSafetyConfig(api_key="test_api_key")
|
||||
|
||||
@pytest.fixture
|
||||
def safety_impl(safety_config):
|
||||
return TogetherSafetyImpl(safety_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(safety_impl):
|
||||
await safety_impl.initialize()
|
||||
# Add assertions if needed for initialization
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_safe(safety_impl, monkeypatch):
|
||||
# Mock the Together client
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(role="assistant", content="safe"))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
monkeypatch.setattr(safety_impl, 'client', mock_client)
|
||||
|
||||
messages = [UserMessage(role="user", content="Hello, world!")]
|
||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||
|
||||
response = await safety_impl.run_shields(messages, shields)
|
||||
|
||||
assert isinstance(response, RunShieldResponse)
|
||||
assert len(response.responses) == 1
|
||||
assert response.responses[0].is_violation == False
|
||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_unsafe(safety_impl, monkeypatch):
|
||||
# Mock the Together client
|
||||
mock_client = Mock()
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(role="assistant", content="unsafe\ns2"))]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
monkeypatch.setattr(safety_impl, 'client', mock_client)
|
||||
|
||||
messages = [UserMessage(role="user", content="Unsafe content")]
|
||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||
|
||||
response = await safety_impl.run_shields(messages, shields)
|
||||
|
||||
assert isinstance(response, RunShieldResponse)
|
||||
assert len(response.responses) == 1
|
||||
assert response.responses[0].is_violation == True
|
||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
||||
assert response.responses[0].violation_type == "s2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_unsupported_shield(safety_impl):
|
||||
messages = [UserMessage(role="user", content="Hello")]
|
||||
shields = [ShieldDefinition(shield_type="unsupported_shield")]
|
||||
|
||||
with pytest.raises(ValueError, match="shield type unsupported_shield is not supported"):
|
||||
await safety_impl.run_shields(messages, shields)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shields_unsupported_message_type(safety_impl):
|
||||
class UnsupportedMessage:
|
||||
role = "unsupported"
|
||||
content = "Hello"
|
||||
|
||||
messages = [UnsupportedMessage()]
|
||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
||||
|
||||
with pytest.raises(ValueError, match="role unsupported is not supported"):
|
||||
await safety_impl.run_shields(messages, shields)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integtest
|
||||
@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run")
|
||||
async def test_actual_run():
|
||||
safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="<replace your together api key here>"))
|
||||
await safety_impl.initialize()
|
||||
response = await safety_impl.run_shields([UserMessage(role="user", content="Hello")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)])
|
||||
|
||||
assert isinstance(response, RunShieldResponse)
|
||||
assert len(response.responses) == 1
|
||||
assert response.responses[0].is_violation == False
|
||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
||||
assert response.responses[0].violation_type == None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.integtest
|
||||
@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run")
|
||||
async def test_actual_run_violation():
|
||||
safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="replace your together api key here"))
|
||||
await safety_impl.initialize()
|
||||
response = await safety_impl.run_shields([UserMessage(role="user", content="can I kill you?")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)])
|
||||
|
||||
assert isinstance(response, RunShieldResponse)
|
||||
assert len(response.responses) == 1
|
||||
assert response.responses[0].is_violation == True
|
||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
||||
assert response.responses[0].violation_type == "S1"
|
|
@ -7,6 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec, remote_provider_spec, AdapterSpec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
@ -30,6 +31,20 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=[],
|
||||
module="llama_stack.providers.adapters.safety.sample",
|
||||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||
)
|
||||
),
|
||||
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_id="together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.adapters.safety.together",
|
||||
config_class="llama_stack.providers.adapters.safety.together.TogetherSafetyConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue