PYRSS-Bot/src/db/db.py

78 lines
2.1 KiB
Python

"""
Database Manager
"""
import logging
from os import getenv
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
DB_TYPE = getenv("DB_TYPE", default="sqlite")
DB_HOST = getenv("DB_HOST", default="db.sqlite")
DB_PORT = getenv("DB_PORT")
DB_USERNAME = getenv("DB_USERNAME")
DB_PASSWORD = getenv("DB_PASSWORD")
DB_DATABASE = getenv("DB_DATABASE")
log = logging.getLogger(__name__)
class DatabaseManager:
"""
Asynchronous database context manager.
"""
def __init__(self, no_commit: bool = False):
database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None
self.no_commit = no_commit
@staticmethod
def get_database_url(use_async=True):
"""
Returns a connection string for the database.
"""
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
url_addon = ""
# This looks fucking ugly
match use_async, DB_TYPE:
case True, "sqlite":
url_addon = "aiosqlite"
case True, "postgresql":
url_addon = "asyncpg"
case False, "sqlite":
pass
case False, "postgresql":
pass
case _, _:
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
url = url.replace(":/", f"+{url_addon}:/") if url_addon else url
return url
async def __aenter__(self):
self.session = self.session_maker()
log.debug("Database connection open")
return self
async def __aexit__(self, *_):
if not self.no_commit:
await self.session.commit()
await self.session.close()
self.session = None
await self.engine.dispose()
log.debug("Database connection closed")