mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-18 23:28:53 +00:00
More work towards making remote stacks usable from tests
This commit is contained in:
parent
8645f8bc9e
commit
8b7be87bec
7 changed files with 91 additions and 99 deletions
|
@ -28,11 +28,16 @@ from llama_stack.apis.shields import * # noqa: F403
|
|||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import (
|
||||
additional_protocols_map,
|
||||
api_protocol_map,
|
||||
resolve_impls,
|
||||
)
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
|
@ -65,7 +70,9 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
impls = await maybe_get_remote_stack_impls(run_config)
|
||||
if impls is None:
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
resources = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
|
@ -97,3 +104,54 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
|
||||
print("")
|
||||
return impls
|
||||
|
||||
|
||||
# NOTE: this code path is really for the tests so you can send HTTP requests
|
||||
# to the remote stack without needing to use llama-stack-client
|
||||
async def maybe_get_remote_stack_impls(
|
||||
run_config: StackRunConfig,
|
||||
) -> Optional[Dict[Api, Any]]:
|
||||
remote_config = remote_provider_config(run_config)
|
||||
if not remote_config:
|
||||
return None
|
||||
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in run_config.apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
None,
|
||||
remote_config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
None,
|
||||
remote_config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def remote_provider_config(
|
||||
run_config: StackRunConfig,
|
||||
) -> Optional[RemoteProviderConfig]:
|
||||
remote_config = None
|
||||
has_non_remote = False
|
||||
for api_providers in run_config.providers.values():
|
||||
for provider in api_providers:
|
||||
if provider.provider_type == "remote":
|
||||
remote_config = RemoteProviderConfig(**provider.config)
|
||||
else:
|
||||
has_non_remote = True
|
||||
|
||||
if remote_config:
|
||||
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||
|
||||
return remote_config
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue