mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
addressed the PR comments
This commit is contained in:
parent
cb00e5933b
commit
2df858bdd6
3 changed files with 5 additions and 15 deletions
|
@ -25,13 +25,8 @@ class TogetherSafetyImpl(Safety):
|
|||
pass
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
# support only llama guard shield
|
||||
|
||||
if shield_type != "llama_guard":
|
||||
raise ValueError(f"shield type {shield_type} is not supported")
|
||||
|
||||
|
@ -47,13 +42,11 @@ class TogetherSafetyImpl(Safety):
|
|||
# messages can have role assistant or user
|
||||
api_messages = []
|
||||
for message in messages:
|
||||
if type(message) is UserMessage:
|
||||
if message.role == Role.user.value:
|
||||
api_messages.append({'role': message.role, 'content': message.content})
|
||||
else:
|
||||
raise ValueError(f"role {message.role} is not supported")
|
||||
|
||||
# construct Together request
|
||||
response = await asyncio.run(get_safety_response(together_api_key, api_messages))
|
||||
response = await get_safety_response(together_api_key, api_messages)
|
||||
return RunShieldResponse(violation=response)
|
||||
|
||||
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
|
||||
|
@ -67,7 +60,7 @@ async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> O
|
|||
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
||||
else:
|
||||
parts = response_text.split("\n")
|
||||
if not len(parts) == 2:
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
if parts[0] == 'unsafe':
|
||||
|
|
|
@ -33,7 +33,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||
)
|
||||
),
|
||||
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -8,5 +8,3 @@ python-dotenv
|
|||
pydantic
|
||||
requests
|
||||
termcolor
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue