mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 01:48:05 +00:00
fix: address review comments
1. building: bool is now listing: bool 2. self.config.conf is now self.stack_config Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
b064251f6a
commit
0f4790f531
6 changed files with 40 additions and 40 deletions
|
|
@ -46,7 +46,7 @@ def get_provider_dependencies(
|
||||||
|
|
||||||
deps = []
|
deps = []
|
||||||
external_provider_deps = []
|
external_provider_deps = []
|
||||||
registry = get_provider_registry(config, True)
|
registry = get_provider_registry(config=config, listing=True)
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
providers_for_api = registry[Api(api_str)]
|
providers_for_api = registry[Api(api_str)]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry(
|
def get_provider_registry(
|
||||||
config: StackConfig | None = None, building: bool = False
|
config: StackConfig | None = None, listing: bool = False
|
||||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
"""Get the provider registry, optionally including external providers.
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
|
|
@ -111,13 +111,13 @@ def get_provider_registry(
|
||||||
safety/
|
safety/
|
||||||
llama-guard.yaml
|
llama-guard.yaml
|
||||||
|
|
||||||
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
|
This method is overloaded in that it can be called from a variety of places: during list-deps, during run, during stack construction.
|
||||||
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
So when listing external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
|
||||||
There is special handling for all of the potential cases this method can be called from.
|
There is special handling for all of the potential cases this method can be called from.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional object containing the external providers directory path
|
config: Optional object containing the external providers directory path
|
||||||
building: Optional bool delineating whether or not this is being called from a build process
|
listing: Optional bool delineating whether or not this is being called from a list-deps process
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping APIs to their available providers
|
A dictionary mapping APIs to their available providers
|
||||||
|
|
@ -163,7 +163,7 @@ def get_provider_registry(
|
||||||
registry = get_external_providers_from_module(
|
registry = get_external_providers_from_module(
|
||||||
registry=registry,
|
registry=registry,
|
||||||
config=config,
|
config=config,
|
||||||
building=building,
|
listing=listing,
|
||||||
)
|
)
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
|
|
@ -222,7 +222,7 @@ def get_external_providers_from_dir(
|
||||||
|
|
||||||
|
|
||||||
def get_external_providers_from_module(
|
def get_external_providers_from_module(
|
||||||
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
|
registry: dict[Api, dict[str, ProviderSpec]], config, listing: bool
|
||||||
) -> dict[Api, dict[str, ProviderSpec]]:
|
) -> dict[Api, dict[str, ProviderSpec]]:
|
||||||
provider_list = None
|
provider_list = None
|
||||||
provider_list = config.providers.items()
|
provider_list = config.providers.items()
|
||||||
|
|
@ -235,14 +235,14 @@ def get_external_providers_from_module(
|
||||||
continue
|
continue
|
||||||
# get provider using module
|
# get provider using module
|
||||||
try:
|
try:
|
||||||
if not building:
|
if not listing:
|
||||||
package_name = provider.module.split("==")[0]
|
package_name = provider.module.split("==")[0]
|
||||||
module = importlib.import_module(f"{package_name}.provider")
|
module = importlib.import_module(f"{package_name}.provider")
|
||||||
# if config class is wrong you will get an error saying module could not be imported
|
# if config class is wrong you will get an error saying module could not be imported
|
||||||
spec = module.get_provider_spec()
|
spec = module.get_provider_spec()
|
||||||
else:
|
else:
|
||||||
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
|
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon list-deps and run
|
||||||
# in the case we are building we CANNOT import this module of course because it has not been installed.
|
# in the case we are listing we CANNOT import this module of course because it has not been installed.
|
||||||
spec = ProviderSpec(
|
spec = ProviderSpec(
|
||||||
api=Api(provider_api),
|
api=Api(provider_api),
|
||||||
provider_type=provider.provider_type,
|
provider_type=provider.provider_type,
|
||||||
|
|
|
||||||
|
|
@ -33,14 +33,14 @@ async def get_provider_impl(config, deps):
|
||||||
|
|
||||||
class DistributionInspectImpl(Inspect):
|
class DistributionInspectImpl(Inspect):
|
||||||
def __init__(self, config: DistributionInspectConfig, deps):
|
def __init__(self, config: DistributionInspectConfig, deps):
|
||||||
self.config = config
|
self.stack_config = config.config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||||
config: StackConfig = self.config.config
|
config: StackConfig = self.stack_config
|
||||||
|
|
||||||
# Helper function to determine if a route should be included based on api_filter
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
def should_include_route(webmethod) -> bool:
|
def should_include_route(webmethod) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -34,13 +34,13 @@ class PromptServiceImpl(Prompts):
|
||||||
"""Built-in prompt service implementation using KVStore."""
|
"""Built-in prompt service implementation using KVStore."""
|
||||||
|
|
||||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||||
self.config = config
|
self.stack_config = config.config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
self.kvstore: KVStore
|
self.kvstore: KVStore
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
# Use prompts store reference from run config
|
# Use prompts store reference from run config
|
||||||
prompts_ref = self.config.config.storage.stores.prompts
|
prompts_ref = self.stack_config.storage.stores.prompts
|
||||||
if not prompts_ref:
|
if not prompts_ref:
|
||||||
raise ValueError("storage.stores.prompts must be configured in run config")
|
raise ValueError("storage.stores.prompts must be configured in run config")
|
||||||
self.kvstore = await kvstore_impl(prompts_ref)
|
self.kvstore = await kvstore_impl(prompts_ref)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ async def get_provider_impl(config, deps):
|
||||||
|
|
||||||
class ProviderImpl(Providers):
|
class ProviderImpl(Providers):
|
||||||
def __init__(self, config, deps):
|
def __init__(self, config, deps):
|
||||||
self.config = config
|
self.stack_config = config.config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
|
|
@ -41,7 +41,7 @@ class ProviderImpl(Providers):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
run_config = self.config.config
|
run_config = self.stack_config
|
||||||
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
|
safe_config = StackConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||||
providers_health = await self.get_providers_health()
|
providers_health = await self.get_providers_health()
|
||||||
ret = []
|
ret = []
|
||||||
|
|
|
||||||
|
|
@ -270,7 +270,7 @@ class TestProviderRegistry:
|
||||||
external_providers_dir="/nonexistent/dir",
|
external_providers_dir="/nonexistent/dir",
|
||||||
)
|
)
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
get_provider_registry(config)
|
get_provider_registry(config=config)
|
||||||
|
|
||||||
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
||||||
"""Test handling of empty API directory."""
|
"""Test handling of empty API directory."""
|
||||||
|
|
@ -339,7 +339,7 @@ pip_packages:
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config=config)
|
||||||
assert Api.inference in registry
|
assert Api.inference in registry
|
||||||
assert "external_test" in registry[Api.inference]
|
assert "external_test" in registry[Api.inference]
|
||||||
provider = registry[Api.inference]["external_test"]
|
provider = registry[Api.inference]["external_test"]
|
||||||
|
|
@ -368,7 +368,7 @@ pip_packages:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_provider_registry(config)
|
get_provider_registry(config=config)
|
||||||
assert "get_provider_spec not found" in str(exc_info.value)
|
assert "get_provider_spec not found" in str(exc_info.value)
|
||||||
|
|
||||||
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
||||||
|
|
@ -391,14 +391,14 @@ pip_packages:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
get_provider_registry(config)
|
get_provider_registry(config=config)
|
||||||
|
|
||||||
def test_external_provider_from_module_building(self, mock_providers):
|
def test_external_provider_from_module_listing(self, mock_providers):
|
||||||
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
"""Test loading an external provider from a module during list-deps (listing=True, partial spec)."""
|
||||||
from llama_stack.core.datatypes import StackConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack_api import Api
|
from llama_stack_api import Api
|
||||||
|
|
||||||
# No importlib patch needed, should not import module when building
|
# No importlib patch needed, should not import module when listing
|
||||||
config = StackConfig(
|
config = StackConfig(
|
||||||
image_name="test_image",
|
image_name="test_image",
|
||||||
apis=[],
|
apis=[],
|
||||||
|
|
@ -413,7 +413,7 @@ pip_packages:
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = get_provider_registry(config, building=True)
|
registry = get_provider_registry(config=config, listing=True)
|
||||||
assert Api.inference in registry
|
assert Api.inference in registry
|
||||||
assert "external_test" in registry[Api.inference]
|
assert "external_test" in registry[Api.inference]
|
||||||
provider = registry[Api.inference]["external_test"]
|
provider = registry[Api.inference]["external_test"]
|
||||||
|
|
@ -446,7 +446,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
# Should not add anything to registry
|
# Should not add anything to registry
|
||||||
assert len(result[Api.inference]) == 0
|
assert len(result[Api.inference]) == 0
|
||||||
|
|
||||||
|
|
@ -485,12 +485,12 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
assert "versioned_test" in result[Api.inference]
|
assert "versioned_test" in result[Api.inference]
|
||||||
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
||||||
|
|
||||||
def test_buildconfig_does_not_import_module(self, mock_providers):
|
def test_buildconfig_does_not_import_module(self, mock_providers):
|
||||||
"""Test that StackConfig does not import the module when building (building=True)."""
|
"""Test that StackConfig does not import the module when listing (listing=True)."""
|
||||||
from llama_stack.core.datatypes import StackConfig
|
from llama_stack.core.datatypes import StackConfig
|
||||||
from llama_stack.core.distribution import get_external_providers_from_module
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
|
@ -509,10 +509,10 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not call import_module at all when building
|
# Should not call import_module at all when listing
|
||||||
with patch("importlib.import_module") as mock_import:
|
with patch("importlib.import_module") as mock_import:
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=True)
|
result = get_external_providers_from_module(registry, config, listing=True)
|
||||||
|
|
||||||
# Verify module was NOT imported
|
# Verify module was NOT imported
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
@ -543,14 +543,14 @@ class TestGetExternalProvidersFromModule:
|
||||||
|
|
||||||
with patch("importlib.import_module") as mock_import:
|
with patch("importlib.import_module") as mock_import:
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=True)
|
result = get_external_providers_from_module(registry, config, listing=True)
|
||||||
|
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
assert "provider1" in result[Api.inference]
|
assert "provider1" in result[Api.inference]
|
||||||
assert "provider2" in result[Api.inference]
|
assert "provider2" in result[Api.inference]
|
||||||
|
|
||||||
def test_distributionspec_does_not_import_module(self, mock_providers):
|
def test_distributionspec_does_not_import_module(self, mock_providers):
|
||||||
"""Test that DistributionSpec does not import the module (building=True)."""
|
"""Test that DistributionSpec does not import the module (listing=True)."""
|
||||||
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
||||||
from llama_stack.core.distribution import get_external_providers_from_module
|
from llama_stack.core.distribution import get_external_providers_from_module
|
||||||
|
|
||||||
|
|
@ -566,10 +566,10 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should not call import_module at all when building
|
# Should not call import_module at all when listing
|
||||||
with patch("importlib.import_module") as mock_import:
|
with patch("importlib.import_module") as mock_import:
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
result = get_external_providers_from_module(registry, dist_spec, listing=True)
|
||||||
|
|
||||||
# Verify module was NOT imported
|
# Verify module was NOT imported
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
@ -623,7 +623,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
# Only the matching provider_type should be added
|
# Only the matching provider_type should be added
|
||||||
assert "list_test" in result[Api.inference]
|
assert "list_test" in result[Api.inference]
|
||||||
|
|
@ -671,7 +671,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
# Only the matching provider_type should be added
|
# Only the matching provider_type should be added
|
||||||
assert "wanted" in result[Api.inference]
|
assert "wanted" in result[Api.inference]
|
||||||
|
|
@ -726,7 +726,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
# Both provider types should be added to registry
|
# Both provider types should be added to registry
|
||||||
assert "remote::ollama" in result[Api.inference]
|
assert "remote::ollama" in result[Api.inference]
|
||||||
|
|
@ -760,7 +760,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_external_providers_from_module(registry, config, building=False)
|
get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
assert "get_provider_spec not found" in str(exc_info.value)
|
assert "get_provider_spec not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
@ -797,7 +797,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
get_external_providers_from_module(registry, config, building=False)
|
get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
assert "Something went wrong" in str(exc_info.value)
|
assert "Something went wrong" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
@ -810,7 +810,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
providers={},
|
providers={},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}}
|
registry = {Api.inference: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
# Should return registry unchanged
|
# Should return registry unchanged
|
||||||
assert result == registry
|
assert result == registry
|
||||||
|
|
@ -866,7 +866,7 @@ class TestGetExternalProvidersFromModule:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
registry = {Api.inference: {}, Api.safety: {}}
|
registry = {Api.inference: {}, Api.safety: {}}
|
||||||
result = get_external_providers_from_module(registry, config, building=False)
|
result = get_external_providers_from_module(registry, config, listing=False)
|
||||||
|
|
||||||
assert "inf_test" in result[Api.inference]
|
assert "inf_test" in result[Api.inference]
|
||||||
assert "safe_test" in result[Api.safety]
|
assert "safe_test" in result[Api.safety]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue