mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-19 07:37:18 +00:00
fixed agent persistence test, more cleanup
This commit is contained in:
parent
4f3b009980
commit
22aedd0277
14 changed files with 202 additions and 310 deletions
|
@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
@ -102,22 +102,16 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
|||
async def safety_stack(inference_model, safety_shield, request):
|
||||
# We need an inference + safety fixture to test safety
|
||||
fixture_dict = request.param
|
||||
inference_fixture = request.getfixturevalue(
|
||||
f"inference_{fixture_dict['inference']}"
|
||||
)
|
||||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
||||
|
||||
providers = {
|
||||
"inference": inference_fixture.providers,
|
||||
"safety": safety_fixture.providers,
|
||||
}
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
if inference_fixture.provider_data:
|
||||
provider_data.update(inference_fixture.provider_data)
|
||||
if safety_fixture.provider_data:
|
||||
provider_data.update(safety_fixture.provider_data)
|
||||
for key in ["inference", "safety"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
|
@ -125,5 +119,5 @@ async def safety_stack(inference_model, safety_shield, request):
|
|||
shields=[safety_shield],
|
||||
)
|
||||
|
||||
shield = await impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||
return impls[Api.safety], impls[Api.shields], shield
|
||||
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue