Allow specifying resources in StackRunConfig

This commit is contained in:
Ashwin Bharambe 2024-11-11 22:08:51 -08:00
parent 8035fa1869
commit 38257a9cbe
9 changed files with 151 additions and 102 deletions

View file

@ -9,6 +9,8 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.models import Model
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
@ -159,13 +161,13 @@ async def inference_stack(request, inference_model):
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
models=[
Model(
identifier=inference_model,
provider_resource_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
)
return (impls[Api.inference], impls[Api.models])

View file

@ -17,29 +17,38 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=apis,
providers=providers,
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
models=models or [],
shields=shields or [],
memory_banks=memory_banks or [],
datasets=datasets or [],
scoring_fns=scoring_fns or [],
eval_tasks=eval_tasks or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
try:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
impls = await construct_stack(run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e

View file

@ -7,7 +7,9 @@
import pytest
import pytest_asyncio
from llama_stack.apis.shields import ShieldType
from llama_stack.apis.models import Model
from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
@ -96,32 +98,29 @@ async def safety_stack(inference_model, safety_model, request):
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield = get_shield(
shield_provider_type, safety_fixture.providers[0].provider_id, safety_model
)
impls = await resolve_impls_for_test_v2(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[
Model(
identifier=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
provider_resource_id=inference_model,
)
],
shields=[shield],
)
safety_impl = impls[Api.safety]
shields_impl = impls[Api.shields]
# Register the appropriate shield based on provider type
provider_type = safety_fixture.providers[0].provider_type
shield = await create_and_register_shield(provider_type, safety_model, shields_impl)
provider_id = inference_fixture.providers[0].provider_id
print(f"Registering model {inference_model} with provider {provider_id}")
await impls[Api.models].register_model(
model_id=inference_model,
provider_id=provider_id,
)
return safety_impl, shields_impl, shield
return impls[Api.safety], impls[Api.shields], shield
async def create_and_register_shield(
provider_type: str, safety_model: str, shields_impl
):
def get_shield(provider_type: str, provider_id: str, safety_model: str):
shield_config = {}
shield_type = ShieldType.llama_guard
identifier = "llama_guard"
@ -134,8 +133,10 @@ async def create_and_register_shield(
shield_config["guardrailVersion"] = get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")
shield_type = ShieldType.generic_content_shield
return await shields_impl.register_shield(
shield_id=identifier,
return Shield(
identifier=identifier,
shield_type=shield_type,
params=shield_config,
provider_id=provider_id,
provider_resource_id=identifier,
)