mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
get together api key from provider data. If not present, for now use the configured key
This commit is contained in:
parent
06c6b54529
commit
93ad663e29
2 changed files with 26 additions and 137 deletions
|
@ -7,39 +7,40 @@ import pydantic
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from llama_stack.distribution.request_headers import get_request_provider_data
|
||||||
from .config import TogetherSafetyConfig
|
from .config import TogetherSafetyConfig
|
||||||
from llama_stack.apis.safety import *
|
from llama_stack.apis.safety import *
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
class TogetherHeaderInfo(BaseModel):
|
||||||
|
together_api_key: str
|
||||||
|
|
||||||
|
|
||||||
class TogetherSafetyImpl(Safety):
|
class TogetherSafetyImpl(Safety):
|
||||||
def __init__(self, config: TogetherSafetyConfig) -> None:
|
def __init__(self, config: TogetherSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self._client = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client(self) -> Together:
|
|
||||||
if self._client == None:
|
|
||||||
self._client = Together(api_key=self.config.api_key)
|
|
||||||
return self._client
|
|
||||||
|
|
||||||
@client.setter
|
|
||||||
def client(self, client: Together) -> None:
|
|
||||||
self._client = client
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_shields(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
shield_type: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
shields: List[ShieldDefinition],
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
# support only llama guard shield
|
# support only llama guard shield
|
||||||
for shield in shields:
|
|
||||||
if not isinstance(shield.shield_type, BuiltinShield) or shield.shield_type != BuiltinShield.llama_guard:
|
if shield_type != "llama_guard":
|
||||||
raise ValueError(f"shield type {shield.shield_type} is not supported")
|
raise ValueError(f"shield type {shield_type} is not supported")
|
||||||
|
|
||||||
|
provider_data = get_request_provider_data()
|
||||||
|
together_api_key = self.config.api_key
|
||||||
|
# @TODO error out if together_api_key is missing in the header
|
||||||
|
if provider_data is not None:
|
||||||
|
if isinstance(provider_data, TogetherHeaderInfo):
|
||||||
|
together_api_key = provider_data.together_api_key
|
||||||
|
|
||||||
# messages can have role assistant or user
|
# messages can have role assistant or user
|
||||||
api_messages = []
|
api_messages = []
|
||||||
|
@ -50,32 +51,26 @@ class TogetherSafetyImpl(Safety):
|
||||||
raise ValueError(f"role {message.role} is not supported")
|
raise ValueError(f"role {message.role} is not supported")
|
||||||
|
|
||||||
# construct Together request
|
# construct Together request
|
||||||
responses = await asyncio.gather(*[get_safety_response(self.client, api_messages)])
|
response = await asyncio.run(get_safety_response(together_api_key, api_messages))
|
||||||
return RunShieldResponse(responses=responses)
|
return RunShieldResponse(violation=response)
|
||||||
|
|
||||||
async def get_safety_response(client: Together, messages: List[Dict[str, str]]) -> Optional[ShieldResponse]:
|
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
|
||||||
|
client = Together(api_key=api_key)
|
||||||
response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B")
|
response = client.chat.completions.create(messages=messages, model="meta-llama/Meta-Llama-Guard-3-8B")
|
||||||
if len(response.choices) == 0:
|
if len(response.choices) == 0:
|
||||||
return ShieldResponse(shield_type=BuiltinShield.llama_guard, is_violation=False)
|
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
||||||
|
|
||||||
response_text = response.choices[0].message.content
|
response_text = response.choices[0].message.content
|
||||||
if response_text == 'safe':
|
if response_text == 'safe':
|
||||||
return ShieldResponse(
|
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
is_violation=False,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
parts = response_text.split("\n")
|
parts = response_text.split("\n")
|
||||||
if not len(parts) == 2:
|
if not len(parts) == 2:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if parts[0] == 'unsafe':
|
if parts[0] == 'unsafe':
|
||||||
return ShieldResponse(
|
SafetyViolation(violation_level=ViolationLevel.WARN, user_message="unsafe",
|
||||||
shield_type=BuiltinShield.llama_guard,
|
metadata={"violation_type": parts[1]})
|
||||||
is_violation=True,
|
|
||||||
violation_type=parts[1],
|
|
||||||
violation_return_message="Sorry, I cannot do that"
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -1,106 +0,0 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, AsyncMock
|
|
||||||
from llama_stack.apis.safety import UserMessage, ShieldDefinition, ShieldType, RunShieldResponse, ShieldResponse, \
|
|
||||||
BuiltinShield
|
|
||||||
from llama_stack.providers.adapters.safety.together import TogetherSafetyImpl, TogetherSafetyConfig
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def safety_config():
|
|
||||||
return TogetherSafetyConfig(api_key="test_api_key")
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def safety_impl(safety_config):
|
|
||||||
return TogetherSafetyImpl(safety_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(safety_impl):
|
|
||||||
await safety_impl.initialize()
|
|
||||||
# Add assertions if needed for initialization
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_shields_safe(safety_impl, monkeypatch):
|
|
||||||
# Mock the Together client
|
|
||||||
mock_client = Mock()
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.choices = [Mock(message=Mock(role="assistant", content="safe"))]
|
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
|
||||||
monkeypatch.setattr(safety_impl, 'client', mock_client)
|
|
||||||
|
|
||||||
messages = [UserMessage(role="user", content="Hello, world!")]
|
|
||||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
|
||||||
|
|
||||||
response = await safety_impl.run_shields(messages, shields)
|
|
||||||
|
|
||||||
assert isinstance(response, RunShieldResponse)
|
|
||||||
assert len(response.responses) == 1
|
|
||||||
assert response.responses[0].is_violation == False
|
|
||||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_shields_unsafe(safety_impl, monkeypatch):
|
|
||||||
# Mock the Together client
|
|
||||||
mock_client = Mock()
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.choices = [Mock(message=Mock(role="assistant", content="unsafe\ns2"))]
|
|
||||||
mock_client.chat.completions.create.return_value = mock_response
|
|
||||||
monkeypatch.setattr(safety_impl, 'client', mock_client)
|
|
||||||
|
|
||||||
messages = [UserMessage(role="user", content="Unsafe content")]
|
|
||||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
|
||||||
|
|
||||||
response = await safety_impl.run_shields(messages, shields)
|
|
||||||
|
|
||||||
assert isinstance(response, RunShieldResponse)
|
|
||||||
assert len(response.responses) == 1
|
|
||||||
assert response.responses[0].is_violation == True
|
|
||||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
|
||||||
assert response.responses[0].violation_type == "s2"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_shields_unsupported_shield(safety_impl):
|
|
||||||
messages = [UserMessage(role="user", content="Hello")]
|
|
||||||
shields = [ShieldDefinition(shield_type="unsupported_shield")]
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="shield type unsupported_shield is not supported"):
|
|
||||||
await safety_impl.run_shields(messages, shields)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_shields_unsupported_message_type(safety_impl):
|
|
||||||
class UnsupportedMessage:
|
|
||||||
role = "unsupported"
|
|
||||||
content = "Hello"
|
|
||||||
|
|
||||||
messages = [UnsupportedMessage()]
|
|
||||||
shields = [ShieldDefinition(shield_type=BuiltinShield.llama_guard)]
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="role unsupported is not supported"):
|
|
||||||
await safety_impl.run_shields(messages, shields)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integtest
|
|
||||||
@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run")
|
|
||||||
async def test_actual_run():
|
|
||||||
safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="<replace your together api key here>"))
|
|
||||||
await safety_impl.initialize()
|
|
||||||
response = await safety_impl.run_shields([UserMessage(role="user", content="Hello")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)])
|
|
||||||
|
|
||||||
assert isinstance(response, RunShieldResponse)
|
|
||||||
assert len(response.responses) == 1
|
|
||||||
assert response.responses[0].is_violation == False
|
|
||||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
|
||||||
assert response.responses[0].violation_type == None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.integtest
|
|
||||||
@pytest.mark.skipif("'integtest' not in sys.argv", reason="need -m integtest option to run")
|
|
||||||
async def test_actual_run_violation():
|
|
||||||
safety_impl = TogetherSafetyImpl(config=TogetherSafetyConfig(api_key="replace your together api key here"))
|
|
||||||
await safety_impl.initialize()
|
|
||||||
response = await safety_impl.run_shields([UserMessage(role="user", content="can I kill you?")], [ShieldDefinition(shield_type=BuiltinShield.llama_guard)])
|
|
||||||
|
|
||||||
assert isinstance(response, RunShieldResponse)
|
|
||||||
assert len(response.responses) == 1
|
|
||||||
assert response.responses[0].is_violation == True
|
|
||||||
assert response.responses[0].shield_type == BuiltinShield.llama_guard
|
|
||||||
assert response.responses[0].violation_type == "S1"
|
|
Loading…
Add table
Add a link
Reference in a new issue