mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add deps dynamically based on metastore config (#2405)
# What does this PR do? ## Test Plan changed metastore in one of the templates, rerun distro gen, observe change in build.yaml
This commit is contained in:
parent
92b59a3377
commit
446893f791
6 changed files with 36 additions and 18 deletions
|
@ -43,23 +43,12 @@ def get_provider_dependencies(
|
||||||
config: BuildConfig | DistributionTemplate,
|
config: BuildConfig | DistributionTemplate,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""Get normal and special dependencies from provider configuration."""
|
"""Get normal and special dependencies from provider configuration."""
|
||||||
# Extract providers based on config type
|
|
||||||
if isinstance(config, DistributionTemplate):
|
if isinstance(config, DistributionTemplate):
|
||||||
providers = config.providers
|
config = config.build_config()
|
||||||
|
|
||||||
|
providers = config.distribution_spec.providers
|
||||||
|
additional_pip_packages = config.additional_pip_packages
|
||||||
|
|
||||||
# TODO: This is a hack to get the dependencies for internal APIs into build
|
|
||||||
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
|
||||||
# and providers, with a way to specify dependencies for them.
|
|
||||||
run_configs = config.run_configs
|
|
||||||
additional_pip_packages: list[str] = []
|
|
||||||
if run_configs:
|
|
||||||
for run_config in run_configs.values():
|
|
||||||
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
|
|
||||||
if run_config_.inference_store:
|
|
||||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
|
||||||
elif isinstance(config, BuildConfig):
|
|
||||||
providers = config.distribution_spec.providers
|
|
||||||
additional_pip_packages = config.additional_pip_packages
|
|
||||||
deps = []
|
deps = []
|
||||||
registry = get_provider_registry(config)
|
registry = get_provider_registry(config)
|
||||||
for api_str, provider_or_providers in providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
|
@ -87,8 +76,7 @@ def get_provider_dependencies(
|
||||||
else:
|
else:
|
||||||
normal_deps.append(package)
|
normal_deps.append(package)
|
||||||
|
|
||||||
if additional_pip_packages:
|
normal_deps.extend(additional_pip_packages or [])
|
||||||
normal_deps.extend(additional_pip_packages)
|
|
||||||
|
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
"pandas",
|
"pandas",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
]
|
]
|
||||||
+ kvstore_dependencies(),
|
+ kvstore_dependencies(), # TODO make this dynamic based on the kvstore config
|
||||||
module="llama_stack.providers.inline.agents.meta_reference",
|
module="llama_stack.providers.inline.agents.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
|
|
|
@ -36,6 +36,10 @@ class RedisKVStoreConfig(CommonConfig):
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return f"redis://{self.host}:{self.port}"
|
return f"redis://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> list[str]:
|
||||||
|
return ["redis"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls):
|
def sample_run_config(cls):
|
||||||
return {
|
return {
|
||||||
|
@ -53,6 +57,10 @@ class SqliteKVStoreConfig(CommonConfig):
|
||||||
description="File path for the sqlite database",
|
description="File path for the sqlite database",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> list[str]:
|
||||||
|
return ["aiosqlite"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
def sample_run_config(cls, __distro_dir__: str, db_name: str = "kvstore.db"):
|
||||||
return {
|
return {
|
||||||
|
@ -100,6 +108,10 @@ class PostgresKVStoreConfig(CommonConfig):
|
||||||
raise ValueError("Table name must be less than 63 characters")
|
raise ValueError("Table name must be less than 63 characters")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> list[str]:
|
||||||
|
return ["psycopg2-binary"]
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreConfig(CommonConfig):
|
class MongoDBKVStoreConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
|
||||||
|
@ -110,6 +122,10 @@ class MongoDBKVStoreConfig(CommonConfig):
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
collection_name: str = "llamastack_kvstore"
|
collection_name: str = "llamastack_kvstore"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pip_packages(self) -> list[str]:
|
||||||
|
return ["pymongo"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -10,6 +10,13 @@ from .config import KVStoreConfig, KVStoreType
|
||||||
|
|
||||||
|
|
||||||
def kvstore_dependencies():
|
def kvstore_dependencies():
|
||||||
|
"""
|
||||||
|
Returns all possible kvstore dependencies for registry/provider specifications.
|
||||||
|
|
||||||
|
NOTE: For specific kvstore implementations, use config.pip_packages instead.
|
||||||
|
This function returns the union of all dependencies for cases where the specific
|
||||||
|
kvstore type is not known at declaration time (e.g., provider registries).
|
||||||
|
"""
|
||||||
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,4 +21,5 @@ distribution_spec:
|
||||||
image_type: conda
|
image_type: conda
|
||||||
additional_pip_packages:
|
additional_pip_packages:
|
||||||
- asyncpg
|
- asyncpg
|
||||||
|
- psycopg2-binary
|
||||||
- sqlalchemy[asyncio]
|
- sqlalchemy[asyncio]
|
||||||
|
|
|
@ -186,8 +186,14 @@ class DistributionTemplate(BaseModel):
|
||||||
additional_pip_packages: list[str] = []
|
additional_pip_packages: list[str] = []
|
||||||
for run_config in self.run_configs.values():
|
for run_config in self.run_configs.values():
|
||||||
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
run_config_ = run_config.run_config(self.name, self.providers, self.container_image)
|
||||||
|
|
||||||
|
# TODO: This is a hack to get the dependencies for internal APIs into build
|
||||||
|
# We should have a better way to do this by formalizing the concept of "internal" APIs
|
||||||
|
# and providers, with a way to specify dependencies for them.
|
||||||
if run_config_.inference_store:
|
if run_config_.inference_store:
|
||||||
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
|
||||||
|
if run_config_.metadata_store:
|
||||||
|
additional_pip_packages.extend(run_config_.metadata_store.pip_packages)
|
||||||
|
|
||||||
if self.additional_pip_packages:
|
if self.additional_pip_packages:
|
||||||
additional_pip_packages.extend(self.additional_pip_packages)
|
additional_pip_packages.extend(self.additional_pip_packages)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue