addressed the PR comments

This commit is contained in:
Yogish Baliga 2024-09-24 11:11:49 -07:00
parent cb00e5933b
commit 2df858bdd6
3 changed files with 5 additions and 15 deletions

View file

@ -25,13 +25,8 @@ class TogetherSafetyImpl(Safety):
pass
async def run_shield(
self,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
# support only llama guard shield
if shield_type != "llama_guard":
raise ValueError(f"shield type {shield_type} is not supported")
@ -47,13 +42,11 @@ class TogetherSafetyImpl(Safety):
# messages can have role assistant or user
api_messages = []
for message in messages:
if type(message) is UserMessage:
if message.role == Role.user.value:
api_messages.append({'role': message.role, 'content': message.content})
else:
raise ValueError(f"role {message.role} is not supported")
# construct Together request
response = await asyncio.run(get_safety_response(together_api_key, api_messages))
response = await get_safety_response(together_api_key, api_messages)
return RunShieldResponse(violation=response)
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
@ -67,7 +60,7 @@ async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> O
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
else:
parts = response_text.split("\n")
if not len(parts) == 2:
if len(parts) != 2:
return None
if parts[0] == 'unsafe':

View file

@ -33,7 +33,6 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
)
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(

View file

@ -7,6 +7,4 @@ prompt-toolkit
python-dotenv
pydantic
requests
termcolor
pytest
pytest-asyncio
termcolor