From 20f0836d13f9a0d11ce43b9c29d51ad63cff4888 Mon Sep 17 00:00:00 2001 From: Kaushik Date: Fri, 14 Feb 2025 12:37:28 -0800 Subject: [PATCH] remove debug logs. wip tests remove excess debug logs provider tests for fiddlecube safety provider setup todo: fix test and get it to run with an inference provider --- .../remote/safety/fiddlecube/fiddlecube.py | 7 ------- llama_stack/providers/tests/safety/conftest.py | 14 +++++++++++--- llama_stack/providers/tests/safety/fixtures.py | 16 +++++++++++++++- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py index 3ed8a6285..9a468f663 100644 --- a/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py +++ b/llama_stack/providers/remote/safety/fiddlecube/fiddlecube.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import logging from typing import Any, Dict, List @@ -57,18 +56,12 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate): headers=headers, ) - logger.debug("Response:::", response.status_code) - - # Check if the response is successful if response.status_code != 200: logger.error(f"FiddleCube API error: {response.status_code} - {response.text}") raise RuntimeError("Failed to run shield with FiddleCube API") - # Convert the response into the format RunShieldResponse expects response_data = response.json() - logger.debug("Response data: %s", json.dumps(response_data, indent=2)) - # Check if there's a violation based on the response structure if response_data.get("action") == "GUARDRAIL_INTERVENED": user_message = "" metadata = {} diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 10a8517fc..b5e0db859 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -53,11 +53,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="remote", marks=pytest.mark.remote, ), + pytest.param( + { + "inference": "bedrock", + "safety": "fiddlecube", + }, + id="fiddlecube", + marks=pytest.mark.fiddlecube, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]: + for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "fiddlecube"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", @@ -65,7 +73,7 @@ def pytest_configure(config): SAFETY_SHIELD_PARAMS = [ - pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"), + pytest.param("meta-llama/Llama3.1-70B-Instruct", marks=pytest.mark.guard_1b, id="guard_1b"), ] @@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc): if "safety_shield" in metafunc.fixturenames: shield_id = metafunc.config.getoption("--safety-shield") if shield_id: - assert shield_id.startswith("meta-llama/") + # assert shield_id.startswith("meta-llama/") params = [pytest.param(shield_id, id="")] else: params = SAFETY_SHIELD_PARAMS diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 32883bfab..457e886a7 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,6 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig +from llama_stack.providers.remote.safety.fiddlecube.config import FiddlecubeSafetyConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -98,7 +99,20 @@ def safety_bedrock() -> ProviderFixture: ) -SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] +@pytest.fixture(scope="session") +def safety_fiddlecube() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="fiddlecube", + provider_type="remote::fiddlecube", + config=FiddlecubeSafetyConfig().model_dump(), + ) + ], + ) + + +SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "fiddlecube"] @pytest_asyncio.fixture(scope="session")