mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
remove debug logs. wip tests
remove excess debug logs provider tests for fiddlecube safety provider setup todo: fix test and get it to run with an inference provider
This commit is contained in:
parent
4b4b592bb1
commit
20f0836d13
3 changed files with 26 additions and 11 deletions
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
@ -57,18 +56,12 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Response:::", response.status_code)
|
|
||||||
|
|
||||||
# Check if the response is successful
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
|
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
|
||||||
raise RuntimeError("Failed to run shield with FiddleCube API")
|
raise RuntimeError("Failed to run shield with FiddleCube API")
|
||||||
|
|
||||||
# Convert the response into the format RunShieldResponse expects
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
logger.debug("Response data: %s", json.dumps(response_data, indent=2))
|
|
||||||
|
|
||||||
# Check if there's a violation based on the response structure
|
|
||||||
if response_data.get("action") == "GUARDRAIL_INTERVENED":
|
if response_data.get("action") == "GUARDRAIL_INTERVENED":
|
||||||
user_message = ""
|
user_message = ""
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
|
@ -53,11 +53,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="remote",
|
id="remote",
|
||||||
marks=pytest.mark.remote,
|
marks=pytest.mark.remote,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "bedrock",
|
||||||
|
"safety": "fiddlecube",
|
||||||
|
},
|
||||||
|
id="fiddlecube",
|
||||||
|
marks=pytest.mark.fiddlecube,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]:
|
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "fiddlecube"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
f"{mark}: marks tests as {mark} specific",
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
@ -65,7 +73,7 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
SAFETY_SHIELD_PARAMS = [
|
SAFETY_SHIELD_PARAMS = [
|
||||||
pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
pytest.param("meta-llama/Llama3.1-70B-Instruct", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
|
||||||
if "safety_shield" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
shield_id = metafunc.config.getoption("--safety-shield")
|
shield_id = metafunc.config.getoption("--safety-shield")
|
||||||
if shield_id:
|
if shield_id:
|
||||||
assert shield_id.startswith("meta-llama/")
|
# assert shield_id.startswith("meta-llama/")
|
||||||
params = [pytest.param(shield_id, id="")]
|
params = [pytest.param(shield_id, id="")]
|
||||||
else:
|
else:
|
||||||
params = SAFETY_SHIELD_PARAMS
|
params = SAFETY_SHIELD_PARAMS
|
||||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.safety.fiddlecube.config import FiddlecubeSafetyConfig
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
@ -98,7 +99,20 @@ def safety_bedrock() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
@pytest.fixture(scope="session")
|
||||||
|
def safety_fiddlecube() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="fiddlecube",
|
||||||
|
provider_type="remote::fiddlecube",
|
||||||
|
config=FiddlecubeSafetyConfig().model_dump(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "fiddlecube"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue