Make sure we always serve the routing table APIs if the corresponding router APIs are being served

This commit is contained in:
Ashwin Bharambe 2024-10-09 22:47:18 -07:00
parent a5d7caf21b
commit 8be385c994
2 changed files with 10 additions and 4 deletions

View file

@ -109,7 +109,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[info.routing_table_api.value] = { providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec( "__builtin__": ProviderWithSpec(
provider_id="__builtin__", provider_id="__routing_table__",
provider_type="__routing_table__", provider_type="__routing_table__",
config={}, config={},
spec=RoutingTableProviderSpec( spec=RoutingTableProviderSpec(
@ -124,7 +124,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[info.router_api.value] = { providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec( "__builtin__": ProviderWithSpec(
provider_id="__builtin__", provider_id="__autorouted__",
provider_type="__autorouted__", provider_type="__autorouted__",
config={}, config={},
spec=AutoRoutedProviderSpec( spec=AutoRoutedProviderSpec(
@ -162,9 +162,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
) )
) )
print(f"Resolved {len(sorted_providers)} providers in topological order") print(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers: for api_str, provider in sorted_providers:
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}") print(f" {api_str} => {provider.provider_id}")
print("") print("")
impls = {} impls = {}

View file

@ -26,6 +26,8 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint from termcolor import cprint
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
@ -285,6 +287,10 @@ def main(
else: else:
apis_to_serve = set(impls.keys()) apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect") apis_to_serve.add("inspect")
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)