This commit is contained in:
Corban-Lee Jones 2023-12-13 01:46:41 +00:00
parent fd962a9e76
commit cec1db209b
6 changed files with 130 additions and 1 deletions

BIN
db.sqlite Normal file

Binary file not shown.

14
requirements.txt Normal file
View File

@ -0,0 +1,14 @@
aiohttp==3.9.1
aiosignal==1.3.1
aiosqlite==0.19.0
attrs==23.1.0
discord.py==2.3.2
frozenlist==1.4.0
greenlet==3.0.2
idna==3.6
multidict==6.0.4
psycopg2==2.9.9
python-dotenv==1.0.0
SQLAlchemy==2.0.23
typing_extensions==4.9.0
yarl==1.9.4

20
src/db/__init__.py Normal file
View File

@ -0,0 +1,20 @@
"""
Initialize the database modules, create the database tables and default data.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .models import Base, AuditModel
from .db import DatabaseManager
# Initialise a database session
engine = create_engine(DatabaseManager.get_database_url(use_async=False))
session = sessionmaker(bind=engine)()
# Create tables if not exists
Base.metadata.create_all(engine)
session.commit()
session.close()

59
src/db/db.py Normal file
View File

@ -0,0 +1,59 @@
"""
Database Manager
"""
import logging
from os import getenv
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
DB_TYPE = getenv("DB_TYPE", default="sqlite")
DB_HOST = getenv("DB_HOST", default="db.sqlite")
DB_PORT = getenv("DB_PORT")
DB_USERNAME = getenv("DB_USERNAME")
DB_PASSWORD = getenv("DB_PASSWORD")
DB_DATABASE = getenv("DB_DATABASE")
log = logging.getLogger(__name__)
class DatabaseManager:
"""
Asynchronous database context manager.
"""
def __init__(self):
database_url = self.get_database_url()
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None
@staticmethod
def get_database_url(use_async=True):
"""
Returns a connection string for the database.
"""
if DB_TYPE not in ("sqlite", "mariadb", "mysql", "postgresql"):
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
is_sqlite = DB_TYPE == "sqlite"
url = f"sqlite:///{DB_HOST}" if is_sqlite else f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
url = url.replace(":/", "+aiosqlite:/" if is_sqlite else "+asyncpg:/") if use_async else url
return url
async def __aenter__(self):
self.session = self.session_maker()
log.debug("Database connection open")
return self
async def __aexit__(self, *_):
await self.session.commit()
await self.session.close()
self.session = None
await self.engine.dispose()
log.debug("Database connection closed")

27
src/db/models.py Normal file
View File

@ -0,0 +1,27 @@
"""
Models and Enums for the database.
All table classes should be suffixed with `Model`.
"""
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class AuditModel(Base):
"""
Table for taking audits.
"""
__tablename__ = "audit"
id = Column(Integer, primary_key=True, autoincrement=True)
discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
active = Column(Integer, default=True, nullable=False)

View File

@ -7,6 +7,9 @@ import logging
from discord import app_commands, Interaction
from discord.ext import commands, tasks
from sqlalchemy import insert, select
from db import DatabaseManager, AuditModel
log = logging.getLogger(__name__)
@ -27,7 +30,13 @@ class Test(commands.Cog):
@app_commands.command(name="test-command")
async def test_command(self, inter: Interaction):
await inter.response.send_message("test")
async with DatabaseManager() as database:
message = f"Test command has been invoked successfully!"
query = insert(AuditModel).values(discord_user_id=inter.user.id, message=message)
await database.session.execute(query)
await inter.response.send_message("the audit log test was successful")
async def setup(bot):