From 8be385c99434a30d3a60feaac2dd82cf264a1035 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 9 Oct 2024 22:47:18 -0700 Subject: [PATCH] Make sure we always serve the routing table APIs if the corresponding router APIs are being served --- llama_stack/distribution/resolver.py | 8 ++++---- llama_stack/distribution/server/server.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 1de52817f..a05e08cd7 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -109,7 +109,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An providers_with_specs[info.routing_table_api.value] = { "__builtin__": ProviderWithSpec( - provider_id="__builtin__", + provider_id="__routing_table__", provider_type="__routing_table__", config={}, spec=RoutingTableProviderSpec( @@ -124,7 +124,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An providers_with_specs[info.router_api.value] = { "__builtin__": ProviderWithSpec( - provider_id="__builtin__", + provider_id="__autorouted__", provider_type="__autorouted__", config={}, spec=AutoRoutedProviderSpec( @@ -162,9 +162,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An ) ) - print(f"Resolved {len(sorted_providers)} providers in topological order") + print(f"Resolved {len(sorted_providers)} providers") for api_str, provider in sorted_providers: - print(f" {api_str}: ({provider.provider_id}) {provider.spec.provider_type}") + print(f" {api_str} => {provider.provider_id}") print("") impls = {} diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 9f362e023..4348f52f8 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -26,6 +26,8 @@ from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated +from llama_stack.distribution.distribution import builtin_automatically_routed_apis + from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, @@ -285,6 +287,10 @@ def main( else: apis_to_serve = set(impls.keys()) + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in apis_to_serve: + apis_to_serve.add(inf.routing_table_api.value) + apis_to_serve.add("inspect") for api_str in apis_to_serve: api = Api(api_str)