forked from phoenix-oss/llama-stack-mirror
feat: support postgresql inference store (#2310)
# What does this PR do? * Added support postgresql inference store * Added 'oracle' template that demos how to config postgresql stores (except for telemetry, which is not supported currently) ## Test Plan llama stack build --template oracle --image-type conda --run LLAMA_STACK_CONFIG=http://localhost:8321 pytest -s -v tests/integration/ --text-model accounts/fireworks/models/llama-v3p3-70b-instruct -k 'inference_store'
This commit is contained in:
parent
168c7113df
commit
2603f10f95
32 changed files with 516 additions and 53 deletions
|
@ -65,7 +65,7 @@ class SqliteKVStoreConfig(CommonConfig):
|
|||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
|
|
@ -19,10 +19,10 @@ from sqlalchemy import (
|
|||
Text,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
from ..api import ColumnDefinition, ColumnType, SqlStore
|
||||
from ..sqlstore import SqliteSqlStoreConfig
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||
|
||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||
ColumnType.INTEGER: Integer,
|
||||
|
@ -35,9 +35,10 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
class SqliteSqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqliteSqlStoreConfig):
|
||||
self.engine = create_async_engine(config.engine_str)
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
|
||||
self.metadata = MetaData()
|
||||
|
||||
async def create_table(
|
||||
|
@ -78,13 +79,14 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
|
||||
# Create the table in the database if it doesn't exist
|
||||
# checkfirst=True ensures it doesn't try to recreate if it's already there
|
||||
async with self.engine.begin() as conn:
|
||||
engine = create_async_engine(self.config.engine_str)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||
|
||||
async def insert(self, table: str, data: Mapping[str, Any]) -> None:
|
||||
async with self.engine.begin() as conn:
|
||||
await conn.execute(self.metadata.tables[table].insert(), data)
|
||||
await conn.commit()
|
||||
async with self.async_session() as session:
|
||||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
|
@ -93,7 +95,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
limit: int | None = None,
|
||||
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
query = select(self.metadata.tables[table])
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
|
@ -117,7 +119,7 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
query = query.order_by(self.metadata.tables[table].c[name].desc())
|
||||
else:
|
||||
raise ValueError(f"Invalid order '{order_type}' for column '{name}'")
|
||||
result = await conn.execute(query)
|
||||
result = await session.execute(query)
|
||||
if result.rowcount == 0:
|
||||
return []
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
@ -142,20 +144,20 @@ class SqliteSqlStoreImpl(SqlStore):
|
|||
if not where:
|
||||
raise ValueError("where is required for update")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt, data)
|
||||
await conn.commit()
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, table: str, where: Mapping[str, Any]) -> None:
|
||||
if not where:
|
||||
raise ValueError("where is required for delete")
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
await conn.execute(stmt)
|
||||
await conn.commit()
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
@ -21,7 +22,18 @@ class SqlStoreType(Enum):
|
|||
postgres = "postgres"
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(BaseModel):
|
||||
class SqlAlchemySqlStoreConfig(BaseModel):
|
||||
@property
|
||||
@abstractmethod
|
||||
def engine_str(self) -> str: ...
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
|
||||
|
||||
class SqliteSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["sqlite"] = SqlStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
|
||||
|
@ -39,18 +51,26 @@ class SqliteSqlStoreConfig(BaseModel):
|
|||
db_path="${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
)
|
||||
|
||||
# TODO: move this when we have a better way to specify dependencies with internal APIs
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return ["sqlalchemy[asyncio]"]
|
||||
return super().pip_packages + ["aiosqlite"]
|
||||
|
||||
|
||||
class PostgresSqlStoreConfig(BaseModel):
|
||||
class PostgresSqlStoreConfig(SqlAlchemySqlStoreConfig):
|
||||
type: Literal["postgres"] = SqlStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: str = "5432"
|
||||
db: str = "llamastack"
|
||||
user: str
|
||||
password: str | None = None
|
||||
|
||||
@property
|
||||
def engine_str(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
return super().pip_packages + ["asyncpg"]
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
@ -60,12 +80,10 @@ SqlStoreConfig = Annotated[
|
|||
|
||||
|
||||
def sqlstore_impl(config: SqlStoreConfig) -> SqlStore:
|
||||
if config.type == SqlStoreType.sqlite.value:
|
||||
from .sqlite.sqlite import SqliteSqlStoreImpl
|
||||
if config.type in [SqlStoreType.sqlite.value, SqlStoreType.postgres.value]:
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
impl = SqliteSqlStoreImpl(config)
|
||||
elif config.type == SqlStoreType.postgres.value:
|
||||
raise NotImplementedError("Postgres is not implemented yet")
|
||||
impl = SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {config.type}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue