mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
# What does this PR do? This PR brings back the facility to not force registration of resources onto the user. This is not just annoying but actually not feasible sometimes. For example, you may have a Stack which boots up with private providers for inference for models A and B. There is no way for the user to actually know which model is being served by these providers now (to be able to register it.) How will this avoid the users needing to do registration? In a follow-up diff, I will make sure I update the sample run.yaml files so they list the models served by the distributions explicitly. So when users do `llama stack build --template <...>` and run it, their distributions come up with the right set of models they expect. For self-hosted distributions, it also allows us to have a place to explicit list the models that need to be served to make the "complete" stack (including safety, e.g.) ## Test Plan Started ollama locally with two lightweight models: Llama3.2-3B-Instruct and Llama-Guard-3-1B. Updated all the tests including agents. Here's the tests I ran so far: ```bash pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \ --env FIREWORKS_API_KEY=... pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference pytest -s -v -m ollama test_safety.py pytest -s -v -m faiss test_memory.py pytest -s -v -m ollama test_agents.py \ --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B ``` Found a few bugs here and there pre-existing that these test runs fixed.
142 lines
4.5 KiB
Python
142 lines
4.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import Model
|
|
|
|
from llama_stack.apis.shields import Shield, ShieldType
|
|
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
|
|
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
from ..env import get_env_or_fail
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_model(request):
|
|
if hasattr(request, "param"):
|
|
return request.param
|
|
return request.config.getoption("--safety-model", None)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_llama_guard(safety_model) -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="inline::llama-guard",
|
|
provider_type="inline::llama-guard",
|
|
config=LlamaGuardConfig(model=safety_model).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
# TODO: this is not tested yet; we would need to configure the run_shield() test
|
|
# and parametrize it with the "prompt" for testing depending on the safety fixture
|
|
# we are using.
|
|
@pytest.fixture(scope="session")
|
|
def safety_prompt_guard() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="inline::prompt-guard",
|
|
provider_type="inline::prompt-guard",
|
|
config=PromptGuardConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def safety_bedrock() -> ProviderFixture:
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="bedrock",
|
|
provider_type="remote::bedrock",
|
|
config=BedrockSafetyConfig().model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def safety_stack(inference_model, safety_model, request):
|
|
# We need an inference + safety fixture to test safety
|
|
fixture_dict = request.param
|
|
inference_fixture = request.getfixturevalue(
|
|
f"inference_{fixture_dict['inference']}"
|
|
)
|
|
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
|
|
|
providers = {
|
|
"inference": inference_fixture.providers,
|
|
"safety": safety_fixture.providers,
|
|
}
|
|
provider_data = {}
|
|
if inference_fixture.provider_data:
|
|
provider_data.update(inference_fixture.provider_data)
|
|
if safety_fixture.provider_data:
|
|
provider_data.update(safety_fixture.provider_data)
|
|
|
|
shield_provider_type = safety_fixture.providers[0].provider_type
|
|
shield = get_shield_to_register(
|
|
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
|
)
|
|
|
|
impls = await resolve_impls_for_test_v2(
|
|
[Api.safety, Api.shields, Api.inference],
|
|
providers,
|
|
provider_data,
|
|
models=[
|
|
Model(
|
|
identifier=inference_model,
|
|
provider_id=inference_fixture.providers[0].provider_id,
|
|
provider_resource_id=inference_model,
|
|
)
|
|
],
|
|
shields=[shield],
|
|
)
|
|
|
|
return impls[Api.safety], impls[Api.shields], shield
|
|
|
|
|
|
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
|
|
shield_config = {}
|
|
shield_type = ShieldType.llama_guard
|
|
identifier = "llama_guard"
|
|
if provider_type == "meta-reference":
|
|
shield_config["model"] = safety_model
|
|
elif provider_type == "remote::together":
|
|
shield_config["model"] = safety_model
|
|
elif provider_type == "remote::bedrock":
|
|
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
|
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
|
shield_type = ShieldType.generic_content_shield
|
|
|
|
return Shield(
|
|
identifier=identifier,
|
|
shield_type=shield_type,
|
|
params=shield_config,
|
|
provider_id=provider_id,
|
|
provider_resource_id=identifier,
|
|
)
|