ci: tests

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-06-12 16:39:08 +02:00
parent 5fdd4952a9
commit 67b6f79715
No known key found for this signature in database
25 changed files with 245 additions and 224 deletions

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
@ -105,6 +106,11 @@ def build_image(
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")

View file

@ -423,6 +423,10 @@ class BuildConfig(BaseModel):
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod

View file

@ -152,10 +152,15 @@ def get_provider_registry(
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
raise ImportError(
f"Failed to import external API module {name}. Is the external API package installed? {e}"
) from e
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:

View file

@ -8,23 +8,21 @@
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]:
def load_external_apis(config: StackRunConfig | BuildConfig) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig containing the external APIs directory path
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config:
return {}
if not hasattr(config, "external_apis_dir"):
return {}
@ -51,9 +49,9 @@ def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]:
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load external API spec from {yaml_path}: {e}")
raise e
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -100,8 +100,8 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
api_class = getattr(module, api_spec.protocol)
protocols[api] = api_class
except (ImportError, AttributeError) as e:
logger.warning(f"Failed to load external API {api_spec.name}: {e}")
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols

View file

@ -37,8 +37,6 @@ def get_all_api_routes(
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
# Lazy import to avoid circular dependency
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():