mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-10 04:08:31 +00:00
remove debug logs. wip tests
remove excess debug logs provider tests for fiddlecube safety provider setup todo: fix test and get it to run with an inference provider
This commit is contained in:
parent
4b4b592bb1
commit
20f0836d13
3 changed files with 26 additions and 11 deletions
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
@ -57,18 +56,12 @@ class FiddlecubeSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
headers=headers,
|
||||
)
|
||||
|
||||
logger.debug("Response:::", response.status_code)
|
||||
|
||||
# Check if the response is successful
|
||||
if response.status_code != 200:
|
||||
logger.error(f"FiddleCube API error: {response.status_code} - {response.text}")
|
||||
raise RuntimeError("Failed to run shield with FiddleCube API")
|
||||
|
||||
# Convert the response into the format RunShieldResponse expects
|
||||
response_data = response.json()
|
||||
logger.debug("Response data: %s", json.dumps(response_data, indent=2))
|
||||
|
||||
# Check if there's a violation based on the response structure
|
||||
if response_data.get("action") == "GUARDRAIL_INTERVENED":
|
||||
user_message = ""
|
||||
metadata = {}
|
||||
|
|
|
@ -53,11 +53,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "bedrock",
|
||||
"safety": "fiddlecube",
|
||||
},
|
||||
id="fiddlecube",
|
||||
marks=pytest.mark.fiddlecube,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock"]:
|
||||
for mark in ["meta_reference", "ollama", "together", "remote", "bedrock", "fiddlecube"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
|
@ -65,7 +73,7 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("meta-llama/Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
pytest.param("meta-llama/Llama3.1-70B-Instruct", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -77,7 +85,7 @@ def pytest_generate_tests(metafunc):
|
|||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
assert shield_id.startswith("meta-llama/")
|
||||
# assert shield_id.startswith("meta-llama/")
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
|
|
|
@ -16,6 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
|
||||
from llama_stack.providers.remote.safety.fiddlecube.config import FiddlecubeSafetyConfig
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
@ -98,7 +99,20 @@ def safety_bedrock() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_fiddlecube() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="fiddlecube",
|
||||
provider_type="remote::fiddlecube",
|
||||
config=FiddlecubeSafetyConfig().model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote", "fiddlecube"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue