diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 32ccde144..8ff3e7d15 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -176,6 +176,7 @@ class Datasets(Protocol): "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" } + TODO: Add postgresql example here - { "type": "rows", "rows": [ diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index d444b03a3..87984c646 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -430,6 +430,8 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): # infer provider from uri if source.uri.startswith("huggingface"): provider_id = "huggingface" + elif source.uri.startswith("postgresql"): + provider_id = "postgresql" else: provider_id = "localfs" else: diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index f83dcbc60..d9f23a4ef 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -36,4 +36,18 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", ), ), + remote_provider_spec( + api=Api.datasetio, + adapter=AdapterSpec( + adapter_type="postgresql", + pip_packages=[ +# "asyncpg", +# "datasets", + "psycopg", + "psycopg[pool]", + ], + module="llama_stack.providers.remote.datasetio.postgresql", + config_class="llama_stack.providers.remote.datasetio.postgresql.PostgreSQLDatasetIOConfig", + ), + ), ] diff --git a/llama_stack/providers/remote/datasetio/postgresql/__init__.py b/llama_stack/providers/remote/datasetio/postgresql/__init__.py new file mode 100644 index 000000000..aef086427 --- /dev/null +++ b/llama_stack/providers/remote/datasetio/postgresql/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import PostgreSQLDatasetIOConfig + + +async def get_adapter_impl( + config: PostgreSQLDatasetIOConfig, + _deps, +): + from .postgresql import PostgreSQLDatasetIOImpl + + impl = PostgreSQLDatasetIOImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/datasetio/postgresql/config.py b/llama_stack/providers/remote/datasetio/postgresql/config.py new file mode 100644 index 000000000..0eaef9f92 --- /dev/null +++ b/llama_stack/providers/remote/datasetio/postgresql/config.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from pydantic import BaseModel, Field +from typing import Any, Dict +import os + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) + + +class PostgreSQLDatasetIOConfig(BaseModel): + kvstore: KVStoreConfig + + pg_host: str = Field(default="localhost") # os.getenv("POSTGRES_HOST", "127.0.0.1") + pg_port: int = Field(default=5432) + # TODO - revise secutiry implications of using env vars for user and password + pg_user: str = Field(default="postgres") + pg_password: str = Field(default="fail") + pg_con_pool_size: int = Field(default=3) + pg_database: str = Field(default="postgres") + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="postgresql_datasetio.db", + ), + "pg_host": os.getenv("POSTGRES_HOST", "127.0.0.1"), + "pg_port": os.getenv("POSTGRES_PORT", 5432), + "pg_user": os.getenv("POSTGRES_USER", ""), + "pg_password": os.getenv("POSTGRES_PASSWORD", ""), + "pg_con_pool_size": os.getenv("POSTGRES_CONN_POOL_SIZE", 3), + "pg_database": os.getenv("POSTGRES_DATABASE", "postgres"), + } diff --git a/llama_stack/providers/remote/datasetio/postgresql/pg_tools.py b/llama_stack/providers/remote/datasetio/postgresql/pg_tools.py new file mode 100644 index 000000000..da1f9984a --- /dev/null +++ b/llama_stack/providers/remote/datasetio/postgresql/pg_tools.py @@ -0,0 +1,258 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.datasets import Dataset, DatasetPurpose +from typing import Dict, List, Optional, Any + +import logging + +import psycopg_pool +from psycopg_pool.abc import ACT +from psycopg import sql +from urllib.parse import urlparse, parse_qs +from .config import PostgreSQLDatasetIOConfig +from typing import AsyncIterator + +log = logging.getLogger(__name__) + + +class DatasetColumn: + def __init__(self, name: str, is_array: bool): + self.name = name + self.is_array = is_array + + +class PgConnectionInfo: + def __init__(self, host: str, port: int, user: str, password: str, database: str): + self.host = host + self.port = port + self.database = database + self.user = user + self.password = password + + def __str__(self): + return f"host={self.host} port={self.port} dbname={self.database} user={self.user} password={self.password}" + + +def get_mandatory_cols(purpose: DatasetPurpose) -> list[DatasetColumn]: + if purpose == DatasetPurpose.post_training_messages: + return [DatasetColumn("messages", True)] + elif purpose == DatasetPurpose.eval_question_answer: + return [DatasetColumn("question", False), DatasetColumn("answer", False)] + elif purpose == DatasetPurpose.eval_messages_answer: + return [DatasetColumn("messages", True), DatasetColumn("answer", False)] + else: + raise ValueError(f"Unknown purpose: {purpose}") + + +def get_config_from_uri(uri: str, config: PostgreSQLDatasetIOConfig) -> tuple[PgConnectionInfo, str | None]: + parsed = urlparse(uri) + # Extract main components + if parsed.scheme != "postgresql": + raise ValueError(f"Unsupported scheme: {parsed.scheme} (uri: {uri})") + + # uri info has precedence over config info + username = parsed.username if parsed.username else config.pg_user + password = parsed.password if parsed.password else config.pg_password + host = parsed.hostname if parsed.hostname else config.pg_host + port = parsed.port if parsed.port else config.pg_port + database = parsed.path.lstrip("/") # Remove leading "/" + database = database if database else config.pg_database + + # Extract query parameters + raw_query = parsed.query.replace("?", "&") # Fix multiple question marks + query_params = parse_qs(raw_query) + + table = query_params.get("table", [None])[0] # Extract first value if exists + # TODO: read from metadata here? + # if table is None: + # raise ValueError(f"Missing table parameter in URI: {uri}") + + return PgConnectionInfo( + user=username, + password=password, + host=host, + port=port, + database=database, + ), table + + +async def create_connection_pool( + max_connections: int, + info: PgConnectionInfo, + min_connections: int = 1, +) -> psycopg_pool.AsyncConnectionPool: + error = False + try: + pool = psycopg_pool.AsyncConnectionPool( + str(info), min_size=min_connections, max_size=max_connections, open=False + ) + await pool.open(wait=True, timeout=10.0) + except Exception as e: + log.error(f"Failed to create connection pool: {e}") + error = True + raise + finally: + if error and pool is not None: + await pool.close() + pool = None + return pool + + +async def check_table_exists(conn: AsyncIterator[ACT], table_name: str) -> bool: + try: + sql_stmnt = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = %s)" + cur = await conn.execute(sql_stmnt, (table_name,)) + row = await cur.fetchone() + exists = bool(row[0]) + return exists + except Exception as e: + log.error(f"Error: {e}") + raise + finally: + await cur.close() + + +async def _get_table_columns(conn: AsyncIterator[ACT], table_name: str) -> List[str]: + try: + query = sql.SQL("SELECT column_name FROM information_schema.columns WHERE table_name = {table}").format( + table=table_name + ) + async with conn.cursor() as cur: + await cur.execute(query) + table_cols = await cur.fetchall() + return [col[0] for col in table_cols] + except Exception as e: + log.error(f"Error: {e}") + raise + + +async def check_schema(conn: AsyncIterator[ACT], table_name: str, purpose: DatasetPurpose) -> None: + try: + # cur = await conn.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", + # (table_name,)) + table_cols = await _get_table_columns(conn, table_name) + schema_cols = get_mandatory_cols(purpose) + missing_col_names = [] + for sc in schema_cols: + if not any(tc == sc.name for tc in table_cols): + log.error(f"Failed to find column {sc.name} in table {table_name} (purpose {purpose})") + missing_col_names.append(sc.name) + else: + # TODO: check type compatibility + pass + + if len(missing_col_names) > 0: + raise ValueError(f"Could not find column(s) {missing_col_names} in table {table_name} (purpose {purpose})") + + except Exception as e: + log.error(f"Error: {e}") + raise + return + + +def get_table_name(dataset: Dataset) -> str: + table_name = str(dataset.metadata.get("table", None)) + if table_name is None: + log.error(f"No table defined for dataset: {dataset.provider_id}({dataset.identifier})") + raise ValueError(f"No table defined for dataset: {dataset.identifier}") + elif "'" in table_name or '"' in table_name: + log.error(f"Table name {table_name} contains quotes - this is ignored for security reasons") + raise ValueError(f"Table name {table_name} contains quotes - registration fails for security reasons") + return table_name + + +async def check_table_and_schema(ds: Dataset, conn: AsyncIterator[ACT], provider_type: str) -> None: + table_name = get_table_name(ds) + # Check table existance + try: + exists = await check_table_exists(conn, table_name) + if not exists: + log.error(f'Table "{table_name}" does not exist') + raise ValueError( + f"Table '{table_name}' does not exist in the database, dataset '{ds.identifier}' cannot be registered" + ) + except Exception as e: + log.error(f"Error: {e}") + raise + + # get and check table schema + try: + await check_schema(conn, table_name, ds.purpose) + + except Exception as e: + log.error(f"Error: {e}") + raise + + return + + +def build_select_statement( + dataset: Dataset, + conn: AsyncIterator[ACT], + start_index: Optional[int] = None, + limit: Optional[int] = None, +) -> str: + """ + Build a select statement for the given purpose + """ + params = [] + stmnt = "SELECT * from {} " + params.append(sql.Identifier(dataset.metadata["table"])) + + if dataset.metadata.get("filter", None): + stmnt += " WHERE {}" + params.append(sql.Literal(dataset.metadata["filter"])) + + if limit is not None: + stmnt += " LIMIT {}" + params.append(sql.Literal(limit)) + + if start_index is not None: + stmnt += " OFFSET {}" + params.append(sql.Literal(start_index)) + + sql_stmnt = sql.SQL(stmnt).format(*params) + + return sql_stmnt + + +async def get_row_count( + conn: AsyncIterator[ACT], + table_name: str, +) -> int: + """ + Get the number of rows in the table + """ + try: + sql_stmnt = "SELECT COUNT(*) FROM {}" + sql_stmnt = sql.SQL(sql_stmnt).format(sql.Identifier(table_name)) + async with conn.cursor() as cur: + await cur.execute(sql_stmnt) + row_count = await cur.fetchone() + return int(row_count[0]) + except Exception as e: + log.error(f"Error: {e}") + return 0 + + +async def rows_to_iterrows_response( + rows: List[Any], + conn: AsyncIterator[ACT], + table_name: str, +) -> List[Dict[str, Any]]: + """ + Convert rows from the database to InterrowsResponse + """ + res = [] + # cols = get_mandatory_cols(purpose) + cols = await _get_table_columns(conn, table_name) + for _i, row in enumerate(rows): + res_row = {} + for i, col in enumerate(cols): + res_row[col] = row[i] + res.append(res_row) + return res diff --git a/llama_stack/providers/remote/datasetio/postgresql/postgresql.py b/llama_stack/providers/remote/datasetio/postgresql/postgresql.py new file mode 100644 index 000000000..617fde92f --- /dev/null +++ b/llama_stack/providers/remote/datasetio/postgresql/postgresql.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import logging +from typing import Any, Dict, List, Optional + +from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse +from llama_stack.apis.datasets import Dataset +from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.common.provider_utils import get_provider_type +from llama_stack.providers.utils.kvstore import KVStore + +# from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url +from llama_stack.providers.utils.kvstore import kvstore_impl + +from psycopg_pool import AsyncConnectionPool + +from .config import PostgreSQLDatasetIOConfig + +# from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool +from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool +from .pg_tools import build_select_statement, get_row_count, rows_to_iterrows_response +from .pg_tools import get_table_name + +log = logging.getLogger(__name__) + +DATASETS_PREFIX = "datasets:" + + +class PostgreSQLDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): + def __init__(self, config: PostgreSQLDatasetIOConfig) -> None: + self.config = config + # local registry for keeping track of datasets within the provider + self.dataset_infos: Dict[str, Dataset] = {} + self.conn_pools: Dict[str, AsyncConnectionPool] = {} + self.row_counts: Dict[str, int] = {} + self.kvstore: KVStore | None = None + + async def initialize(self) -> None: + self.kvstore = await kvstore_impl(self.config.kvstore) + # Load existing datasets from kvstore + start_key = DATASETS_PREFIX + end_key = f"{DATASETS_PREFIX}\xff" + stored_datasets = await self.kvstore.range(start_key, end_key) + + for ds in stored_datasets: + dataset = Dataset.model_validate_json(ds) + pg_config_info, _ = get_config_from_uri(dataset.source.uri, self.config) + self.dataset_infos[dataset.identifier] = dataset + try: + conn_pool = await create_connection_pool(3, pg_config_info) + self.conn_pools[dataset.identifier] = conn_pool + except Exception as e: + log.error(f"Failed to create connection pool for dataset on initialization {dataset.identifier}: {e}") + + async def shutdown(self) -> None: ... + + async def register_dataset( + self, + dataset_def: Dataset, + ) -> None: + # Store in kvstore + provider_type = get_provider_type(self.__module__) + if self.dataset_infos.get(dataset_def.identifier): + log.error( + f"Failed to register dataset {dataset_def.identifier}. " + "Dataset with this name alreadt exists" + ) + raise ValueError(f"Dataset {dataset_def.identifier} already exists") + + pg_connection_info, table = get_config_from_uri(dataset_def.source.uri, self.config) + tbmd = dataset_def.metadata.get("table", None) + if table is not None: + # Uri setting overrides metadata setting + dataset_def.metadata["table"] = table ## logging table for future use. + + if tbmd and table and tbmd != table: + log.warning( + f"Table name mismatch for dataset {dataset_def.identifier}: metadata:{tbmd} != uri:{table}, using {table}" + ) + elif get_table_name(dataset_def) is None: # should have been set by now + log.error( + f"No table defined for dataset: {provider_type}::{dataset_def.provider_id}({dataset_def.identifier})" + ) + raise ValueError(f"No table defined for dataset {dataset_def.identifier}") + + try: + pool = await create_connection_pool(3, pg_connection_info) + async with pool.connection() as conn: + await check_table_and_schema(dataset_def, conn, provider_type) + except ValueError: + # these are already logged in check_table_and_schema + raise + except Exception as e: + log.error(f"Error: {e}") + raise + + key = f"{DATASETS_PREFIX}{dataset_def.identifier}" + + await self.kvstore.set( + key=key, + value=dataset_def.model_dump_json(), + ) + self.dataset_infos[dataset_def.identifier] = dataset_def + self.conn_pools[dataset_def.identifier] = pool + + async def unregister_dataset(self, dataset_id: str) -> None: + if self.conn_pools[dataset_id] is not None: + await self.conn_pools[dataset_id].close() + del self.conn_pools[dataset_id] + key = f"{DATASETS_PREFIX}{dataset_id}" + await self.kvstore.delete(key=key) + del self.dataset_infos[dataset_id] + + async def iterrows( + self, + dataset_id: str, + start_index: Optional[int] = None, + limit: Optional[int] = None, + ) -> IterrowsResponse: + if start_index is not None and start_index < 0: + raise ValueError(f"start_index ({start_index}) must be a non-negative integer") + + dataset_def = self.dataset_infos[dataset_id] + pool = self.conn_pools[dataset_id] + if pool is None: # Retry to crate connection pool + try: + pg_config_info, _ = get_config_from_uri(dataset_def.source.uri, self.config) + pool = await create_connection_pool(3, pg_config_info) + self.conn_pools[dataset_def.identifier] = pool + except Exception as e: + log.error(f"Failed to create connection pool for dataset {dataset_def.identifier}: {e}") + raise + + try: + async with pool.connection() as conn: + await pool._check_connection(conn) + stmnt = build_select_statement(dataset_def, conn, start_index=start_index, limit=limit) + if self.row_counts.get(dataset_def.identifier) is None or start_index is None or start_index < 3: + # get row count only once per iteration + self.row_counts[dataset_def.identifier] = await get_row_count(conn, get_table_name(dataset_def)) + async with conn.cursor() as cur: + await cur.execute(stmnt) + rows = await cur.fetchall() + data = await rows_to_iterrows_response(rows, conn, get_table_name(dataset_def)) + except Exception as e: + log.error(f"Error: {e}") + raise + + begin = 0 if start_index is None else start_index + end = begin + len(data) + + return IterrowsResponse( + data=data, + next_start_index=end if end < self.row_counts[dataset_def.identifier] else None, + ) + + # TODO: Implement filtering + + async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: + # This inteface is not implemented in the DatasetsResource class and is not + # accessible via the client. + raise NotImplementedError("Uploading to postgresql dataset is not supported yet") diff --git a/llama_stack/providers/utils/common/provider_utils.py b/llama_stack/providers/utils/common/provider_utils.py new file mode 100644 index 000000000..8a4131f51 --- /dev/null +++ b/llama_stack/providers/utils/common/provider_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +def get_provider_type(module: str) -> str: + parts = module.split(".") + if parts[0] != "llama_stack" or parts[1] != "providers": + raise ValueError(f"Invalid module name <{module}>") + if parts[2] == "inline" or parts[2] == "remote": + return parts[2] + else: + raise ValueError(f"Invalid module name <{module}>")