From 0f4790f531339544c63399eef34ce627f1f8a930 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Tue, 2 Dec 2025 11:59:14 -0500 Subject: [PATCH] fix: address review comments 1. building: bool is now listing: bool 2. self.config.conf is now self.stack_config Signed-off-by: Charlie Doern --- src/llama_stack/core/build.py | 2 +- src/llama_stack/core/distribution.py | 18 ++++---- src/llama_stack/core/inspect.py | 4 +- src/llama_stack/core/prompts/prompts.py | 4 +- src/llama_stack/core/providers.py | 4 +- tests/unit/distribution/test_distribution.py | 48 ++++++++++---------- 6 files changed, 40 insertions(+), 40 deletions(-) diff --git a/src/llama_stack/core/build.py b/src/llama_stack/core/build.py index 6c53e1439..52478472c 100644 --- a/src/llama_stack/core/build.py +++ b/src/llama_stack/core/build.py @@ -46,7 +46,7 @@ def get_provider_dependencies( deps = [] external_provider_deps = [] - registry = get_provider_registry(config, True) + registry = get_provider_registry(config=config, listing=True) for api_str, provider_or_providers in providers.items(): providers_for_api = registry[Api(api_str)] diff --git a/src/llama_stack/core/distribution.py b/src/llama_stack/core/distribution.py index 554b99ddb..97f1edcd5 100644 --- a/src/llama_stack/core/distribution.py +++ b/src/llama_stack/core/distribution.py @@ -86,7 +86,7 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam def get_provider_registry( - config: StackConfig | None = None, building: bool = False + config: StackConfig | None = None, listing: bool = False ) -> dict[Api, dict[str, ProviderSpec]]: """Get the provider registry, optionally including external providers. @@ -111,13 +111,13 @@ def get_provider_registry( safety/ llama-guard.yaml - This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction. - So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet. + This method is overloaded in that it can be called from a variety of places: during list-deps, during run, during stack construction. + So when listing external providers from a module, there are scenarios where the pip package required to import the module might not be available yet. There is special handling for all of the potential cases this method can be called from. Args: config: Optional object containing the external providers directory path - building: Optional bool delineating whether or not this is being called from a build process + listing: Optional bool delineating whether or not this is being called from a list-deps process Returns: A dictionary mapping APIs to their available providers @@ -163,7 +163,7 @@ def get_provider_registry( registry = get_external_providers_from_module( registry=registry, config=config, - building=building, + listing=listing, ) return registry @@ -222,7 +222,7 @@ def get_external_providers_from_dir( def get_external_providers_from_module( - registry: dict[Api, dict[str, ProviderSpec]], config, building: bool + registry: dict[Api, dict[str, ProviderSpec]], config, listing: bool ) -> dict[Api, dict[str, ProviderSpec]]: provider_list = None provider_list = config.providers.items() @@ -235,14 +235,14 @@ def get_external_providers_from_module( continue # get provider using module try: - if not building: + if not listing: package_name = provider.module.split("==")[0] module = importlib.import_module(f"{package_name}.provider") # if config class is wrong you will get an error saying module could not be imported spec = module.get_provider_spec() else: - # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run - # in the case we are building we CANNOT import this module of course because it has not been installed. + # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon list-deps and run + # in the case we are listing we CANNOT import this module of course because it has not been installed. spec = ProviderSpec( api=Api(provider_api), provider_type=provider.provider_type, diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index f14326f2d..d6d89a82c 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -33,14 +33,14 @@ async def get_provider_impl(config, deps): class DistributionInspectImpl(Inspect): def __init__(self, config: DistributionInspectConfig, deps): - self.config = config + self.stack_config = config.config self.deps = deps async def initialize(self) -> None: pass async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse: - config: StackConfig = self.config.config + config: StackConfig = self.stack_config # Helper function to determine if a route should be included based on api_filter def should_include_route(webmethod) -> bool: diff --git a/src/llama_stack/core/prompts/prompts.py b/src/llama_stack/core/prompts/prompts.py index 44e560091..f2a604b37 100644 --- a/src/llama_stack/core/prompts/prompts.py +++ b/src/llama_stack/core/prompts/prompts.py @@ -34,13 +34,13 @@ class PromptServiceImpl(Prompts): """Built-in prompt service implementation using KVStore.""" def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]): - self.config = config + self.stack_config = config.config self.deps = deps self.kvstore: KVStore async def initialize(self) -> None: # Use prompts store reference from run config - prompts_ref = self.config.config.storage.stores.prompts + prompts_ref = self.stack_config.storage.stores.prompts if not prompts_ref: raise ValueError("storage.stores.prompts must be configured in run config") self.kvstore = await kvstore_impl(prompts_ref) diff --git a/src/llama_stack/core/providers.py b/src/llama_stack/core/providers.py index 85f2f9221..2514e8775 100644 --- a/src/llama_stack/core/providers.py +++ b/src/llama_stack/core/providers.py @@ -30,7 +30,7 @@ async def get_provider_impl(config, deps): class ProviderImpl(Providers): def __init__(self, config, deps): - self.config = config + self.stack_config = config.config self.deps = deps async def initialize(self) -> None: @@ -41,7 +41,7 @@ class ProviderImpl(Providers): pass async def list_providers(self) -> ListProvidersResponse: - run_config = self.config.config + run_config = self.stack_config safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump())) providers_health = await self.get_providers_health() ret = [] diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index 762d8219f..4884f70ba 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -270,7 +270,7 @@ class TestProviderRegistry: external_providers_dir="/nonexistent/dir", ) with pytest.raises(FileNotFoundError): - get_provider_registry(config) + get_provider_registry(config=config) def test_empty_api_directory(self, api_directories, mock_providers, base_config): """Test handling of empty API directory.""" @@ -339,7 +339,7 @@ pip_packages: ] }, ) - registry = get_provider_registry(config) + registry = get_provider_registry(config=config) assert Api.inference in registry assert "external_test" in registry[Api.inference] provider = registry[Api.inference]["external_test"] @@ -368,7 +368,7 @@ pip_packages: }, ) with pytest.raises(ValueError) as exc_info: - get_provider_registry(config) + get_provider_registry(config=config) assert "get_provider_spec not found" in str(exc_info.value) def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): @@ -391,14 +391,14 @@ pip_packages: }, ) with pytest.raises(AttributeError): - get_provider_registry(config) + get_provider_registry(config=config) - def test_external_provider_from_module_building(self, mock_providers): - """Test loading an external provider from a module during build (building=True, partial spec).""" + def test_external_provider_from_module_listing(self, mock_providers): + """Test loading an external provider from a module during list-deps (listing=True, partial spec).""" from llama_stack.core.datatypes import StackConfig from llama_stack_api import Api - # No importlib patch needed, should not import module when building + # No importlib patch needed, should not import module when listing config = StackConfig( image_name="test_image", apis=[], @@ -413,7 +413,7 @@ pip_packages: ] }, ) - registry = get_provider_registry(config, building=True) + registry = get_provider_registry(config=config, listing=True) assert Api.inference in registry assert "external_test" in registry[Api.inference] provider = registry[Api.inference]["external_test"] @@ -446,7 +446,7 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) # Should not add anything to registry assert len(result[Api.inference]) == 0 @@ -485,12 +485,12 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) assert "versioned_test" in result[Api.inference] assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0" def test_buildconfig_does_not_import_module(self, mock_providers): - """Test that StackConfig does not import the module when building (building=True).""" + """Test that StackConfig does not import the module when listing (listing=True).""" from llama_stack.core.datatypes import StackConfig from llama_stack.core.distribution import get_external_providers_from_module @@ -509,10 +509,10 @@ class TestGetExternalProvidersFromModule: }, ) - # Should not call import_module at all when building + # Should not call import_module at all when listing with patch("importlib.import_module") as mock_import: registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=True) + result = get_external_providers_from_module(registry, config, listing=True) # Verify module was NOT imported mock_import.assert_not_called() @@ -543,14 +543,14 @@ class TestGetExternalProvidersFromModule: with patch("importlib.import_module") as mock_import: registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=True) + result = get_external_providers_from_module(registry, config, listing=True) mock_import.assert_not_called() assert "provider1" in result[Api.inference] assert "provider2" in result[Api.inference] def test_distributionspec_does_not_import_module(self, mock_providers): - """Test that DistributionSpec does not import the module (building=True).""" + """Test that DistributionSpec does not import the module (listing=True).""" from llama_stack.core.datatypes import BuildProvider, DistributionSpec from llama_stack.core.distribution import get_external_providers_from_module @@ -566,10 +566,10 @@ class TestGetExternalProvidersFromModule: }, ) - # Should not call import_module at all when building + # Should not call import_module at all when listing with patch("importlib.import_module") as mock_import: registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, dist_spec, building=True) + result = get_external_providers_from_module(registry, dist_spec, listing=True) # Verify module was NOT imported mock_import.assert_not_called() @@ -623,7 +623,7 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) # Only the matching provider_type should be added assert "list_test" in result[Api.inference] @@ -671,7 +671,7 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) # Only the matching provider_type should be added assert "wanted" in result[Api.inference] @@ -726,7 +726,7 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) # Both provider types should be added to registry assert "remote::ollama" in result[Api.inference] @@ -760,7 +760,7 @@ class TestGetExternalProvidersFromModule: registry = {Api.inference: {}} with pytest.raises(ValueError) as exc_info: - get_external_providers_from_module(registry, config, building=False) + get_external_providers_from_module(registry, config, listing=False) assert "get_provider_spec not found" in str(exc_info.value) @@ -797,7 +797,7 @@ class TestGetExternalProvidersFromModule: registry = {Api.inference: {}} with pytest.raises(RuntimeError) as exc_info: - get_external_providers_from_module(registry, config, building=False) + get_external_providers_from_module(registry, config, listing=False) assert "Something went wrong" in str(exc_info.value) @@ -810,7 +810,7 @@ class TestGetExternalProvidersFromModule: providers={}, ) registry = {Api.inference: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) # Should return registry unchanged assert result == registry @@ -866,7 +866,7 @@ class TestGetExternalProvidersFromModule: }, ) registry = {Api.inference: {}, Api.safety: {}} - result = get_external_providers_from_module(registry, config, building=False) + result = get_external_providers_from_module(registry, config, listing=False) assert "inf_test" in result[Api.inference] assert "safe_test" in result[Api.safety]