remove debug logs. wip tests

remove excess debug logs
provider tests for fiddlecube safety provider setup
todo: fix test and get it to run with an inference provider
This commit is contained in:
Kaushik 2025-02-14 12:37:28 -08:00
parent 4b4b592bb1
commit 20f0836d13
3 changed files with 26 additions and 11 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List
@ -57,18 +56,12 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
headers=headers,
)
logger.debug("Response:::", response.status_code)
# Check if the response is successful
if response.status_code != 200:
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
raise RuntimeError("Failed to run shield with FiddleCube API")
# Convert the response into the format RunShieldResponse expects
response_data = response.json()
logger.debug("Response data: %s", json.dumps(response_data, indent=2))
# Check if there's a violation based on the response structure
if response_data.get("action") == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}

View file

@ -53,11 +53,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="remote",
marks=pytest.mark.remote,
),
pytest.param(
{
"inference": "bedrock",
"safety": "fiddlecube",
},
id="fiddlecube",
marks=pytest.mark.fiddlecube,
),
]
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]:
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "fiddlecube"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",
@ -65,7 +73,7 @@ def pytest_configure(config):
SAFETY_SHIELD_PARAMS = [
pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
pytest.param("meta-llama/Llama3.1-70B-Instruct", marks=pytest.mark.guard_1b, id="guard_1b"),
]
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
assert shield_id.startswith("meta-llama/")
# assert shield_id.startswith("meta-llama/")
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_SHIELD_PARAMS

View file

@ -16,6 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.remote.safety.fiddlecube.config import FiddlecubeSafetyConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -98,7 +99,20 @@ def safety_bedrock() -> ProviderFixture:
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest.fixture(scope="session")
def safety_fiddlecube() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="fiddlecube",
provider_type="remote::fiddlecube",
config=FiddlecubeSafetyConfig().model_dump(),
)
],
)
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "fiddlecube"]
@pytest_asyncio.fixture(scope="session")