Add an introspection "Api.inspect" API

This commit is contained in:
Ashwin Bharambe 2024-10-02 15:13:24 -07:00
parent 01d93be948
commit 8d049000e3
14 changed files with 619 additions and 174 deletions

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
from typing import Any, Dict, List, Set
@ -11,7 +12,8 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis,
get_provider_registry,
)
from llama_stack.distribution.utils.dynamic import instantiate_provider
from llama_stack.distribution.inspect import DistributionInspectImpl
from llama_stack.distribution.utils.dynamic import instantiate_class_type
async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]:
@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve:
continue
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
impls[api] = impl
impls[Api.inspect] = DistributionInspectImpl()
specs[Api.inspect] = InlineProviderSpec(
api=Api.inspect,
provider_type="__distribution_builtin__",
config_class="",
module="",
)
return impls, specs
@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]:
dfs(a, visited, stack)
return [by_id[x] for x in stack]
# returns a class implementing the protocol corresponding to the Api
async def instantiate_provider(
provider_spec: ProviderSpec,
deps: Dict[str, Any],
provider_config: Union[GenericProviderConfig, RoutingTable],
):
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):
if provider_spec.adapter:
method = "get_adapter_impl"
else:
method = "get_client_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_type: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_type],
deps,
routing_entry,
)
inner_impls.append((routing_entry.routing_key, impl))
config = None
args = [provider_spec.api, inner_impls, routing_table, deps]
else:
method = "get_provider_impl"
assert isinstance(provider_config, GenericProviderConfig)
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider_config.config)
args = [config, deps]
fn = getattr(module, method)
impl = await fn(*args)
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
return impl