mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
addressed the PR comments
This commit is contained in:
parent
cb00e5933b
commit
2df858bdd6
3 changed files with 5 additions and 15 deletions
|
@ -25,13 +25,8 @@ class TogetherSafetyImpl(Safety):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
shield_type: str,
|
|
||||||
messages: List[Message],
|
|
||||||
params: Dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
# support only llama guard shield
|
|
||||||
|
|
||||||
if shield_type != "llama_guard":
|
if shield_type != "llama_guard":
|
||||||
raise ValueError(f"shield type {shield_type} is not supported")
|
raise ValueError(f"shield type {shield_type} is not supported")
|
||||||
|
|
||||||
|
@ -47,13 +42,11 @@ class TogetherSafetyImpl(Safety):
|
||||||
# messages can have role assistant or user
|
# messages can have role assistant or user
|
||||||
api_messages = []
|
api_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if type(message) is UserMessage:
|
if message.role == Role.user.value:
|
||||||
api_messages.append({'role': message.role, 'content': message.content})
|
api_messages.append({'role': message.role, 'content': message.content})
|
||||||
else:
|
|
||||||
raise ValueError(f"role {message.role} is not supported")
|
|
||||||
|
|
||||||
# construct Together request
|
# construct Together request
|
||||||
response = await asyncio.run(get_safety_response(together_api_key, api_messages))
|
response = await get_safety_response(together_api_key, api_messages)
|
||||||
return RunShieldResponse(violation=response)
|
return RunShieldResponse(violation=response)
|
||||||
|
|
||||||
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
|
async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]:
|
||||||
|
@ -67,7 +60,7 @@ async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> O
|
||||||
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe")
|
||||||
else:
|
else:
|
||||||
parts = response_text.split("\n")
|
parts = response_text.split("\n")
|
||||||
if not len(parts) == 2:
|
if len(parts) != 2:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if parts[0] == 'unsafe':
|
if parts[0] == 'unsafe':
|
||||||
|
|
|
@ -33,7 +33,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -7,6 +7,4 @@ prompt-toolkit
|
||||||
python-dotenv
|
python-dotenv
|
||||||
pydantic
|
pydantic
|
||||||
requests
|
requests
|
||||||
termcolor
|
termcolor
|
||||||
pytest
|
|
||||||
pytest-asyncio
|
|
Loading…
Add table
Add a link
Reference in a new issue