mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add deps dynamically based on metastore config (#2405)
# What does this PR do? ## Test Plan changed metastore in one of the templates, rerun distro gen, observe change in build.yaml
This commit is contained in:
parent
92b59a3377
commit
446893f791
6 changed files with 36 additions and 18 deletions
|
@ -43,23 +43,12 @@ def get_provider_dependencies(
|
|||
config: BuildConfig | DistributionTemplate,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
# Extract providers based on config type
|
||||
if isinstance(config, DistributionTemplate):
|
||||
providers = config.providers
|
||||
config = config.build_config()
|
||||
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
|
||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
run_configs = config.run_configs
|
||||
additional_pip_packages: list[str] = []
|
||||
if run_configs:
|
||||
for run_config in run_configs.values():
|
||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
elif isinstance(config, BuildConfig):
|
||||
providers = config.distribution_spec.providers
|
||||
additional_pip_packages = config.additional_pip_packages
|
||||
deps = []
|
||||
registry = get_provider_registry(config)
|
||||
for api_str, provider_or_providers in providers.items():
|
||||
|
@ -87,8 +76,7 @@ def get_provider_dependencies(
|
|||
else:
|
||||
normal_deps.append(package)
|
||||
|
||||
if additional_pip_packages:
|
||||
normal_deps.extend(additional_pip_packages)
|
||||
normal_deps.extend(additional_pip_packages or [])
|
||||
|
||||
return list(set(normal_deps)), list(set(special_deps))
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
"pandas",
|
||||
"scikit-learn",
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
|
||||
module="llama_stack.providers.inline.agents.meta_reference",
|
||||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -36,6 +36,10 @@ class RedisKVStoreConfig(CommonConfig):
|
|||
def url(self) -> str:
|
||||
return f"redis://{self.host}:{self.port}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["redis"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls):
|
||||
return {
|
||||
|
@ -53,6 +57,10 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["aiosqlite"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||
return {
|
||||
|
@ -100,6 +108,10 @@ class PostgresKVStoreConfig(CommonConfig):
|
|||
raise ValueError("Table name must be less than 63 characters")
|
||||
return v
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["psycopg2-binary"]
|
||||
|
||||
|
||||
class MongoDBKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
||||
|
@ -110,6 +122,10 @@ class MongoDBKVStoreConfig(CommonConfig):
|
|||
password: str | None = None
|
||||
collection_name: str = "llamastack_kvstore"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["pymongo"]
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||
return {
|
||||
|
|
|
@ -10,6 +10,13 @@ from .config import KVStoreConfig, KVStoreType
|
|||
|
||||
|
||||
def kvstore_dependencies():
|
||||
"""
|
||||
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||
|
||||
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||
This function returns the union of all dependencies for cases where the specific
|
||||
kvstore type is not known at declaration time (e.g., provider registries).
|
||||
"""
|
||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||
|
||||
|
||||
|
|
|
@ -21,4 +21,5 @@ distribution_spec:
|
|||
image_type: conda
|
||||
additional_pip_packages:
|
||||
- asyncpg
|
||||
- psycopg2-binary
|
||||
- sqlalchemy[asyncio]
|
||||
|
|
|
@ -186,8 +186,14 @@ class DistributionTemplate(BaseModel):
|
|||
additional_pip_packages: list[str] = []
|
||||
for run_config in self.run_configs.values():
|
||||
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
||||
|
||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||
# and providers, with a way to specify dependencies for them.
|
||||
if run_config_.inference_store:
|
||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||
if run_config_.metadata_store:
|
||||
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||
|
||||
if self.additional_pip_packages:
|
||||
additional_pip_packages.extend(self.additional_pip_packages)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue