remote tests are functional!

This commit is contained in:
Ashwin Bharambe 2024-11-12 19:17:58 -08:00
parent 8b7be87bec
commit 0121114a5d
5 changed files with 93 additions and 89 deletions

View file

@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -308,7 +309,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _ = additional_protocols[provider_spec.api]
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
@ -354,3 +355,31 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
config,
{},
)
return impls

View file

@ -87,13 +87,10 @@ class CommonRoutingTableImpl(RoutingTable):
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()

View file

@ -28,16 +28,11 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import (
additional_protocols_map,
api_protocol_map,
resolve_impls,
)
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
from llama_stack.providers.datatypes import Api
class LlamaStack(
@ -63,31 +58,23 @@ class LlamaStack(
pass
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
impls = await maybe_get_remote_stack_impls(run_config)
if impls is None:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
resources = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
for rsrc, api, register_method, list_method in resources:
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
@ -103,55 +90,14 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
)
print("")
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
await register_resources(run_config, impls)
return impls
# NOTE: this code path is really for the tests so you can send HTTP requests
# to the remote stack without needing to use llama-stack-client
async def maybe_get_remote_stack_impls(
run_config: StackRunConfig,
) -> Optional[Dict[Api, Any]]:
remote_config = remote_provider_config(run_config)
if not remote_config:
return None
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in run_config.apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
remote_config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
remote_config,
{},
)
return impls
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config

View file

@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],

View file

@ -17,11 +17,25 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def construct_stack_for_test(run_config: StackRunConfig):
remote_config = remote_provider_config(run_config)
if not remote_config:
return await construct_stack(run_config)
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
return impls
async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
@ -49,7 +63,7 @@ async def resolve_impls_for_test_v2(
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
impls = await construct_stack(run_config)
impls = await construct_stack_for_test(run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
@ -62,6 +76,24 @@ async def resolve_impls_for_test_v2(
return impls
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "test::remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(