mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
fix: address review comments
1. building: bool is now listing: bool 2. self.config.conf is now self.stack_config Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
b064251f6a
commit
0f4790f531
6 changed files with 40 additions and 40 deletions
|
|
@ -46,7 +46,7 @@ def get_provider_dependencies(
|
|||
|
||||
deps = []
|
||||
external_provider_deps = []
|
||||
registry = get_provider_registry(config, True)
|
||||
registry = get_provider_registry(config=config, listing=True)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
providers_for_api = registry[Api(api_str)]
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
|
|||
|
||||
|
||||
def get_provider_registry(
|
||||
config: StackConfig | None = None, building: bool = False
|
||||
config: StackConfig | None = None, listing: bool = False
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
|
|
@ -111,13 +111,13 @@ def get_provider_registry(
|
|||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
This method is overloaded in that it can be called from a variety of places: during list-deps, during run, during stack construction.
|
||||
So when listing external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||
There is special handling for all of the potential cases this method can be called from.
|
||||
|
||||
Args:
|
||||
config: Optional object containing the external providers directory path
|
||||
building: Optional bool delineating whether or not this is being called from a build process
|
||||
listing: Optional bool delineating whether or not this is being called from a list-deps process
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
|
@ -163,7 +163,7 @@ def get_provider_registry(
|
|||
registry = get_external_providers_from_module(
|
||||
registry=registry,
|
||||
config=config,
|
||||
building=building,
|
||||
listing=listing,
|
||||
)
|
||||
|
||||
return registry
|
||||
|
|
@ -222,7 +222,7 @@ def get_external_providers_from_dir(
|
|||
|
||||
|
||||
def get_external_providers_from_module(
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
||||
registry: dict[Api, dict[str, ProviderSpec]], config, listing: bool
|
||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||
provider_list = None
|
||||
provider_list = config.providers.items()
|
||||
|
|
@ -235,14 +235,14 @@ def get_external_providers_from_module(
|
|||
continue
|
||||
# get provider using module
|
||||
try:
|
||||
if not building:
|
||||
if not listing:
|
||||
package_name = provider.module.split("==")[0]
|
||||
module = importlib.import_module(f"{package_name}.provider")
|
||||
# if config class is wrong you will get an error saying module could not be imported
|
||||
spec = module.get_provider_spec()
|
||||
else:
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon list-deps and run
|
||||
# in the case we are listing we CANNOT import this module of course because it has not been installed.
|
||||
spec = ProviderSpec(
|
||||
api=Api(provider_api),
|
||||
provider_type=provider.provider_type,
|
||||
|
|
|
|||
|
|
@ -33,14 +33,14 @@ async def get_provider_impl(config, deps):
|
|||
|
||||
class DistributionInspectImpl(Inspect):
|
||||
def __init__(self, config: DistributionInspectConfig, deps):
|
||||
self.config = config
|
||||
self.stack_config = config.config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||
config: StackConfig = self.config.config
|
||||
config: StackConfig = self.stack_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
|
|
|
|||
|
|
@ -34,13 +34,13 @@ class PromptServiceImpl(Prompts):
|
|||
"""Built-in prompt service implementation using KVStore."""
|
||||
|
||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.stack_config = config.config
|
||||
self.deps = deps
|
||||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# Use prompts store reference from run config
|
||||
prompts_ref = self.config.config.storage.stores.prompts
|
||||
prompts_ref = self.stack_config.storage.stores.prompts
|
||||
if not prompts_ref:
|
||||
raise ValueError("storage.stores.prompts must be configured in run config")
|
||||
self.kvstore = await kvstore_impl(prompts_ref)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ async def get_provider_impl(config, deps):
|
|||
|
||||
class ProviderImpl(Providers):
|
||||
def __init__(self, config, deps):
|
||||
self.config = config
|
||||
self.stack_config = config.config
|
||||
self.deps = deps
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
@ -41,7 +41,7 @@ class ProviderImpl(Providers):
|
|||
pass
|
||||
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
run_config = self.config.config
|
||||
run_config = self.stack_config
|
||||
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||
providers_health = await self.get_providers_health()
|
||||
ret = []
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class TestProviderRegistry:
|
|||
external_providers_dir="/nonexistent/dir",
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_provider_registry(config)
|
||||
get_provider_registry(config=config)
|
||||
|
||||
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of empty API directory."""
|
||||
|
|
@ -339,7 +339,7 @@ pip_packages:
|
|||
]
|
||||
},
|
||||
)
|
||||
registry = get_provider_registry(config)
|
||||
registry = get_provider_registry(config=config)
|
||||
assert Api.inference in registry
|
||||
assert "external_test" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["external_test"]
|
||||
|
|
@ -368,7 +368,7 @@ pip_packages:
|
|||
},
|
||||
)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_provider_registry(config)
|
||||
get_provider_registry(config=config)
|
||||
assert "get_provider_spec not found" in str(exc_info.value)
|
||||
|
||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||
|
|
@ -391,14 +391,14 @@ pip_packages:
|
|||
},
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
get_provider_registry(config)
|
||||
get_provider_registry(config=config)
|
||||
|
||||
def test_external_provider_from_module_building(self, mock_providers):
|
||||
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
||||
def test_external_provider_from_module_listing(self, mock_providers):
|
||||
"""Test loading an external provider from a module during list-deps (listing=True, partial spec)."""
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack_api import Api
|
||||
|
||||
# No importlib patch needed, should not import module when building
|
||||
# No importlib patch needed, should not import module when listing
|
||||
config = StackConfig(
|
||||
image_name="test_image",
|
||||
apis=[],
|
||||
|
|
@ -413,7 +413,7 @@ pip_packages:
|
|||
]
|
||||
},
|
||||
)
|
||||
registry = get_provider_registry(config, building=True)
|
||||
registry = get_provider_registry(config=config, listing=True)
|
||||
assert Api.inference in registry
|
||||
assert "external_test" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["external_test"]
|
||||
|
|
@ -446,7 +446,7 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
# Should not add anything to registry
|
||||
assert len(result[Api.inference]) == 0
|
||||
|
||||
|
|
@ -485,12 +485,12 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
assert "versioned_test" in result[Api.inference]
|
||||
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||
|
||||
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||
"""Test that StackConfig does not import the module when building (building=True)."""
|
||||
"""Test that StackConfig does not import the module when listing (listing=True)."""
|
||||
from llama_stack.core.datatypes import StackConfig
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
|
|
@ -509,10 +509,10 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
# Should not call import_module at all when listing
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=True)
|
||||
result = get_external_providers_from_module(registry, config, listing=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
|
@ -543,14 +543,14 @@ class TestGetExternalProvidersFromModule:
|
|||
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=True)
|
||||
result = get_external_providers_from_module(registry, config, listing=True)
|
||||
|
||||
mock_import.assert_not_called()
|
||||
assert "provider1" in result[Api.inference]
|
||||
assert "provider2" in result[Api.inference]
|
||||
|
||||
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||
"""Test that DistributionSpec does not import the module (building=True)."""
|
||||
"""Test that DistributionSpec does not import the module (listing=True)."""
|
||||
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||
from llama_stack.core.distribution import get_external_providers_from_module
|
||||
|
||||
|
|
@ -566,10 +566,10 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
|
||||
# Should not call import_module at all when building
|
||||
# Should not call import_module at all when listing
|
||||
with patch("importlib.import_module") as mock_import:
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
||||
result = get_external_providers_from_module(registry, dist_spec, listing=True)
|
||||
|
||||
# Verify module was NOT imported
|
||||
mock_import.assert_not_called()
|
||||
|
|
@ -623,7 +623,7 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "list_test" in result[Api.inference]
|
||||
|
|
@ -671,7 +671,7 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
# Only the matching provider_type should be added
|
||||
assert "wanted" in result[Api.inference]
|
||||
|
|
@ -726,7 +726,7 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
# Both provider types should be added to registry
|
||||
assert "remote::ollama" in result[Api.inference]
|
||||
|
|
@ -760,7 +760,7 @@ class TestGetExternalProvidersFromModule:
|
|||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
assert "get_provider_spec not found" in str(exc_info.value)
|
||||
|
||||
|
|
@ -797,7 +797,7 @@ class TestGetExternalProvidersFromModule:
|
|||
registry = {Api.inference: {}}
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
get_external_providers_from_module(registry, config, building=False)
|
||||
get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
assert "Something went wrong" in str(exc_info.value)
|
||||
|
||||
|
|
@ -810,7 +810,7 @@ class TestGetExternalProvidersFromModule:
|
|||
providers={},
|
||||
)
|
||||
registry = {Api.inference: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
# Should return registry unchanged
|
||||
assert result == registry
|
||||
|
|
@ -866,7 +866,7 @@ class TestGetExternalProvidersFromModule:
|
|||
},
|
||||
)
|
||||
registry = {Api.inference: {}, Api.safety: {}}
|
||||
result = get_external_providers_from_module(registry, config, building=False)
|
||||
result = get_external_providers_from_module(registry, config, listing=False)
|
||||
|
||||
assert "inf_test" in result[Api.inference]
|
||||
assert "safe_test" in result[Api.safety]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue