complete rewrite

This commit is contained in:
Corban-Lee 2023-07-07 07:44:28 +01:00
parent ae21eab7f1
commit 8dc97e9e70
11 changed files with 922 additions and 142 deletions

16
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Discord Bot",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/src/main.py",
"console": "integratedTerminal",
"justMyCode": true
}
]
}

42
db/build.sql Normal file
View File

@ -0,0 +1,42 @@
/*
Server Channels
*/
CREATE TABLE IF NOT EXISTS 'server_channels' (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL,
news_category_id INTEGER NOT NULL,
active INTEGER NOT NULL,
FOREIGN KEY (news_category_id) REFERENCES 'news_categories' (id)
ON DELETE CASCADE
);
/*
News Articles
*/
CREATE TABLE IF NOT EXISTS 'news_articles' (
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT NOT NULL,
server_channel_id INTEGER NOT NULL,
FOREIGN KEY (server_channel_id) REFERENCES 'server_channels' (id)
ON DELETE CASCADE
);
/*
News Categories
*/
CREATE TABLE IF NOT EXISTS 'news_categories' (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL
);
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('all');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('world');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('uk');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('north_america');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('entertainment');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('business');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('tech');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('science');
INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('top_stories');

BIN
db/db.sqlite Normal file

Binary file not shown.

View File

@ -1,6 +1,10 @@
"""The discord bot for the application"""
"""
The discord bot for the application.
"""
import os
import time
import logging
from datetime import datetime
import aiohttp
@ -10,6 +14,9 @@ from discord import Intents, Interaction, app_commands
from discord.ext import commands, tasks
from bbc_feeds import news
log = logging.getLogger(__name__)
EXTENSIONS_DIRECTORY = "src/extensions/"
class DiscordBot(commands.Bot):
@ -17,24 +24,36 @@ class DiscordBot(commands.Bot):
super().__init__(command_prefix="-", intents=Intents.all())
async def sync_app_commands(self):
"""Sync application commands"""
"""
Sync application commands.
"""
await self.wait_until_ready()
await self.tree.sync()
print("app commands synced")
log.info("Application commands successfully synced")
async def on_ready(self):
"""When the bot is ready"""
"""
When the bot is ready.
"""
await self.add_cog(CommandsCog(self))
await self.add_cog(ErrorCog(self))
await self.sync_app_commands()
async def load_extensions(self):
"""
Load any extensions found in the extensions dictionary.
"""
for filename in os.listdir(EXTENSIONS_DIRECTORY):
if filename.endswith(".py"):
await self.load_extension(f"extensions.{filename[:-3]}")
class CommandsCog(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.news_task.start()
# self.news_task.start()
async def story_to_embed(self, story) -> discord.Embed:
"""
@ -97,135 +116,3 @@ class CommandsCog(commands.Cog):
embed = await self.story_to_embed(story)
channel = self.bot.get_channel(1057004889458348042)
await channel.send(embed=embed)
class ErrorCog(commands.Cog):
"""Error handling cog."""
__slots__ = ()
default_err_msg = "I'm sorry, but I've encountered an " \
"error while processing your command."
def __init__(self, bot):
super().__init__()
self.bot = bot
# Register the error handler
bot.tree.error(coro = self._dispatch_to_app_command_handler)
def trace_error(self, error: Exception):
print(f"{type(error).__name__} {error}")
raise error
async def _dispatch_to_app_command_handler(
self,
inter: Interaction,
error: app_commands.AppCommandError
):
"""Dispatches the error to the app command handler"""
self.bot.dispatch("app_command_error", inter, error)
async def _respond_to_interaction(self, inter: Interaction) -> bool:
"""Respond to an interaction with an error message"""
try:
await inter.response.send_message(
self.default_err_msg,
ephemeral=True
)
except discord.InteractionResponded:
return
@commands.Cog.listener("on_app_command_error")
async def get_app_command_error(
self,
inter: Interaction,
error: app_commands.AppCommandError
):
"""Handles the application command error.
Responds with the appropriate error message.
"""
try:
# Send the default error message and create an edit
# shorthand to add more details to the message once
# we've figured out what the error is.
print(error.with_traceback(None))
await self._respond_to_interaction(inter)
edit = lambda x: inter.edit_original_response(content=x)
raise error
except app_commands.CommandInvokeError as _err:
# The interaction has already been responded to.
if isinstance(
_err.original,
discord.InteractionResponded
):
await edit(_err.original)
return
# Some other error occurred while invoking the command.
await edit(
f"`{type(_err.original).__name__}` " \
f": {_err.original}"
)
except app_commands.CheckFailure as _err:
# The command is still on cooldown.
if isinstance(
_err,
app_commands.CommandOnCooldown
):
await edit(
f"Woah, slow down! This command is on cooldown, " \
f"wait `{str(_err).split(' ')[7]}` !"
)
return
if isinstance(
_err,
app_commands.MissingPermissions
):
await edit(
"You don't have the required permissions to " \
"run this command!"
)
return
if isinstance(
_err,
app_commands.BotMissingPermissions
):
await edit(
"I don't have the required permissions to " \
"run this command! Please ask an admin to " \
"grant me the required permissions."
)
return
# A different check has failed.
await edit(f"`{type(_err).__name__}` : {_err}")
except app_commands.CommandNotFound:
# The command could not be found.
await edit(
f"I couldn't find the command you were looking for... "
"\nThis is probably a discord bug related to " \
"desynchronization between my commands and discord's " \
"servers. Please try again later."
)
except Exception as _err:
# Caught here:
# app_commands.TransformerError
# app_commands.CommandLimitReached
# app_commands.CommandAlreadyRegistered
# app_commands.CommandSignatureMismatch
self.trace_error(_err)

27
src/db/__init__.py Normal file
View File

@ -0,0 +1,27 @@
"""
Initialize the database modules.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .db import DATABASE_URL
from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories
engine = create_engine(DATABASE_URL)
Session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
session = Session()
default_categories = [
(category.name, category.value)
for category in DefaultNewsCategories
]
for category_name, category_id in default_categories:
get_or_create(session, model=NewsCategories, id=category_id, name=category_name)
session.commit()
session.close()

147
src/db/db.py Normal file
View File

@ -0,0 +1,147 @@
"""
"""
import logging
import aiosqlite
from os.path import isfile
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
DB_PATH = "db/db.sqlite"
BUILD_PATH = "db/build.sql"
DATABASE_URL = "sqlite:///db/db.sqlite"
DATABASE_ASYNC_URL = "sqlite+aiosqlite:///db/db.sqlite"
log = logging.getLogger(__name__)
class DatabaseManager:
def __init__(self, database_url=DATABASE_ASYNC_URL):
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None
async def __aenter__(self):
self.session = self.session_maker()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.session.close()
await self.session.commit()
self.session = None
await self.engine.dispose()
# class DBConnection:
# """
# Asynchronous context manager for database connections.
# """
# def __init__(self):
# self.conn = None
# async def __aenter__(self):
# self.conn = await connect()
# return self.conn
# async def __aexit__(self, *args):
# await close(self.conn)
# async def connect():
# log.info("Opening database connection")
# return await aiosqlite.connect(DB_PATH)
# async def with_commit(func):
# """
# Wrapper to commit changes to the database.
# """
# async def inner(*args, **kwargs):
# await func(*args, **kwargs)
# await commit()
# return inner
# async def build(conn):
# """
# Build the database from the build script.
# """
# log.info("Building database from build script")
# if isfile(BUILD_PATH):
# await scriptexec(conn, BUILD_PATH)
# return
# raise ValueError('Build script not found')
# async def commit(conn):
# """
# Commit changes to the database.
# """
# log.info("Committing database changes")
# await conn.commit()
# async def close(conn):
# """
# Close the database connection.
# """
# log.debug("Closing database connection")
# await conn.close()
# async def field(conn, cmd, *vals):
# """
# Return a single field.
# """
# log.debug("Executing command for field: %s, vals:%s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# if (fetch := await cur.fetchone()) is not None:
# return fetch[0]
# async def record(conn, cmd, *vals):
# """
# Return a single record.
# """
# log.debug("Executing command for record: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return await cur.fetchone()
# async def records(conn, cmd, *vals):
# """
# Return all records.
# """
# log.debug("Executing command for records: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return await cur.fetchall()
# async def column(conn, cmd, *vals):
# """
# Return a single column.
# """
# log.debug("Executing command for column: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return [item[0] for item in await cur.fetchall()]
# async def execute(conn, cmd, *vals):
# """
# Execute a command.
# """
# log.debug("Executing command: %s, vals: %s", cmd, vals)
# async with conn.execute(cmd, tuple(vals)) as cur:
# return cur
# async def multiexec(conn, cmd, valset):
# """
# Execute multiple commands.
# """
# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset)
# async with conn.executemany(cmd, valset):
# pass
# async def scriptexec(conn, path):
# """
# Execute a script.
# """
# log.debug("Executing script: %s", path)
# with open(path, 'r', encoding='utf-8') as script:
# await conn.executescript(script.read())

63
src/db/models.py Normal file
View File

@ -0,0 +1,63 @@
"""
"""
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
def get_or_create(session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()
if instance:
return instance
else:
instance = model(**kwargs)
session.add(instance)
session.commit()
return instance
class ServerChannels(Base):
__tablename__ = 'server_channels'
id = Column(Integer, primary_key=True, autoincrement=True)
server_id = Column(Integer, nullable=False)
channel_id = Column(Integer, nullable=False)
news_category_id = Column(Integer, nullable=False)
active = Column(Integer, nullable=False, default=True)
news_articles = relationship('NewsArticles', cascade='all, delete')
class NewsArticles(Base):
__tablename__ = 'news_articles'
id = Column(Integer, primary_key=True, autoincrement=True)
url = Column(String, nullable=False)
server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False)
class NewsCategories(Base):
__tablename__ = 'news_categories'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String, unique=True, nullable=False)
class DefaultNewsCategories(Enum):
"""
"""
ALL = auto()
WORLD = auto()
UK = auto()
NORTH_AMERICA = auto()
ENTERTAINMENT = auto()
BUSINESS = auto()
TECH = auto()
SCIENCE = auto()
TOP_STORIES = auto()

