mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Allow specifying resources in StackRunConfig (#425)
# What does this PR do? This PR brings back the facility to not force registration of resources onto the user. This is not just annoying but actually not feasible sometimes. For example, you may have a Stack which boots up with private providers for inference for models A and B. There is no way for the user to actually know which model is being served by these providers now (to be able to register it.) How will this avoid the users needing to do registration? In a follow-up diff, I will make sure I update the sample run.yaml files so they list the models served by the distributions explicitly. So when users do `llama stack build --template <...>` and run it, their distributions come up with the right set of models they expect. For self-hosted distributions, it also allows us to have a place to explicit list the models that need to be served to make the "complete" stack (including safety, e.g.) ## Test Plan Started ollama locally with two lightweight models: Llama3.2-3B-Instruct and Llama-Guard-3-1B. Updated all the tests including agents. Here's the tests I ran so far: ```bash pytest -s -v -m "fireworks and llama_3b" test_text_inference.py::TestInference \ --env FIREWORKS_API_KEY=... pytest -s -v -m "ollama and llama_3b" test_text_inference.py::TestInference pytest -s -v -m ollama test_safety.py pytest -s -v -m faiss test_memory.py pytest -s -v -m ollama test_agents.py \ --inference-model=Llama3.2-3B-Instruct --safety-model=Llama-Guard-3-1B ``` Found a few bugs here and there pre-existing that these test runs fixed.
This commit is contained in:
parent
8035fa1869
commit
d9d271a684
15 changed files with 221 additions and 124 deletions
|
@ -31,48 +31,7 @@ from .strong_typing.schema import json_schema_type
|
|||
|
||||
schema_utils.json_schema_type = json_schema_type
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
from llama_stack.distribution.stack import LlamaStack
|
||||
|
||||
|
||||
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
|
||||
|
|
|
@ -22,6 +22,9 @@ class ResourceType(Enum):
|
|||
class Resource(BaseModel):
|
||||
"""Base class for all Llama Stack resources"""
|
||||
|
||||
# TODO: I think we need to move these into the child classes
|
||||
# and make them `model_id`, `shield_id`, etc. because otherwise
|
||||
# the config file has these confusing generic names in there
|
||||
identifier: str = Field(
|
||||
description="Unique identifier for this resource in llama stack"
|
||||
)
|
||||
|
|
|
@ -151,6 +151,14 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
models: List[Model] = Field(default_factory=list)
|
||||
shields: List[Shield] = Field(default_factory=list)
|
||||
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
||||
datasets: List[Dataset] = Field(default_factory=list)
|
||||
scoring_fns: List[ScoringFn] = Field(default_factory=list)
|
||||
eval_tasks: List[EvalTask] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -27,12 +27,7 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.distribution import (
|
||||
builtin_automatically_routed_apis,
|
||||
get_provider_registry,
|
||||
)
|
||||
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
|
@ -42,14 +37,15 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import InvalidProviderError, resolve_impls
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
|
@ -281,12 +277,8 @@ def main(
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||
|
||||
try:
|
||||
impls = asyncio.run(
|
||||
resolve_impls(config, get_provider_registry(), dist_registry)
|
||||
)
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
except InvalidProviderError:
|
||||
sys.exit(1)
|
||||
|
||||
|
|
100
llama_stack/distribution/stack.py
Normal file
100
llama_stack/distribution/stack.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.eval import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.batch_inference import * # noqa: F403
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
from llama_stack.apis.post_training import * # noqa: F403
|
||||
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
MemoryBanks,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
Telemetry,
|
||||
PostTraining,
|
||||
Memory,
|
||||
Eval,
|
||||
EvalTasks,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
DatasetIO,
|
||||
Models,
|
||||
Shields,
|
||||
Inspect,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
objects = [
|
||||
*run_config.models,
|
||||
*run_config.shields,
|
||||
*run_config.memory_banks,
|
||||
*run_config.datasets,
|
||||
*run_config.scoring_fns,
|
||||
*run_config.eval_tasks,
|
||||
]
|
||||
for obj in objects:
|
||||
await dist_registry.register(obj)
|
||||
|
||||
resources = [
|
||||
("models", Api.models),
|
||||
("shields", Api.shields),
|
||||
("memory_banks", Api.memory_banks),
|
||||
("datasets", Api.datasets),
|
||||
("scoring_fns", Api.scoring_functions),
|
||||
("eval_tasks", Api.eval_tasks),
|
||||
]
|
||||
for rsrc, api in resources:
|
||||
if api not in impls:
|
||||
continue
|
||||
|
||||
method = getattr(impls[api], f"list_{api.value}")
|
||||
for obj in await method():
|
||||
print(
|
||||
f"{rsrc.capitalize()}: {colored(obj.identifier, 'white', attrs=['bold'])} served by {colored(obj.provider_id, 'white', attrs=['bold'])}",
|
||||
)
|
||||
|
||||
print("")
|
||||
return impls
|
|
@ -5,14 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Protocol
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
import pydantic
|
||||
|
||||
from llama_stack.distribution.datatypes import (
|
||||
RoutableObjectWithProvider,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
from llama_stack.providers.utils.kvstore import (
|
||||
|
@ -144,17 +141,16 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
|
||||
|
||||
async def create_dist_registry(
|
||||
config: StackRunConfig,
|
||||
metadata_store: Optional[KVStoreConfig],
|
||||
image_name: str,
|
||||
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||
# instantiate kvstore for storing and retrieving distribution metadata
|
||||
if config.metadata_store:
|
||||
dist_kvstore = await kvstore_impl(config.metadata_store)
|
||||
if metadata_store:
|
||||
dist_kvstore = await kvstore_impl(metadata_store)
|
||||
else:
|
||||
dist_kvstore = await kvstore_impl(
|
||||
SqliteKVStoreConfig(
|
||||
db_path=(
|
||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||
).as_posix()
|
||||
db_path=(DISTRIBS_BASE_DIR / image_name / "kvstore.db").as_posix()
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -641,12 +641,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBank(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
await self.memory_banks_api.register_memory_bank(
|
||||
memory_bank_id=bank_id,
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
await self.memory_banks_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
|
@ -19,7 +19,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "meta_reference",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="meta_reference",
|
||||
|
@ -29,7 +29,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
{
|
||||
"inference": "ollama",
|
||||
"safety": "llama_guard",
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="ollama",
|
||||
|
@ -40,7 +40,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
"inference": "together",
|
||||
"safety": "llama_guard",
|
||||
# make this work with Weaviate which is what the together distro supports
|
||||
"memory": "meta_reference",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="together",
|
||||
|
|
|
@ -9,6 +9,7 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference import (
|
||||
|
@ -17,8 +18,18 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||
# multiple models when you need to run a safety model in addition to normal agent
|
||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||
if isinstance(inference_model, list):
|
||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||
assert inference_model is not None
|
||||
return inference_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -49,7 +60,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request):
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -60,9 +71,28 @@ async def agents_stack(request):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
inf_provider_id = providers["inference"][0].provider_id
|
||||
safety_provider_id = providers["safety"][0].provider_id
|
||||
|
||||
shield = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_provider_id, safety_model
|
||||
)
|
||||
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
Model(
|
||||
identifier=model,
|
||||
provider_id=inf_provider_id,
|
||||
provider_resource_id=model,
|
||||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
|
|
|
@ -16,15 +16,12 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
# pytest -v -s llama_stack/providers/tests/agents/test_agents.py
|
||||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
||||
# multiple models when you need to run a safety model in addition to normal agent
|
||||
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
||||
if isinstance(inference_model, list):
|
||||
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
||||
assert inference_model is not None
|
||||
inference_model = pick_inference_model(inference_model)
|
||||
|
||||
return dict(
|
||||
model=inference_model,
|
||||
|
|
|
@ -9,6 +9,8 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.inference.meta_reference import (
|
||||
MetaReferenceInferenceConfig,
|
||||
|
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
|
|||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
models=[
|
||||
Model(
|
||||
identifier=inference_model,
|
||||
provider_resource_id=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
||||
|
|
|
@ -26,13 +26,13 @@ def memory_remote() -> ProviderFixture:
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
def memory_faiss() -> ProviderFixture:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissImplConfig(
|
||||
kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(),
|
||||
).model_dump(),
|
||||
|
@ -93,7 +93,7 @@ def memory_chroma() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
||||
MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -44,7 +44,6 @@ def sample_documents():
|
|||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
|
@ -71,7 +70,7 @@ class TestMemory:
|
|||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = memory_stack
|
||||
|
||||
bank = await banks_impl.register_memory_bank(
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
|
|
|
@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
|
||||
|
||||
async def resolve_impls_for_test_v2(
|
||||
apis: List[Api],
|
||||
providers: Dict[str, List[Provider]],
|
||||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
models: Optional[List[Model]] = None,
|
||||
shields: Optional[List[Shield]] = None,
|
||||
memory_banks: Optional[List[MemoryBank]] = None,
|
||||
datasets: Optional[List[Dataset]] = None,
|
||||
scoring_fns: Optional[List[ScoringFn]] = None,
|
||||
eval_tasks: Optional[List[EvalTask]] = None,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||
models=models or [],
|
||||
shields=shields or [],
|
||||
memory_banks=memory_banks or [],
|
||||
datasets=datasets or [],
|
||||
scoring_fns=scoring_fns or [],
|
||||
eval_tasks=eval_tasks or [],
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
try:
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
impls = await construct_stack(run_config)
|
||||
except ModuleNotFoundError as e:
|
||||
print_pip_install_help(providers)
|
||||
raise e
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.shields import ShieldType
|
||||
from llama_stack.apis.models import Model
|
||||
|
||||
from llama_stack.apis.shields import Shield, ShieldType
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||
|
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
|
|||
if safety_fixture.provider_data:
|
||||
provider_data.update(safety_fixture.provider_data)
|
||||
|
||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = get_shield_to_register(
|
||||
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
|
||||
)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[
|
||||
Model(
|
||||
identifier=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
provider_resource_id=inference_model,
|
||||
)
|
||||
],
|
||||
shields=[shield],
|
||||
)
|
||||
|
||||
safety_impl = impls[Api.safety]
|
||||
shields_impl = impls[Api.shields]
|
||||
|
||||
# Register the appropriate shield based on provider type
|
||||
provider_type = safety_fixture.providers[0].provider_type
|
||||
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
|
||||
|
||||
provider_id = inference_fixture.providers[0].provider_id
|
||||
print(f"Registering model {inference_model} with provider {provider_id}")
|
||||
await impls[Api.models].register_model(
|
||||
model_id=inference_model,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
|
||||
return safety_impl, shields_impl, shield
|
||||
return impls[Api.safety], impls[Api.shields], shield
|
||||
|
||||
|
||||
async def create_and_register_shield(
|
||||
provider_type: str, safety_model: str, shields_impl
|
||||
):
|
||||
def get_shield_to_register(provider_type: str, provider_id: str, safety_model: str):
|
||||
shield_config = {}
|
||||
shield_type = ShieldType.llama_guard
|
||||
identifier = "llama_guard"
|
||||
|
@ -134,8 +133,10 @@ async def create_and_register_shield(
|
|||
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
|
||||
shield_type = ShieldType.generic_content_shield
|
||||
|
||||
return await shields_impl.register_shield(
|
||||
shield_id=identifier,
|
||||
return Shield(
|
||||
identifier=identifier,
|
||||
shield_type=shield_type,
|
||||
params=shield_config,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=identifier,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue