fixing safety inference and safety adapter for new API spec. Pinned the llama_models version to 0.0.24 as the latest version 0.0.35 has the model descriptor name changed. I was getting the missing package error during runtime as well, hence added the dependency to requirements.txt

This commit is contained in:
Yogish Baliga 2024-09-25 14:14:15 -07:00
parent 53070e34a3
commit 9bb0c8f4fc
4 changed files with 33 additions and 26 deletions

View file

@ -6,9 +6,16 @@
from together import Together
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.distribution.request_headers import get_request_provider_data
from .config import TogetherProviderDataValidator, TogetherSafetyConfig
from .config import TogetherSafetyConfig
class TogetherSafetyImpl(Safety):
@ -24,21 +31,13 @@ class TogetherSafetyImpl(Safety):
if shield_type != "llama_guard":
raise ValueError(f"shield type {shield_type} is not supported")
provider_data = get_request_provider_data()
together_api_key = None
if provider_data is not None:
if not isinstance(provider_data, TogetherProviderDataValidator):
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
if not together_api_key:
together_api_key = self.config.api_key
if not together_api_key:
raise ValueError("The API key must be provider in the header or config")
provider_data = get_request_provider_data()
if provider_data is None or not provider_data.together_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "together_api_key": <your api key>}'
)
together_api_key = provider_data.together_api_key
# messages can have role assistant or user
api_messages = []
@ -62,7 +61,9 @@ async def get_safety_response(
response_text = response.choices[0].message.content
if response_text == "safe":
return None
return SafetyViolation(
violation_level=ViolationLevel.INFO, user_message="safe", metadata={}
)
parts = response_text.split("\n")
if len(parts) != 2: