Initial commit for postgresql dataset provider

Core implementation of postgresql dataset for llama stack

Signed-off-by: Josh Salomon <jsalomon@redhat.com>
This commit is contained in:
Josh Salomon 2025-03-26 12:15:40 +02:00
parent bdfe7fee92
commit 8e3b579df2
8 changed files with 512 additions and 0 deletions

View file

@ -176,6 +176,7 @@ class Datasets(Protocol):
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
TODO: Add postgresql example here
- {
"type": "rows",
"rows": [

View file

@ -430,6 +430,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
elif source.uri.startswith("postgresql"):
provider_id = "postgresql"
else:
provider_id = "localfs"
else:

View file

@ -36,4 +36,18 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
),
),
remote_provider_spec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="postgresql",
pip_packages=[
# "asyncpg",
# "datasets",
"psycopg",
"psycopg[pool]",
],
module="llama_stack.providers.remote.datasetio.postgresql",
config_class="llama_stack.providers.remote.datasetio.postgresql.PostgreSQLDatasetIOConfig",
),
),
]

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import PostgreSQLDatasetIOConfig
async def get_adapter_impl(
config: PostgreSQLDatasetIOConfig,
_deps,
):
from .postgresql import PostgreSQLDatasetIOImpl
impl = PostgreSQLDatasetIOImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from typing import Any, Dict
import os
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class PostgreSQLDatasetIOConfig(BaseModel):
kvstore: KVStoreConfig
pg_host: str = Field(default="localhost") # os.getenv("POSTGRES_HOST", "127.0.0.1")
pg_port: int = Field(default=5432)
# TODO - revise secutiry implications of using env vars for user and password
pg_user: str = Field(default="postgres")
pg_password: str = Field(default="fail")
pg_con_pool_size: int = Field(default=3)
pg_database: str = Field(default="postgres")
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="postgresql_datasetio.db",
),
"pg_host": os.getenv("POSTGRES_HOST", "127.0.0.1"),
"pg_port": os.getenv("POSTGRES_PORT", 5432),
"pg_user": os.getenv("POSTGRES_USER", ""),
"pg_password": os.getenv("POSTGRES_PASSWORD", ""),
"pg_con_pool_size": os.getenv("POSTGRES_CONN_POOL_SIZE", 3),
"pg_database": os.getenv("POSTGRES_DATABASE", "postgres"),
}

View file

@ -0,0 +1,258 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasets import Dataset, DatasetPurpose
from typing import Dict, List, Optional, Any
import logging
import psycopg_pool
from psycopg_pool.abc import ACT
from psycopg import sql
from urllib.parse import urlparse, parse_qs
from .config import PostgreSQLDatasetIOConfig
from typing import AsyncIterator
log = logging.getLogger(__name__)
class DatasetColumn:
def __init__(self, name: str, is_array: bool):
self.name = name
self.is_array = is_array
class PgConnectionInfo:
def __init__(self, host: str, port: int, user: str, password: str, database: str):
self.host = host
self.port = port
self.database = database
self.user = user
self.password = password
def __str__(self):
return f"host={self.host} port={self.port} dbname={self.database} user={self.user} password={self.password}"
def get_mandatory_cols(purpose: DatasetPurpose) -> list[DatasetColumn]:
if purpose == DatasetPurpose.post_training_messages:
return [DatasetColumn("messages", True)]
elif purpose == DatasetPurpose.eval_question_answer:
return [DatasetColumn("question", False), DatasetColumn("answer", False)]
elif purpose == DatasetPurpose.eval_messages_answer:
return [DatasetColumn("messages", True), DatasetColumn("answer", False)]
else:
raise ValueError(f"Unknown purpose: {purpose}")
def get_config_from_uri(uri: str, config: PostgreSQLDatasetIOConfig) -> tuple[PgConnectionInfo, str | None]:
parsed = urlparse(uri)
# Extract main components
if parsed.scheme != "postgresql":
raise ValueError(f"Unsupported scheme: {parsed.scheme} (uri: {uri})")
# uri info has precedence over config info
username = parsed.username if parsed.username else config.pg_user
password = parsed.password if parsed.password else config.pg_password
host = parsed.hostname if parsed.hostname else config.pg_host
port = parsed.port if parsed.port else config.pg_port
database = parsed.path.lstrip("/") # Remove leading "/"
database = database if database else config.pg_database
# Extract query parameters
raw_query = parsed.query.replace("?", "&") # Fix multiple question marks
query_params = parse_qs(raw_query)
table = query_params.get("table", [None])[0] # Extract first value if exists
# TODO: read from metadata here?
# if table is None:
# raise ValueError(f"Missing table parameter in URI: {uri}")
return PgConnectionInfo(
user=username,
password=password,
host=host,
port=port,
database=database,
), table
async def create_connection_pool(
max_connections: int,
info: PgConnectionInfo,
min_connections: int = 1,
) -> psycopg_pool.AsyncConnectionPool:
error = False
try:
pool = psycopg_pool.AsyncConnectionPool(
str(info), min_size=min_connections, max_size=max_connections, open=False
)
await pool.open(wait=True, timeout=10.0)
except Exception as e:
log.error(f"Failed to create connection pool: {e}")
error = True
raise
finally:
if error and pool is not None:
await pool.close()
pool = None
return pool
async def check_table_exists(conn: AsyncIterator[ACT], table_name: str) -> bool:
try:
sql_stmnt = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = %s)"
cur = await conn.execute(sql_stmnt, (table_name,))
row = await cur.fetchone()
exists = bool(row[0])
return exists
except Exception as e:
log.error(f"Error: {e}")
raise
finally:
await cur.close()
async def _get_table_columns(conn: AsyncIterator[ACT], table_name: str) -> List[str]:
try:
query = sql.SQL("SELECT column_name FROM information_schema.columns WHERE table_name = {table}").format(
table=table_name
)
async with conn.cursor() as cur:
await cur.execute(query)
table_cols = await cur.fetchall()
return [col[0] for col in table_cols]
except Exception as e:
log.error(f"Error: {e}")
raise
async def check_schema(conn: AsyncIterator[ACT], table_name: str, purpose: DatasetPurpose) -> None:
try:
# cur = await conn.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s",
# (table_name,))
table_cols = await _get_table_columns(conn, table_name)
schema_cols = get_mandatory_cols(purpose)
missing_col_names = []
for sc in schema_cols:
if not any(tc == sc.name for tc in table_cols):
log.error(f"Failed to find column {sc.name} in table {table_name} (purpose {purpose})")
missing_col_names.append(sc.name)
else:
# TODO: check type compatibility
pass
if len(missing_col_names) > 0:
raise ValueError(f"Could not find column(s) {missing_col_names} in table {table_name} (purpose {purpose})")
except Exception as e:
log.error(f"Error: {e}")
raise
return
def get_table_name(dataset: Dataset) -> str:
table_name = str(dataset.metadata.get("table", None))
if table_name is None:
log.error(f"No table defined for dataset: {dataset.provider_id}({dataset.identifier})")
raise ValueError(f"No table defined for dataset: {dataset.identifier}")
elif "'" in table_name or '"' in table_name:
log.error(f"Table name {table_name} contains quotes - this is ignored for security reasons")
raise ValueError(f"Table name {table_name} contains quotes - registration fails for security reasons")
return table_name
async def check_table_and_schema(ds: Dataset, conn: AsyncIterator[ACT], provider_type: str) -> None:
table_name = get_table_name(ds)
# Check table existance
try:
exists = await check_table_exists(conn, table_name)
if not exists:
log.error(f'Table "{table_name}" does not exist')
raise ValueError(
f"Table '{table_name}' does not exist in the database, dataset '{ds.identifier}' cannot be registered"
)
except Exception as e:
log.error(f"Error: {e}")
raise
# get and check table schema
try:
await check_schema(conn, table_name, ds.purpose)
except Exception as e:
log.error(f"Error: {e}")
raise
return
def build_select_statement(
dataset: Dataset,
conn: AsyncIterator[ACT],
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> str:
"""
Build a select statement for the given purpose
"""
params = []
stmnt = "SELECT * from {} "
params.append(sql.Identifier(dataset.metadata["table"]))
if dataset.metadata.get("filter", None):
stmnt += " WHERE {}"
params.append(sql.Literal(dataset.metadata["filter"]))
if limit is not None:
stmnt += " LIMIT {}"
params.append(sql.Literal(limit))
if start_index is not None:
stmnt += " OFFSET {}"
params.append(sql.Literal(start_index))
sql_stmnt = sql.SQL(stmnt).format(*params)
return sql_stmnt
async def get_row_count(
conn: AsyncIterator[ACT],
table_name: str,
) -> int:
"""
Get the number of rows in the table
"""
try:
sql_stmnt = "SELECT COUNT(*) FROM {}"
sql_stmnt = sql.SQL(sql_stmnt).format(sql.Identifier(table_name))
async with conn.cursor() as cur:
await cur.execute(sql_stmnt)
row_count = await cur.fetchone()
return int(row_count[0])
except Exception as e:
log.error(f"Error: {e}")
return 0
async def rows_to_iterrows_response(
rows: List[Any],
conn: AsyncIterator[ACT],
table_name: str,
) -> List[Dict[str, Any]]:
"""
Convert rows from the database to InterrowsResponse
"""
res = []
# cols = get_mandatory_cols(purpose)
cols = await _get_table_columns(conn, table_name)
for _i, row in enumerate(rows):
res_row = {}
for i, col in enumerate(cols):
res_row[col] = row[i]
res.append(res_row)
return res

View file

@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.kvstore import KVStore
# from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from psycopg_pool import AsyncConnectionPool
from .config import PostgreSQLDatasetIOConfig
# from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
from .pg_tools import build_select_statement, get_row_count, rows_to_iterrows_response
from .pg_tools import get_table_name
log = logging.getLogger(__name__)
DATASETS_PREFIX = "datasets:"
class PostgreSQLDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: PostgreSQLDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos: Dict[str, Dataset] = {}
self.conn_pools: Dict[str, AsyncConnectionPool] = {}
self.row_counts: Dict[str, int] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
for ds in stored_datasets:
dataset = Dataset.model_validate_json(ds)
pg_config_info, _ = get_config_from_uri(dataset.source.uri, self.config)
self.dataset_infos[dataset.identifier] = dataset
try:
conn_pool = await create_connection_pool(3, pg_config_info)
self.conn_pools[dataset.identifier] = conn_pool
except Exception as e:
log.error(f"Failed to create connection pool for dataset on initialization {dataset.identifier}: {e}")
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: Dataset,
) -> None:
# Store in kvstore
provider_type = get_provider_type(self.__module__)
if self.dataset_infos.get(dataset_def.identifier):
log.error(
f"Failed to register dataset {dataset_def.identifier}. " + "Dataset with this name alreadt exists"
)
raise ValueError(f"Dataset {dataset_def.identifier} already exists")
pg_connection_info, table = get_config_from_uri(dataset_def.source.uri, self.config)
tbmd = dataset_def.metadata.get("table", None)
if table is not None:
# Uri setting overrides metadata setting
dataset_def.metadata["table"] = table ## logging table for future use.
if tbmd and table and tbmd != table:
log.warning(
f"Table name mismatch for dataset {dataset_def.identifier}: metadata:{tbmd} != uri:{table}, using {table}"
)
elif get_table_name(dataset_def) is None: # should have been set by now
log.error(
f"No table defined for dataset: {provider_type}::{dataset_def.provider_id}({dataset_def.identifier})"
)
raise ValueError(f"No table defined for dataset {dataset_def.identifier}")
try:
pool = await create_connection_pool(3, pg_connection_info)
async with pool.connection() as conn:
await check_table_and_schema(dataset_def, conn, provider_type)
except ValueError:
# these are already logged in check_table_and_schema
raise
except Exception as e:
log.error(f"Error: {e}")
raise
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
self.conn_pools[dataset_def.identifier] = pool
async def unregister_dataset(self, dataset_id: str) -> None:
if self.conn_pools[dataset_id] is not None:
await self.conn_pools[dataset_id].close()
del self.conn_pools[dataset_id]
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
if start_index is not None and start_index < 0:
raise ValueError(f"start_index ({start_index}) must be a non-negative integer")
dataset_def = self.dataset_infos[dataset_id]
pool = self.conn_pools[dataset_id]
if pool is None: # Retry to crate connection pool
try:
pg_config_info, _ = get_config_from_uri(dataset_def.source.uri, self.config)
pool = await create_connection_pool(3, pg_config_info)
self.conn_pools[dataset_def.identifier] = pool
except Exception as e:
log.error(f"Failed to create connection pool for dataset {dataset_def.identifier}: {e}")
raise
try:
async with pool.connection() as conn:
await pool._check_connection(conn)
stmnt = build_select_statement(dataset_def, conn, start_index=start_index, limit=limit)
if self.row_counts.get(dataset_def.identifier) is None or start_index is None or start_index < 3:
# get row count only once per iteration
self.row_counts[dataset_def.identifier] = await get_row_count(conn, get_table_name(dataset_def))
async with conn.cursor() as cur:
await cur.execute(stmnt)
rows = await cur.fetchall()
data = await rows_to_iterrows_response(rows, conn, get_table_name(dataset_def))
except Exception as e:
log.error(f"Error: {e}")
raise
begin = 0 if start_index is None else start_index
end = begin + len(data)
return IterrowsResponse(
data=data,
next_start_index=end if end < self.row_counts[dataset_def.identifier] else None,
)
# TODO: Implement filtering
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
# This inteface is not implemented in the DatasetsResource class and is not
# accessible via the client.
raise NotImplementedError("Uploading to postgresql dataset is not supported yet")

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
def get_provider_type(module: str) -> str:
parts = module.split(".")
if parts[0] != "llama_stack" or parts[1] != "providers":
raise ValueError(f"Invalid module name <{module}>")
if parts[2] == "inline" or parts[2] == "remote":
return parts[2]
else:
raise ValueError(f"Invalid module name <{module}>")