From 2df858bdd6d0cc025ff984577131a2811ff36c4b Mon Sep 17 00:00:00 2001 From: Yogish Baliga Date: Tue, 24 Sep 2024 11:11:49 -0700 Subject: [PATCH] addressed the PR comments --- .../providers/adapters/safety/together/safety.py | 15 ++++----------- llama_stack/providers/registry/safety.py | 1 - requirements.txt | 4 +--- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/adapters/safety/together/safety.py b/llama_stack/providers/adapters/safety/together/safety.py index 6dedc0650..37515d90e 100644 --- a/llama_stack/providers/adapters/safety/together/safety.py +++ b/llama_stack/providers/adapters/safety/together/safety.py @@ -25,13 +25,8 @@ class TogetherSafetyImpl(Safety): pass async def run_shield( - self, - shield_type: str, - messages: List[Message], - params: Dict[str, Any] = None, + self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - # support only llama guard shield - if shield_type != "llama_guard": raise ValueError(f"shield type {shield_type} is not supported") @@ -47,13 +42,11 @@ class TogetherSafetyImpl(Safety): # messages can have role assistant or user api_messages = [] for message in messages: - if type(message) is UserMessage: + if message.role == Role.user.value: api_messages.append({'role': message.role, 'content': message.content}) - else: - raise ValueError(f"role {message.role} is not supported") # construct Together request - response = await asyncio.run(get_safety_response(together_api_key, api_messages)) + response = await get_safety_response(together_api_key, api_messages) return RunShieldResponse(violation=response) async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> Optional[SafetyViolation]: @@ -67,7 +60,7 @@ async def get_safety_response(api_key: str, messages: List[Dict[str, str]]) -> O return SafetyViolation(violation_level=ViolationLevel.INFO, user_message="safe") else: parts = response_text.split("\n") - if not len(parts) == 2: + if len(parts) != 2: return None if parts[0] == 'unsafe': diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index b617ece7f..778792c90 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -33,7 +33,6 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", ) ), - remote_provider_spec( api=Api.safety, adapter=AdapterSpec( diff --git a/requirements.txt b/requirements.txt index 56dfed264..f14023e26 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,4 @@ prompt-toolkit python-dotenv pydantic requests -termcolor -pytest -pytest-asyncio +termcolor \ No newline at end of file