Allow passing provider_registry to resolve_impls()

This commit is contained in:
Ashwin Bharambe 2024-10-28 11:58:16 -07:00
parent 8a3b64d1be
commit b7d2b83d55
3 changed files with 13 additions and 12 deletions

View file

@ -13,6 +13,7 @@ import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
@ -36,7 +37,7 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id