146
src/extensions/errors.py Normal file
View File

@ -0,0 +1,146 @@
"""
Extension for the `ErrorCog` cog.
Loading this file via `commands.Bot.load_extension` will add the `ErrorCog` cog to the bot.
"""
import logging
from discord import app_commands, Interaction
from discord.ext import commands
from discord.errors import InteractionResponded
log = logging.getLogger(__name__)
class ErrorCog(commands.Cog):
"""
Error handling cog.
Discordpy has problems with error handling, this cog corrects
"""
default_err_msg = "I'm sorry, but I've encountered an " \
"error while processing your command."
def __init__(self, bot):
super().__init__()
self.bot = bot
# Register the error handler
bot.tree.error(coro=self._dispatch_to_app_command_handler)
def trace_error(self, error: Exception):
log.error(f"{type(error).__name__} {error}")
raise error
async def _dispatch_to_app_command_handler(
self,
inter: Interaction,
error: app_commands.AppCommandError
):
"""
Dispatches the error to the app command handler.
"""
self.bot.dispatch("app_command_error", inter, error)
async def _respond_to_interaction(self, inter: Interaction) -> bool:
"""
Respond to an interaction with an error message.
"""
try:
await inter.response.send_message(
self.default_err_msg,
ephemeral=True
)
except InteractionResponded:
log.debug("Interaction already responded to.")
return
@commands.Cog.listener("on_app_command_error")
async def get_app_command_error(
self,
inter: Interaction,
error: app_commands.AppCommandError
):
"""
Handles the application command error and responds with the appropriate error message.
"""
try:
# Send the default error message and create an edit
# shorthand to add more details to the message once
# we've figured out what the error is.
log.error(error.with_traceback(None))
await self._respond_to_interaction(inter)
edit = lambda x: inter.edit_original_response(content=x)
raise error
except app_commands.CommandInvokeError as _err:
# The interaction has already been responded to.
if isinstance(_err.original, InteractionResponded):
await edit(_err.original)
return
# Some other error occurred while invoking the command.
await edit(f"`{type(_err.original).__name__}` : {_err.original}")
except app_commands.CheckFailure as _err:
# The command is still on cooldown.
if isinstance(_err, app_commands.CommandOnCooldown):
await edit(
f"Woah, slow down! This command is on cooldown, " \
f"wait `{str(_err).split(' ')[7]}` !"
)
return
if isinstance(_err, app_commands.MissingPermissions):
await edit(
"You don't have the required permissions to " \
"run this command!"
)
return
if isinstance(_err, app_commands.BotMissingPermissions):
await edit(
"I don't have the required permissions to " \
"run this command! Please ask an admin to " \
"grant me the required permissions."
)
return
# A different check has failed.
await edit(f"`{type(_err).__name__}` : {_err}")
except app_commands.CommandNotFound:
# The command could not be found.
await edit(
f"I couldn't find the command you were looking for... "
"\nThis is probably a discord bug related to " \
"desynchronization between my commands and discord's " \
"servers. Please try again later."
)
except Exception as _err:
# Caught here:
# app_commands.TransformerError
# app_commands.CommandLimitReached
# app_commands.CommandAlreadyRegistered
# app_commands.CommandSignatureMismatch
self.trace_error(_err)
async def setup(bot):
"""
Setup function for this extension.
Adds the `ErrorCog` cog to the bot.
"""
cog = ErrorCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

335
src/extensions/news.py Normal file
View File

@ -0,0 +1,335 @@
"""
Extension for the `NewsCog` cog.
Loading this file via `commands.Bot.load_extension` will add the `NewsCog` cog to the bot.
"""
import asyncio
import logging
from datetime import datetime
import aiohttp
import discord
from discord import app_commands, Interaction
from discord.ext import commands, tasks
from sqlalchemy import select, insert, delete, and_
from bbc_feeds import news as news_api, feedparser
from bs4 import BeautifulSoup as bs4
from db.db import DatabaseManager
from db.models import ServerChannels, DefaultNewsCategories, NewsArticles
log = logging.getLogger(__name__)
category_choices = [
app_commands.Choice(name=category.name.replace("_", " "), value=category.value)
for category in DefaultNewsCategories
]
class NewsStoryType:
"""
Type hinting class for news stories.
"""
id: str
title: str
title_detail: str
summary: str
summary_detail: str
link: str
links: list[str]
guidislink: str
published: str
published_parsed: datetime
class NewsCog(commands.Cog):
"""
News cog.
Delivers embeds of news articles to discord channels.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
self.news_task.start()
def fetch_articles(self, category_id: int) -> list[NewsStoryType]:
"""
Fetch the latest news articles from the category matching the given `category_id`.
Arguments
---------
category_id : int
The ID of the category to fetch news articles from.
Raises
------
ValueError
The given `category_id` doesn't match any known categories.
"""
n = news_api()
# IMPORTANT:
# The order of items in this list matters,
# Items follow the order of items in the Enum: DefaultNewsCategories
category_methods = [
n.all,
n.world,
n.uk,
n.north_america,
n.entertainment,
n.business,
n.tech,
n.science,
n.top_stories
]
try:
return category_methods[category_id - 1](limit=1)
except IndexError as err:
raise ValueError(f"Invalid category_id: {category_id}") from err
async def story_to_embed(self, story: NewsStoryType, category_name: str) -> discord.Embed:
"""
Returns a discord.Embed object representing the given story.
Parameters
----------
story : NewsStoryType
Parsed details on the news story.
Returns
-------
discord.Embed
A `discord.Embed` object updated with information on the news story.
"""
# Fetch web data for the thumbnail
async with aiohttp.ClientSession() as session:
async with session.get(story.link) as response:
html = await response.text()
# Parse the thumbnail for the news story
soup = bs4(html, "html.parser")
image_src = soup.select_one("meta[property='og:image']").get("content")
category_name = category_name.replace("_", " ").title()
embed = discord.Embed(
colour=discord.Colour.from_str("#FFFFFF"),
title=story.title,
description=story.summary,
url=story.link,
)
embed.set_image(url=image_src)
embed.set_author(name=f"BBC News • {category_name}")
return embed
async def followup_with_articles(self, inter: Interaction, category: app_commands.Choice[int]):
"""
Collects articles and follows up an interaction with embeds for these articles.
Arguments
---------
inter : discord.Interaction
The interaction between the bot and the user.
category : app_commands.Choice[int]
The category chosen by the user, represented as a `Choice` object.
"""
stories = self.fetch_articles(category.value)
if not stories:
await inter.followup.send("No articles found")
return
for story in stories:
embed = await self.story_to_embed(story, category.name)
await inter.followup.send(embed=embed)
async def _get_or_fetch_channel(self, channel_id: int) -> discord.TextChannel:
"""
Returns a `discord.TextChannel` object based on the given `channel_id`.
Will try to return from cache if present, otherwise will make an API call for the channel.
"""
channel = self.bot.get_channel(channel_id)
return channel or await self.bot.fetch_channel(channel_id)
@tasks.loop(minutes=5)
async def news_task(self):
"""
Task that checks for the latest news and shares it to discord.
"""
log.info("Doing news task.")
async with DatabaseManager() as database:
for category in DefaultNewsCategories:
# Find channels that accept this category of news
whereclause = and_(ServerChannels.news_category_id == category.value)
query = select(ServerChannels).where(whereclause)
result = await database.session.execute(query)
server_channels = result.scalars().all()
if not server_channels:
continue
stories = self.fetch_articles(category.value)
if not stories:
continue
story = stories[0]
embed = None
for item in server_channels:
channel_id = item.channel_id
# Check the article hasn't already been shared with this channel.
query = select(NewsArticles).where(and_(NewsArticles.server_channel_id == channel_id, NewsArticles.url == story.link))
result = await database.session.execute(query)
existing_items = result.scalars().all()
if existing_items:
log_string = (
"Existing items not sent"
f"\nchannel id: {channel_id}"
f"\narticle url: {story.link}"
f"\ncategory: {category.name}\n"
)
log.info(log_string)
continue
log_string = (
"Items sent"
f"\nchannel id: {channel_id}"
f"\narticle url: {story.link}"
f"\ncategory: {category.name}\n"
)
log.info(log_string)
# Add this article as shown to this channel, prevents it from being shown here again.
query = insert(NewsArticles).values(server_channel_id=channel_id, url=story.link)
await database.session.execute(query)
embed = embed or await self.story_to_embed(story, category.name)
channel = await self._get_or_fetch_channel(channel_id)
await channel.send(embed=embed)
await database.session.commit() # commit will terminate the connection for some reason.
base_group = app_commands.Group(name="bbc", description="BBC News related commands")
channels_group = app_commands.Group(parent=base_group, name="channels", description="Channel commands")
@base_group.command(name="news")
@app_commands.choices(category=category_choices)
async def get_news(self, inter: Interaction, category: app_commands.Choice[int]):
"""
Get the latest article from BBC news.
"""
await inter.response.defer()
await self.followup_with_articles(inter, category)
# @base_group.command(name="ping-me")
# async def ping_me(self, inter: Interaction):
# """
# Ping you when a new news article has been posted. Use again to remove ping.
# """
# await inter.response.send_message("response OK")
@channels_group.command(name="add-category")
@app_commands.choices(category=category_choices)
async def add_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]):
"""
Assign a category to this channel. Articles under this category will be sent here.
"""
await inter.response.defer()
async with DatabaseManager() as database:
query = insert(ServerChannels).values(server_id=inter.guild_id, channel_id=channel.id, news_category_id=category.value)
await database.session.execute(query)
await database.session.commit()
category_name = category.name.replace("_", " ").title()
embed = discord.Embed(
title=f"Category Added • {category_name}",
description=f"{channel.mention} will now receive news from **{category_name}**",
colour=discord.Colour.from_str("#FFFFFF")
)
await inter.followup.send(embed=embed)
@channels_group.command(name="del-category")
@app_commands.choices(category=category_choices)
async def delete_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]):
"""
Remove a category from this channel.
"""
await inter.response.defer()
async with DatabaseManager() as database:
whereclause = and_(
ServerChannels.channel_id == channel.id,
ServerChannels.news_category_id == category.value
)
query = delete(ServerChannels).where(whereclause)
await database.session.execute(query)
await database.session.commit()
category_name = category.name.replace("_", " ").title()
embed = discord.Embed(
title=f"Category Removed • {category_name}",
description=f"{channel.mention} will no longer receive news from **{category_name}**",
colour=discord.Colour.from_str("#FFFFFF")
)
await inter.followup.send(embed=embed)
@channels_group.command(name="lst-category")
async def list_category(self, inter: Interaction):
"""
List the categories assigned to this channel.
"""
await inter.response.defer()
async with DatabaseManager() as database:
query = select(ServerChannels)
result = await database.session.execute(query)
server_channels = result.scalars().all()
output = ""
for item in server_channels:
if item.server_id != inter.guild_id:
continue
channel = inter.guild.get_channel(item.channel_id)
category_name = DefaultNewsCategories(item.news_category_id).name.replace("_", " ")
output += f"{channel.mention} - {category_name}\n"
output = output or "No categories set"
await inter.followup.send(output)
async def setup(bot):
"""
Setup function for this extension.
Adds the `ErrorCog` cog to the bot.
"""
cog = NewsCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")

