llama-stack-mirror/llama_stack/providers/utils/sqlstore/sqlstore.py
Ashwin Bharambe b380cb463f
feat: add postgres deps to starter distro (#2360)
Once we have this, we can use the starter distro for the Kubernetes
cluster demos.
2025-06-03 11:04:23 -07:00

103 lines
3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import Annotated, Literal
from pydantic import BaseModel, Field
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from .api import SqlStore
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
class SqlStoreType(Enum):
sqlite = "sqlite"
postgres = "postgres"
class SqlAlchemySqlStoreConfig(BaseModel):
@property
@abstractmethod
def engine_str(self) -> str: ...
# TODO: move this when we have a better way to specify dependencies with internal APIs
@property
def pip_packages(self) -> list[str]:
return ["sqlalchemy[asyncio]"]
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["sqlite"] = SqlStoreType.sqlite.value
db_path: str = Field(
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
description="Database path, e.g. ~/.llama/distributions/ollama/sqlstore.db",
)
@property
def engine_str(self) -> str:
return "sqlite+aiosqlite:///" + Path(self.db_path).expanduser().as_posix()
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "sqlstore.db"):
return cls(
type="sqlite",
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
)
@property
def pip_packages(self) -> list[str]:
return super().pip_packages + ["aiosqlite"]
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
type: Literal["postgres"] = SqlStoreType.postgres.value
host: str = "localhost"
port: str = "5432"
db: str = "llamastack"
user: str
password: str | None = None
@property
def engine_str(self) -> str:
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
@property
def pip_packages(self) -> list[str]:
return super().pip_packages + ["asyncpg"]
@classmethod
def sample_run_config(cls, **kwargs):
return cls(
type="postgres",
host="${env.POSTGRES_HOST:localhost}",
port="${env.POSTGRES_PORT:5432}",
db="${env.POSTGRES_DB:llamastack}",
user="${env.POSTGRES_USER:llamastack}",
password="${env.POSTGRES_PASSWORD:llamastack}",
)
SqlStoreConfig = Annotated[
SqliteSqlStoreConfig | PostgresSqlStoreConfig,
Field(discriminator="type", default=SqlStoreType.sqlite.value),
]
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
impl = SqlAlchemySqlStoreImpl(config)
else:
raise ValueError(f"Unknown sqlstore type {config.type}")
return impl