mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
remote tests are functional!
This commit is contained in:
parent
8b7be87bec
commit
0121114a5d
5 changed files with 93 additions and 89 deletions
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
from llama_stack.distribution.client import get_client_impl
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -308,7 +309,7 @@ async def instantiate_provider(
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
and provider_spec.api in additional_protocols
|
and provider_spec.api in additional_protocols
|
||||||
):
|
):
|
||||||
additional_api, _ = additional_protocols[provider_spec.api]
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
@ -354,3 +355,31 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_remote_stack_impls(
|
||||||
|
config: RemoteProviderConfig,
|
||||||
|
apis: List[str],
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
|
protocols = api_protocol_map()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
|
impls = {}
|
||||||
|
for api_str in apis:
|
||||||
|
api = Api(api_str)
|
||||||
|
impls[api] = await get_client_impl(
|
||||||
|
protocols[api],
|
||||||
|
None,
|
||||||
|
config,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
if api in additional_protocols:
|
||||||
|
_, additional_protocol, additional_api = additional_protocols[api]
|
||||||
|
impls[additional_api] = await get_client_impl(
|
||||||
|
additional_protocol,
|
||||||
|
None,
|
||||||
|
config,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls
|
||||||
|
|
|
@ -87,13 +87,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
|
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
|
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
p.scoring_function_store = self
|
p.scoring_function_store = self
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
|
|
|
@ -28,16 +28,11 @@ from llama_stack.apis.shields import * # noqa: F403
|
||||||
from llama_stack.apis.inspect import * # noqa: F403
|
from llama_stack.apis.inspect import * # noqa: F403
|
||||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.client import get_client_impl
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import (
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
additional_protocols_map,
|
|
||||||
api_protocol_map,
|
|
||||||
resolve_impls,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
|
||||||
class LlamaStack(
|
class LlamaStack(
|
||||||
|
@ -63,31 +58,23 @@ class LlamaStack(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
RESOURCES = [
|
||||||
# asked for in the run config.
|
("models", Api.models, "register_model", "list_models"),
|
||||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
("shields", Api.shields, "register_shield", "list_shields"),
|
||||||
dist_registry, _ = await create_dist_registry(
|
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||||
run_config.metadata_store, run_config.image_name
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||||
)
|
(
|
||||||
|
"scoring_fns",
|
||||||
|
Api.scoring_functions,
|
||||||
|
"register_scoring_function",
|
||||||
|
"list_scoring_functions",
|
||||||
|
),
|
||||||
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
|
]
|
||||||
|
|
||||||
impls = await maybe_get_remote_stack_impls(run_config)
|
|
||||||
if impls is None:
|
|
||||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
|
||||||
|
|
||||||
resources = [
|
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
("models", Api.models, "register_model", "list_models"),
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
("shields", Api.shields, "register_shield", "list_shields"),
|
|
||||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
|
||||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
|
||||||
(
|
|
||||||
"scoring_fns",
|
|
||||||
Api.scoring_functions,
|
|
||||||
"register_scoring_function",
|
|
||||||
"list_scoring_functions",
|
|
||||||
),
|
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
|
||||||
]
|
|
||||||
for rsrc, api, register_method, list_method in resources:
|
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
|
@ -103,55 +90,14 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
|
dist_registry, _ = await create_dist_registry(
|
||||||
|
run_config.metadata_store, run_config.image_name
|
||||||
|
)
|
||||||
|
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||||
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
# NOTE: this code path is really for the tests so you can send HTTP requests
|
|
||||||
# to the remote stack without needing to use llama-stack-client
|
|
||||||
async def maybe_get_remote_stack_impls(
|
|
||||||
run_config: StackRunConfig,
|
|
||||||
) -> Optional[Dict[Api, Any]]:
|
|
||||||
remote_config = remote_provider_config(run_config)
|
|
||||||
if not remote_config:
|
|
||||||
return None
|
|
||||||
|
|
||||||
protocols = api_protocol_map()
|
|
||||||
additional_protocols = additional_protocols_map()
|
|
||||||
|
|
||||||
impls = {}
|
|
||||||
for api_str in run_config.apis:
|
|
||||||
api = Api(api_str)
|
|
||||||
impls[api] = await get_client_impl(
|
|
||||||
protocols[api],
|
|
||||||
None,
|
|
||||||
remote_config,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
if api in additional_protocols:
|
|
||||||
_, additional_protocol, additional_api = additional_protocols[api]
|
|
||||||
impls[additional_api] = await get_client_impl(
|
|
||||||
additional_protocol,
|
|
||||||
None,
|
|
||||||
remote_config,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_config(
|
|
||||||
run_config: StackRunConfig,
|
|
||||||
) -> Optional[RemoteProviderConfig]:
|
|
||||||
remote_config = None
|
|
||||||
has_non_remote = False
|
|
||||||
for api_providers in run_config.providers.values():
|
|
||||||
for provider in api_providers:
|
|
||||||
if provider.provider_type == "remote":
|
|
||||||
remote_config = RemoteProviderConfig(**provider.config)
|
|
||||||
else:
|
|
||||||
has_non_remote = True
|
|
||||||
|
|
||||||
if remote_config:
|
|
||||||
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
|
||||||
|
|
||||||
return remote_config
|
|
||||||
|
|
|
@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="remote",
|
provider_id="test::remote",
|
||||||
provider_type="remote",
|
provider_type="test::remote",
|
||||||
config=config.model_dump(),
|
config=config.model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -17,11 +17,25 @@ from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls
|
||||||
from llama_stack.distribution.stack import construct_stack
|
from llama_stack.distribution.stack import construct_stack
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def construct_stack_for_test(run_config: StackRunConfig):
|
||||||
|
remote_config = remote_provider_config(run_config)
|
||||||
|
if not remote_config:
|
||||||
|
return await construct_stack(run_config)
|
||||||
|
|
||||||
|
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
||||||
|
|
||||||
|
# we don't register resources for a remote stack as part of the fixture setup
|
||||||
|
# because the stack is already "up". if a test needs to register resources, it
|
||||||
|
# can do so manually always.
|
||||||
|
|
||||||
|
return impls
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
async def resolve_impls_for_test_v2(
|
||||||
apis: List[Api],
|
apis: List[Api],
|
||||||
providers: Dict[str, List[Provider]],
|
providers: Dict[str, List[Provider]],
|
||||||
|
@ -49,7 +63,7 @@ async def resolve_impls_for_test_v2(
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
try:
|
try:
|
||||||
impls = await construct_stack(run_config)
|
impls = await construct_stack_for_test(run_config)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print_pip_install_help(providers)
|
print_pip_install_help(providers)
|
||||||
raise e
|
raise e
|
||||||
|
@ -62,6 +76,24 @@ async def resolve_impls_for_test_v2(
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
|
def remote_provider_config(
|
||||||
|
run_config: StackRunConfig,
|
||||||
|
) -> Optional[RemoteProviderConfig]:
|
||||||
|
remote_config = None
|
||||||
|
has_non_remote = False
|
||||||
|
for api_providers in run_config.providers.values():
|
||||||
|
for provider in api_providers:
|
||||||
|
if provider.provider_type == "test::remote":
|
||||||
|
remote_config = RemoteProviderConfig(**provider.config)
|
||||||
|
else:
|
||||||
|
has_non_remote = True
|
||||||
|
|
||||||
|
if remote_config:
|
||||||
|
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||||
|
|
||||||
|
return remote_config
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
||||||
if "PROVIDER_CONFIG" not in os.environ:
|
if "PROVIDER_CONFIG" not in os.environ:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue