fix(mypy): resolve SQLAlchemy typing issues in sqlalchemy_sqlstore.py (#3932)

This commit is contained in:
Ashwin Bharambe 2025-10-27 21:26:24 -07:00
parent 8991b65552
commit cc84c2e4f5

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Literal from typing import Any, Literal, cast
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
raise ValueError(f"Operator mapping must have a single operator, got: {value}") raise ValueError(f"Operator mapping must have a single operator, got: {value}")
op, operand = next(iter(value.items())) op, operand = next(iter(value.items()))
if op == "==" or op == "=": if op == "==" or op == "=":
return column == operand return cast(ColumnElement[Any], column == operand)
if op == ">": if op == ">":
return column > operand return cast(ColumnElement[Any], column > operand)
if op == "<": if op == "<":
return column < operand return cast(ColumnElement[Any], column < operand)
if op == ">=": if op == ">=":
return column >= operand return cast(ColumnElement[Any], column >= operand)
if op == "<=": if op == "<=":
return column <= operand return cast(ColumnElement[Any], column <= operand)
raise ValueError(f"Unsupported operator '{op}' in where mapping") raise ValueError(f"Unsupported operator '{op}' in where mapping")
return column == value return cast(ColumnElement[Any], column == value)
class SqlAlchemySqlStoreImpl(SqlStore): class SqlAlchemySqlStoreImpl(SqlStore):
@ -210,10 +210,8 @@ class SqlAlchemySqlStoreImpl(SqlStore):
query = query.limit(fetch_limit) query = query.limit(fetch_limit)
result = await session.execute(query) result = await session.execute(query)
if result.rowcount == 0: # Iterate directly - if no rows, list comprehension yields empty list
rows = [] rows = [dict(row._mapping) for row in result]
else:
rows = [dict(row._mapping) for row in result]
# Always return pagination result # Always return pagination result
has_more = False has_more = False