fix: address review comments

1. building: bool is now listing: bool
2. self.config.conf is now self.stack_config

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-12-02 11:59:14 -05:00
parent b064251f6a
commit 0f4790f531
6 changed files with 40 additions and 40 deletions

View file

@ -46,7 +46,7 @@ def get_provider_dependencies(
deps = [] deps = []
external_provider_deps = [] external_provider_deps = []
registry = get_provider_registry(config, True) registry = get_provider_registry(config=config, listing=True)
for api_str, provider_or_providers in providers.items(): for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)] providers_for_api = registry[Api(api_str)]

View file

@ -86,7 +86,7 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
def get_provider_registry( def get_provider_registry(
config: StackConfig | None = None, building: bool = False config: StackConfig | None = None, listing: bool = False
) -> dict[Api, dict[str, ProviderSpec]]: ) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers. """Get the provider registry, optionally including external providers.
@ -111,13 +111,13 @@ def get_provider_registry(
safety/ safety/
llama-guard.yaml llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction. This method is overloaded in that it can be called from a variety of places: during list-deps, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet. So when listing external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from. There is special handling for all of the potential cases this method can be called from.
Args: Args:
config: Optional object containing the external providers directory path config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process listing: Optional bool delineating whether or not this is being called from a list-deps process
Returns: Returns:
A dictionary mapping APIs to their available providers A dictionary mapping APIs to their available providers
@ -163,7 +163,7 @@ def get_provider_registry(
registry = get_external_providers_from_module( registry = get_external_providers_from_module(
registry=registry, registry=registry,
config=config, config=config,
building=building, listing=listing,
) )
return registry return registry
@ -222,7 +222,7 @@ def get_external_providers_from_dir(
def get_external_providers_from_module( def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool registry: dict[Api, dict[str, ProviderSpec]], config, listing: bool
) -> dict[Api, dict[str, ProviderSpec]]: ) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None provider_list = None
provider_list = config.providers.items() provider_list = config.providers.items()
@ -235,14 +235,14 @@ def get_external_providers_from_module(
continue continue
# get provider using module # get provider using module
try: try:
if not building: if not listing:
package_name = provider.module.split("==")[0] package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider") module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported # if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec() spec = module.get_provider_spec()
else: else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon list-deps and run
# in the case we are building we CANNOT import this module of course because it has not been installed. # in the case we are listing we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec( spec = ProviderSpec(
api=Api(provider_api), api=Api(provider_api),
provider_type=provider.provider_type, provider_type=provider.provider_type,

View file

@ -33,14 +33,14 @@ async def get_provider_impl(config, deps):
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self, config: DistributionInspectConfig, deps): def __init__(self, config: DistributionInspectConfig, deps):
self.config = config self.stack_config = config.config
self.deps = deps self.deps = deps
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse: async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
config: StackConfig = self.config.config config: StackConfig = self.stack_config
# Helper function to determine if a route should be included based on api_filter # Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool: def should_include_route(webmethod) -> bool:

View file

@ -34,13 +34,13 @@ class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore.""" """Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]): def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config self.stack_config = config.config
self.deps = deps self.deps = deps
self.kvstore: KVStore self.kvstore: KVStore
async def initialize(self) -> None: async def initialize(self) -> None:
# Use prompts store reference from run config # Use prompts store reference from run config
prompts_ref = self.config.config.storage.stores.prompts prompts_ref = self.stack_config.storage.stores.prompts
if not prompts_ref: if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config") raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref) self.kvstore = await kvstore_impl(prompts_ref)

View file

@ -30,7 +30,7 @@ async def get_provider_impl(config, deps):
class ProviderImpl(Providers): class ProviderImpl(Providers):
def __init__(self, config, deps): def __init__(self, config, deps):
self.config = config self.stack_config = config.config
self.deps = deps self.deps = deps
async def initialize(self) -> None: async def initialize(self) -> None:
@ -41,7 +41,7 @@ class ProviderImpl(Providers):
pass pass
async def list_providers(self) -> ListProvidersResponse: async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.config run_config = self.stack_config
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump())) safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health() providers_health = await self.get_providers_health()
ret = [] ret = []

View file

@ -270,7 +270,7 @@ class TestProviderRegistry:
external_providers_dir="/nonexistent/dir", external_providers_dir="/nonexistent/dir",
) )
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
get_provider_registry(config) get_provider_registry(config=config)
def test_empty_api_directory(self, api_directories, mock_providers, base_config): def test_empty_api_directory(self, api_directories, mock_providers, base_config):
"""Test handling of empty API directory.""" """Test handling of empty API directory."""
@ -339,7 +339,7 @@ pip_packages:
] ]
}, },
) )
registry = get_provider_registry(config) registry = get_provider_registry(config=config)
assert Api.inference in registry assert Api.inference in registry
assert "external_test" in registry[Api.inference] assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"] provider = registry[Api.inference]["external_test"]
@ -368,7 +368,7 @@ pip_packages:
}, },
) )
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
get_provider_registry(config) get_provider_registry(config=config)
assert "get_provider_spec not found" in str(exc_info.value) assert "get_provider_spec not found" in str(exc_info.value)
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
@ -391,14 +391,14 @@ pip_packages:
}, },
) )
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
get_provider_registry(config) get_provider_registry(config=config)
def test_external_provider_from_module_building(self, mock_providers): def test_external_provider_from_module_listing(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec).""" """Test loading an external provider from a module during list-deps (listing=True, partial spec)."""
from llama_stack.core.datatypes import StackConfig from llama_stack.core.datatypes import StackConfig
from llama_stack_api import Api from llama_stack_api import Api
# No importlib patch needed, should not import module when building # No importlib patch needed, should not import module when listing
config = StackConfig( config = StackConfig(
image_name="test_image", image_name="test_image",
apis=[], apis=[],
@ -413,7 +413,7 @@ pip_packages:
] ]
}, },
) )
registry = get_provider_registry(config, building=True) registry = get_provider_registry(config=config, listing=True)
assert Api.inference in registry assert Api.inference in registry
assert "external_test" in registry[Api.inference] assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"] provider = registry[Api.inference]["external_test"]
@ -446,7 +446,7 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
# Should not add anything to registry # Should not add anything to registry
assert len(result[Api.inference]) == 0 assert len(result[Api.inference]) == 0
@ -485,12 +485,12 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
assert "versioned_test" in result[Api.inference] assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0" assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers): def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that StackConfig does not import the module when building (building=True).""" """Test that StackConfig does not import the module when listing (listing=True)."""
from llama_stack.core.datatypes import StackConfig from llama_stack.core.datatypes import StackConfig
from llama_stack.core.distribution import get_external_providers_from_module from llama_stack.core.distribution import get_external_providers_from_module
@ -509,10 +509,10 @@ class TestGetExternalProvidersFromModule:
}, },
) )
# Should not call import_module at all when building # Should not call import_module at all when listing
with patch("importlib.import_module") as mock_import: with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=True) result = get_external_providers_from_module(registry, config, listing=True)
# Verify module was NOT imported # Verify module was NOT imported
mock_import.assert_not_called() mock_import.assert_not_called()
@ -543,14 +543,14 @@ class TestGetExternalProvidersFromModule:
with patch("importlib.import_module") as mock_import: with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=True) result = get_external_providers_from_module(registry, config, listing=True)
mock_import.assert_not_called() mock_import.assert_not_called()
assert "provider1" in result[Api.inference] assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference] assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers): def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True).""" """Test that DistributionSpec does not import the module (listing=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module from llama_stack.core.distribution import get_external_providers_from_module
@ -566,10 +566,10 @@ class TestGetExternalProvidersFromModule:
}, },
) )
# Should not call import_module at all when building # Should not call import_module at all when listing
with patch("importlib.import_module") as mock_import: with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True) result = get_external_providers_from_module(registry, dist_spec, listing=True)
# Verify module was NOT imported # Verify module was NOT imported
mock_import.assert_not_called() mock_import.assert_not_called()
@ -623,7 +623,7 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
# Only the matching provider_type should be added # Only the matching provider_type should be added
assert "list_test" in result[Api.inference] assert "list_test" in result[Api.inference]
@ -671,7 +671,7 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
# Only the matching provider_type should be added # Only the matching provider_type should be added
assert "wanted" in result[Api.inference] assert "wanted" in result[Api.inference]
@ -726,7 +726,7 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
# Both provider types should be added to registry # Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference] assert "remote::ollama" in result[Api.inference]
@ -760,7 +760,7 @@ class TestGetExternalProvidersFromModule:
registry = {Api.inference: {}} registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False) get_external_providers_from_module(registry, config, listing=False)
assert "get_provider_spec not found" in str(exc_info.value) assert "get_provider_spec not found" in str(exc_info.value)
@ -797,7 +797,7 @@ class TestGetExternalProvidersFromModule:
registry = {Api.inference: {}} registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info: with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False) get_external_providers_from_module(registry, config, listing=False)
assert "Something went wrong" in str(exc_info.value) assert "Something went wrong" in str(exc_info.value)
@ -810,7 +810,7 @@ class TestGetExternalProvidersFromModule:
providers={}, providers={},
) )
registry = {Api.inference: {}} registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
# Should return registry unchanged # Should return registry unchanged
assert result == registry assert result == registry
@ -866,7 +866,7 @@ class TestGetExternalProvidersFromModule:
}, },
) )
registry = {Api.inference: {}, Api.safety: {}} registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False) result = get_external_providers_from_module(registry, config, listing=False)
assert "inf_test" in result[Api.inference] assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety] assert "safe_test" in result[Api.safety]