fix: address review comments

1. building: bool is now listing: bool
2. self.config.conf is now self.stack_config

Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
Charlie Doern 2025-12-02 11:59:14 -05:00
parent b064251f6a
commit 0f4790f531
6 changed files with 40 additions and 40 deletions

View file

@ -46,7 +46,7 @@ def get_provider_dependencies(
deps = []
external_provider_deps = []
registry = get_provider_registry(config, True)
registry = get_provider_registry(config=config, listing=True)
for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)]

View file

@ -86,7 +86,7 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
def get_provider_registry(
config: StackConfig | None = None, building: bool = False
config: StackConfig | None = None, listing: bool = False
) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
@ -111,13 +111,13 @@ def get_provider_registry(
safety/
llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
This method is overloaded in that it can be called from a variety of places: during list-deps, during run, during stack construction.
So when listing external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args:
config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
listing: Optional bool delineating whether or not this is being called from a list-deps process
Returns:
A dictionary mapping APIs to their available providers
@ -163,7 +163,7 @@ def get_provider_registry(
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=building,
listing=listing,
)
return registry
@ -222,7 +222,7 @@ def get_external_providers_from_dir(
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
registry: dict[Api, dict[str, ProviderSpec]], config, listing: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
provider_list = config.providers.items()
@ -235,14 +235,14 @@ def get_external_providers_from_module(
continue
# get provider using module
try:
if not building:
if not listing:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
# in the case we are building we CANNOT import this module of course because it has not been installed.
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon list-deps and run
# in the case we are listing we CANNOT import this module of course because it has not been installed.
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,

View file

@ -33,14 +33,14 @@ async def get_provider_impl(config, deps):
class DistributionInspectImpl(Inspect):
def __init__(self, config: DistributionInspectConfig, deps):
self.config = config
self.stack_config = config.config
self.deps = deps
async def initialize(self) -> None:
pass
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
config: StackConfig = self.config.config
config: StackConfig = self.stack_config
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:

View file

@ -34,13 +34,13 @@ class PromptServiceImpl(Prompts):
"""Built-in prompt service implementation using KVStore."""
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
self.config = config
self.stack_config = config.config
self.deps = deps
self.kvstore: KVStore
async def initialize(self) -> None:
# Use prompts store reference from run config
prompts_ref = self.config.config.storage.stores.prompts
prompts_ref = self.stack_config.storage.stores.prompts
if not prompts_ref:
raise ValueError("storage.stores.prompts must be configured in run config")
self.kvstore = await kvstore_impl(prompts_ref)

View file

@ -30,7 +30,7 @@ async def get_provider_impl(config, deps):
class ProviderImpl(Providers):
def __init__(self, config, deps):
self.config = config
self.stack_config = config.config
self.deps = deps
async def initialize(self) -> None:
@ -41,7 +41,7 @@ class ProviderImpl(Providers):
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.config
run_config = self.stack_config
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
providers_health = await self.get_providers_health()
ret = []

View file

@ -270,7 +270,7 @@ class TestProviderRegistry:
external_providers_dir="/nonexistent/dir",
)
with pytest.raises(FileNotFoundError):
get_provider_registry(config)
get_provider_registry(config=config)
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
"""Test handling of empty API directory."""
@ -339,7 +339,7 @@ pip_packages:
]
},
)
registry = get_provider_registry(config)
registry = get_provider_registry(config=config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
@ -368,7 +368,7 @@ pip_packages:
},
)
with pytest.raises(ValueError) as exc_info:
get_provider_registry(config)
get_provider_registry(config=config)
assert "get_provider_spec not found" in str(exc_info.value)
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
@ -391,14 +391,14 @@ pip_packages:
},
)
with pytest.raises(AttributeError):
get_provider_registry(config)
get_provider_registry(config=config)
def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec)."""
def test_external_provider_from_module_listing(self, mock_providers):
"""Test loading an external provider from a module during list-deps (listing=True, partial spec)."""
from llama_stack.core.datatypes import StackConfig
from llama_stack_api import Api
# No importlib patch needed, should not import module when building
# No importlib patch needed, should not import module when listing
config = StackConfig(
image_name="test_image",
apis=[],
@ -413,7 +413,7 @@ pip_packages:
]
},
)
registry = get_provider_registry(config, building=True)
registry = get_provider_registry(config=config, listing=True)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
@ -446,7 +446,7 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
@ -485,12 +485,12 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that StackConfig does not import the module when building (building=True)."""
"""Test that StackConfig does not import the module when listing (listing=True)."""
from llama_stack.core.datatypes import StackConfig
from llama_stack.core.distribution import get_external_providers_from_module
@ -509,10 +509,10 @@ class TestGetExternalProvidersFromModule:
},
)
# Should not call import_module at all when building
# Should not call import_module at all when listing
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=True)
result = get_external_providers_from_module(registry, config, listing=True)
# Verify module was NOT imported
mock_import.assert_not_called()
@ -543,14 +543,14 @@ class TestGetExternalProvidersFromModule:
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=True)
result = get_external_providers_from_module(registry, config, listing=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
"""Test that DistributionSpec does not import the module (listing=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
@ -566,10 +566,10 @@ class TestGetExternalProvidersFromModule:
},
)
# Should not call import_module at all when building
# Should not call import_module at all when listing
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
result = get_external_providers_from_module(registry, dist_spec, listing=True)
# Verify module was NOT imported
mock_import.assert_not_called()
@ -623,7 +623,7 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
@ -671,7 +671,7 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
@ -726,7 +726,7 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
@ -760,7 +760,7 @@ class TestGetExternalProvidersFromModule:
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
get_external_providers_from_module(registry, config, listing=False)
assert "get_provider_spec not found" in str(exc_info.value)
@ -797,7 +797,7 @@ class TestGetExternalProvidersFromModule:
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
get_external_providers_from_module(registry, config, listing=False)
assert "Something went wrong" in str(exc_info.value)
@ -810,7 +810,7 @@ class TestGetExternalProvidersFromModule:
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
# Should return registry unchanged
assert result == registry
@ -866,7 +866,7 @@ class TestGetExternalProvidersFromModule:
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
result = get_external_providers_from_module(registry, config, listing=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]