130 lines
4.7 KiB
Python
130 lines
4.7 KiB
Python
from sqlalchemy import delete
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
from sqlalchemy.exc import InvalidRequestError
|
|
from sqlalchemy.pool import NullPool
|
|
import config
|
|
from utils import logger
|
|
|
|
|
|
DATABASE_URL = f"postgresql+asyncpg://{config.DB_USER}:{config.DB_PASSWORD}@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
|
|
|
engine = create_async_engine(DATABASE_URL, poolclass=NullPool)
|
|
SessionLocal = async_sessionmaker(engine)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class CRUD:
|
|
|
|
@staticmethod
|
|
async def create(db_data, refresh: bool = False):
|
|
try:
|
|
is_lst = isinstance(db_data, list)
|
|
async with SessionLocal() as db:
|
|
if is_lst:
|
|
logger.debug(f"Создаю {len(db_data)} записей")
|
|
try:
|
|
db.add_all(db_data)
|
|
except InvalidRequestError:
|
|
for data in db_data:
|
|
await db.merge(data)
|
|
else:
|
|
logger.debug("Создаю запись")
|
|
db.add(db_data)
|
|
await db.commit()
|
|
if refresh:
|
|
if is_lst:
|
|
logger.debug(f"Обновляю {len(db_data)} записей")
|
|
for data in db_data:
|
|
await db.refresh(data)
|
|
else:
|
|
logger.debug("Обновляю запись")
|
|
await db.refresh(db_data)
|
|
logger.debug("Запись создана")
|
|
return db_data if refresh else None
|
|
except Exception as e:
|
|
logger.error(f"Ошибка создания: {str(e)}", exc_info=True)
|
|
return None
|
|
|
|
@staticmethod
|
|
async def read(query, all: bool = False):
|
|
try:
|
|
async with SessionLocal() as db:
|
|
logger.debug(f"Чтение записей. Все: {all}")
|
|
results = await db.execute(query)
|
|
logger.debug(f"Чтение завершено")
|
|
return (
|
|
results.unique().scalars().all()
|
|
if all
|
|
else results.unique().scalars().first()
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Ошибка чтения: {str(e)}", exc_info=True)
|
|
return None
|
|
|
|
@staticmethod
|
|
async def delete(db_data) -> bool:
|
|
def itemdebug(instance):
|
|
from sqlalchemy import inspect
|
|
|
|
state = inspect(instance)
|
|
|
|
if state.identity is None:
|
|
pKey = None
|
|
pValue = None
|
|
else:
|
|
mapper = state.mapper
|
|
pKey = mapper.primary_key[0].name
|
|
pValue = getattr(instance, pKey)
|
|
|
|
return {"key": pKey, "value": pValue, "class": instance.__class__}
|
|
|
|
async def deleteFromDB(data, db):
|
|
itemData = itemdebug(data)
|
|
query = delete(itemData["class"]).where(
|
|
getattr(itemData["class"], itemData["key"]) == itemData["value"]
|
|
)
|
|
await db.execute(query)
|
|
|
|
async with SessionLocal() as db:
|
|
try:
|
|
if isinstance(db_data, list):
|
|
logger.debug(f"Удаляю записей: {len(db_data)}")
|
|
for data in db_data:
|
|
await deleteFromDB(data, db)
|
|
else:
|
|
logger.debug("Удаляю запись")
|
|
await deleteFromDB(db_data, db)
|
|
await db.commit()
|
|
logger.debug("Запись удалена")
|
|
return True
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Ошибка удаления: {str(e)}", exc_info=True)
|
|
return False
|
|
|
|
@staticmethod
|
|
async def update(model, id: int, **kwargs):
|
|
from sqlalchemy import update as sa_update
|
|
|
|
async with SessionLocal() as db:
|
|
try:
|
|
query = (
|
|
sa_update(model)
|
|
.where(model.id == id)
|
|
.values(**kwargs)
|
|
.execution_options(synchronize_session="fetch")
|
|
)
|
|
await db.execute(query)
|
|
await db.commit()
|
|
|
|
logger.debug("Запись обновлена")
|
|
|
|
return await db.get(model, id)
|
|
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Ошибка обновления: {str(e)}", exc_info=True)
|
|
return None
|