remote tests are functional!

This commit is contained in:
Ashwin Bharambe 2024-11-12 19:17:58 -08:00
parent 8b7be87bec
commit 0121114a5d
5 changed files with 93 additions and 89 deletions

View file

@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -308,7 +309,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _ = additional_protocols[provider_spec.api]
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
@ -354,3 +355,31 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
config,
{},
)
return impls