diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index 977eb5393..c104b6764 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -26,6 +26,9 @@ from llama_stack.providers.datatypes import ( logger = get_logger(name=__name__, category="core") +INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts} + + def stack_apis() -> list[Api]: return list(Api) @@ -70,7 +73,7 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]: def providable_apis() -> list[Api]: routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} - return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers] + return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS] def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index c72106e46..c6c2eb2c7 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -12,7 +12,7 @@ import yaml from pydantic import BaseModel, Field, ValidationError from llama_stack.core.datatypes import Api, Provider, StackRunConfig -from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis from llama_stack.providers.datatypes import ProviderSpec @@ -152,6 +152,24 @@ class TestProviderRegistry: assert registry[Api.inference]["test_provider"].provider_type == "test_provider" assert registry[Api.inference]["test_provider"].api == Api.inference + def test_internal_apis_excluded(self): + """Test that internal APIs are excluded and APIs without provider registries are marked as internal.""" + import importlib + + apis = providable_apis() + + for internal_api in INTERNAL_APIS: + assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis" + + for api in apis: + module_name = f"llama_stack.providers.registry.{api.name.lower()}" + try: + importlib.import_module(module_name) + except ImportError as err: + raise AssertionError( + f"API {api} is in providable_apis but has no provider registry module ({module_name})" + ) from err + def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml): """Test loading external remote providers from YAML files.""" remote_dir, _ = api_directories