60 lines
1.7 KiB
Python
60 lines
1.7 KiB
Python
"""
|
|
Database Manager
|
|
"""
|
|
|
|
import logging
|
|
from os import getenv
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
DB_TYPE = getenv("DB_TYPE", default="sqlite")
|
|
DB_HOST = getenv("DB_HOST", default="db.sqlite")
|
|
DB_PORT = getenv("DB_PORT")
|
|
DB_USERNAME = getenv("DB_USERNAME")
|
|
DB_PASSWORD = getenv("DB_PASSWORD")
|
|
DB_DATABASE = getenv("DB_DATABASE")
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseManager:
|
|
"""
|
|
Asynchronous database context manager.
|
|
"""
|
|
|
|
def __init__(self):
|
|
database_url = self.get_database_url()
|
|
self.engine = create_async_engine(database_url, future=True)
|
|
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
|
self.session = None
|
|
|
|
@staticmethod
|
|
def get_database_url(use_async=True):
|
|
"""
|
|
Returns a connection string for the database.
|
|
"""
|
|
|
|
if DB_TYPE not in ("sqlite", "mariadb", "mysql", "postgresql"):
|
|
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
|
|
|
|
is_sqlite = DB_TYPE == "sqlite"
|
|
|
|
url = f"sqlite:///{DB_HOST}" if is_sqlite else f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
|
|
url = url.replace(":/", "+aiosqlite:/" if is_sqlite else "+asyncpg:/") if use_async else url
|
|
|
|
return url
|
|
|
|
|
|
async def __aenter__(self):
|
|
self.session = self.session_maker()
|
|
log.debug("Database connection open")
|
|
return self
|
|
|
|
async def __aexit__(self, *_):
|
|
await self.session.commit()
|
|
await self.session.close()
|
|
self.session = None
|
|
await self.engine.dispose()
|
|
log.debug("Database connection closed")
|