# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from collections.abc import Mapping from typing import Any, Literal from sqlalchemy import ( JSON, Boolean, Column, DateTime, Float, Integer, MetaData, String, Table, Text, select, ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, ColumnType.STRING: String, ColumnType.FLOAT: Float, ColumnType.BOOLEAN: Boolean, ColumnType.DATETIME: DateTime, ColumnType.TEXT: Text, ColumnType.JSON: JSON, } class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) self.metadata = MetaData() async def create_table( self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition], ) -> None: if not schema: raise ValueError(f"No columns defined for table '{table}'.") sqlalchemy_columns: list[Column] = [] for col_name, col_props in schema.items(): col_type = None is_primary_key = False is_nullable = True # Default to nullable if isinstance(col_props, ColumnType): col_type = col_props elif isinstance(col_props, ColumnDefinition): col_type = col_props.type is_primary_key = col_props.primary_key is_nullable = col_props.nullable sqlalchemy_type = TYPE_MAPPING.get(col_type) if not sqlalchemy_type: raise ValueError(f"Unsupported column type '{col_type}' for column '{col_name}'.") sqlalchemy_columns.append( Column(col_name, sqlalchemy_type, primary_key=is_primary_key, nullable=is_nullable) ) # Check if table already exists in metadata, otherwise define it if table not in self.metadata.tables: sqlalchemy_table = Table(table, self.metadata, *sqlalchemy_columns) else: sqlalchemy_table = self.metadata.tables[table] # Create the table in the database if it doesn't exist # checkfirst=True ensures it doesn't try to recreate if it's already there engine = create_async_engine(self.config.engine_str) async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) async def insert(self, table: str, data: Mapping[str, Any]) -> None: async with self.async_session() as session: await session.execute(self.metadata.tables[table].insert(), data) await session.commit() async def fetch_all( self, table: str, where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> list[dict[str, Any]]: async with self.async_session() as session: query = select(self.metadata.tables[table]) if where: for key, value in where.items(): query = query.where(self.metadata.tables[table].c[key] == value) if limit: query = query.limit(limit) if order_by: if not isinstance(order_by, list): raise ValueError( f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" ) for order in order_by: if not isinstance(order, tuple): raise ValueError( f"order_by must be a list of tuples (column, order={['asc', 'desc']}), got {order_by}" ) name, order_type = order if order_type == "asc": query = query.order_by(self.metadata.tables[table].c[name].asc()) elif order_type == "desc": query = query.order_by(self.metadata.tables[table].c[name].desc()) else: raise ValueError(f"Invalid order '{order_type}' for column '{name}'") result = await session.execute(query) if result.rowcount == 0: return [] return [dict(row._mapping) for row in result] async def fetch_one( self, table: str, where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: rows = await self.fetch_all(table, where, limit=1, order_by=order_by) if not rows: return None return rows[0] async def update( self, table: str, data: Mapping[str, Any], where: Mapping[str, Any], ) -> None: if not where: raise ValueError("where is required for update") async with self.async_session() as session: stmt = self.metadata.tables[table].update() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) await session.execute(stmt, data) await session.commit() async def delete(self, table: str, where: Mapping[str, Any]) -> None: if not where: raise ValueError("where is required for delete") async with self.async_session() as session: stmt = self.metadata.tables[table].delete() for key, value in where.items(): stmt = stmt.where(self.metadata.tables[table].c[key] == value) await session.execute(stmt) await session.commit()