More work towards making remote stacks usable from tests

This commit is contained in:
Ashwin Bharambe 2024-11-12 17:09:31 -08:00
parent 8645f8bc9e
commit 8b7be87bec
7 changed files with 91 additions and 99 deletions

View file

@ -28,11 +28,16 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resolver import (
additional_protocols_map,
api_protocol_map,
resolve_impls,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
class LlamaStack(
@ -65,7 +70,9 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
impls = await maybe_get_remote_stack_impls(run_config)
if impls is None:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
resources = [
("models", Api.models, "register_model", "list_models"),
@ -97,3 +104,54 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
print("")
return impls
# NOTE: this code path is really for the tests so you can send HTTP requests
# to the remote stack without needing to use llama-stack-client
async def maybe_get_remote_stack_impls(
run_config: StackRunConfig,
) -> Optional[Dict[Api, Any]]:
remote_config = remote_provider_config(run_config)
if not remote_config:
return None
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in run_config.apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
remote_config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
remote_config,
{},
)
return impls
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config