95
src/logs.py Normal file
View File

@ -0,0 +1,95 @@
"""
Handle async logging for the project.
"""
import sys
import queue
import logging
from logging.handlers import QueueHandler, QueueListener
from datetime import datetime, timedelta
from itertools import count
from typing import TextIO
from pathlib import Path
LOGS_DIRECTORY = "logs/"
LOG_FILENAME_FORMAT_PREFIX = "%Y-%m-%d %H-%M-%S"
MAX_LOGFILE_AGE_DAYS = 7
log = logging.getLogger(__name__)
def _open_file() -> TextIO:
"""
Returns a file object for the current log file.
"""
# Create the logs directory if it doesnt exist
Path(LOGS_DIRECTORY).mkdir(exist_ok=True)
# Create a generator to generate a unique filename
timestamp = datetime.now().strftime(LOG_FILENAME_FORMAT_PREFIX)
filenames = (f'{timestamp}.log' if i == 0 else f'{timestamp}_({i}).log' for i in count())
# Find a filename that doesn't already exist and return it
for filename in filenames:
try:
return (Path(f'{LOGS_DIRECTORY}/{filename}').open('x', encoding='utf-8'))
except FileExistsError:
continue
def _delete_old_logs():
"""
Search through the logs directory and delete any expired log files.
"""
for path in Path(LOGS_DIRECTORY).glob('*.txt'):
prefix = path.stem.split('_')[0]
try:
log_date = datetime.strptime(prefix, LOG_FILENAME_FORMAT_PREFIX)
except ValueError:
log.warning(f'{path.parent} contains a problematic filename: {path.name}')
continue
age = datetime.now() - log_date
if age >= timedelta(days=MAX_LOGFILE_AGE_DAYS):
log.info(f'Removing expired log file: {path.name}')
path.unlink()
def update_log_levels(logger_names:tuple[str], level:int):
"""
Quick way to update the log level of multiple loggers at once.
"""
for name in logger_names:
logger=logging.getLogger(name)
logger.setLevel(level)
def setup_logs(log_level:int=logging.DEBUG) -> str:
"""
Setup a logging queue handler and queue listener.
Also creates a new log file for the current session and deletes old log files.
"""
# Create a queue to pass log records to the listener
log_queue = queue.Queue()
queue_handler = QueueHandler(log_queue)
# Configure the root logger to use the queue
logging.basicConfig(
level=log_level,
handlers=(queue_handler,),
format='[%(asctime)s] [%(levelname)-8s] [%(name)-17s]: %(message)s'
)
# Create a new log file
file = _open_file()
file_handler = logging.StreamHandler(file) # Stream logs to the log file
sys_handler = logging.StreamHandler(sys.stdout) # Stream logs to the console
# Create a listener to handle the queue
queue_listener = QueueListener(log_queue, file_handler, sys_handler)
queue_listener.start()
# Clear up old log files
_delete_old_logs()
return file.name

View File

@ -1,16 +1,38 @@
"""Entry point for the application."""
"""
Entry point for the application.
Run this file to get started.
"""
import logging
import asyncio
from bot import DiscordBot
from bot import DiscordBot
from logs import setup_logs, update_log_levels
async def main():
"""
Entry point function for the application.
Run this function to get started.
"""
# Grab the token before anything else, because if there is no token
# available then the bot cannot be started anyways.
with open("TOKEN", "r") as token_file:
token = token_file.read()
await DiscordBot().start(token)
if not token:
raise ValueError("Token file is empty")
setup_logs()
update_log_levels(
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
level=logging.WARNING
)
async with DiscordBot() as bot:
await bot.load_extensions()
await bot.start(token)
if __name__ == "__main__":
asyncio.run(main())