Make sure we always serve the routing table APIs if the corresponding router APIs are being served

This commit is contained in:
Ashwin Bharambe 2024-10-09 22:47:18 -07:00
parent a5d7caf21b
commit 8be385c994
2 changed files with 10 additions and 4 deletions

View file

@ -109,7 +109,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[info.routing_table_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_id="__routing_table__",
provider_type="__routing_table__",
config={},
spec=RoutingTableProviderSpec(
@ -124,7 +124,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__builtin__",
provider_id="__autorouted__",
provider_type="__autorouted__",
config={},
spec=AutoRoutedProviderSpec(
@ -162,9 +162,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
)
)
print(f"Resolved {len(sorted_providers)} providers in topological order")
print(f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}")
print(f" {api_str} => {provider.provider_id}")
print("")
impls = {}

View file

@ -26,6 +26,8 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
@ -285,6 +287,10 @@ def main(
else:
apis_to_serve = set(impls.keys())
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api.value)
apis_to_serve.add("inspect")
for api_str in apis_to_serve:
api = Api(api_str)