diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..0aba14d --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Discord Bot", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/src/main.py", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/db/build.sql b/db/build.sql new file mode 100644 index 0000000..ffe3f2f --- /dev/null +++ b/db/build.sql @@ -0,0 +1,42 @@ + +/* + Server Channels +*/ +CREATE TABLE IF NOT EXISTS 'server_channels' ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + channel_id INTEGER NOT NULL, + news_category_id INTEGER NOT NULL, + active INTEGER NOT NULL, + FOREIGN KEY (news_category_id) REFERENCES 'news_categories' (id) + ON DELETE CASCADE +); + + +/* + News Articles +*/ +CREATE TABLE IF NOT EXISTS 'news_articles' ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT NOT NULL, + server_channel_id INTEGER NOT NULL, + FOREIGN KEY (server_channel_id) REFERENCES 'server_channels' (id) + ON DELETE CASCADE +); + + +/* + News Categories +*/ +CREATE TABLE IF NOT EXISTS 'news_categories' ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL +); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('all'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('world'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('uk'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('north_america'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('entertainment'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('business'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('tech'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('science'); +INSERT OR IGNORE INTO 'news_categories' (name) VALUES ('top_stories'); diff --git a/db/db.sqlite b/db/db.sqlite new file mode 100644 index 0000000..c8dfd5f Binary files /dev/null and b/db/db.sqlite differ diff --git a/src/bot.py b/src/bot.py index 502c127..aa4da6d 100644 --- a/src/bot.py +++ b/src/bot.py @@ -1,6 +1,10 @@ -"""The discord bot for the application""" +""" +The discord bot for the application. +""" +import os import time +import logging from datetime import datetime import aiohttp @@ -10,6 +14,9 @@ from discord import Intents, Interaction, app_commands from discord.ext import commands, tasks from bbc_feeds import news +log = logging.getLogger(__name__) +EXTENSIONS_DIRECTORY = "src/extensions/" + class DiscordBot(commands.Bot): @@ -17,24 +24,36 @@ class DiscordBot(commands.Bot): super().__init__(command_prefix="-", intents=Intents.all()) async def sync_app_commands(self): - """Sync application commands""" + """ + Sync application commands. + """ await self.wait_until_ready() await self.tree.sync() - print("app commands synced") + log.info("Application commands successfully synced") async def on_ready(self): - """When the bot is ready""" + """ + When the bot is ready. + """ - await self.add_cog(CommandsCog(self)) - await self.add_cog(ErrorCog(self)) await self.sync_app_commands() + async def load_extensions(self): + """ + Load any extensions found in the extensions dictionary. + """ + + for filename in os.listdir(EXTENSIONS_DIRECTORY): + if filename.endswith(".py"): + await self.load_extension(f"extensions.{filename[:-3]}") + + class CommandsCog(commands.Cog): def __init__(self, bot): self.bot = bot - self.news_task.start() + # self.news_task.start() async def story_to_embed(self, story) -> discord.Embed: """ @@ -97,135 +116,3 @@ class CommandsCog(commands.Cog): embed = await self.story_to_embed(story) channel = self.bot.get_channel(1057004889458348042) await channel.send(embed=embed) - - -class ErrorCog(commands.Cog): - """Error handling cog.""" - - __slots__ = () - default_err_msg = "I'm sorry, but I've encountered an " \ - "error while processing your command." - - def __init__(self, bot): - super().__init__() - self.bot = bot - - # Register the error handler - bot.tree.error(coro = self._dispatch_to_app_command_handler) - - def trace_error(self, error: Exception): - print(f"{type(error).__name__} {error}") - raise error - - async def _dispatch_to_app_command_handler( - self, - inter: Interaction, - error: app_commands.AppCommandError - ): - """Dispatches the error to the app command handler""" - - self.bot.dispatch("app_command_error", inter, error) - - async def _respond_to_interaction(self, inter: Interaction) -> bool: - """Respond to an interaction with an error message""" - - try: - await inter.response.send_message( - self.default_err_msg, - ephemeral=True - ) - except discord.InteractionResponded: - return - - @commands.Cog.listener("on_app_command_error") - async def get_app_command_error( - self, - inter: Interaction, - error: app_commands.AppCommandError - ): - """Handles the application command error. - - Responds with the appropriate error message. - """ - - try: - # Send the default error message and create an edit - # shorthand to add more details to the message once - # we've figured out what the error is. - print(error.with_traceback(None)) - await self._respond_to_interaction(inter) - edit = lambda x: inter.edit_original_response(content=x) - - raise error - - except app_commands.CommandInvokeError as _err: - - # The interaction has already been responded to. - if isinstance( - _err.original, - discord.InteractionResponded - ): - await edit(_err.original) - return - - # Some other error occurred while invoking the command. - await edit( - f"`{type(_err.original).__name__}` " \ - f": {_err.original}" - ) - - except app_commands.CheckFailure as _err: - - # The command is still on cooldown. - if isinstance( - _err, - app_commands.CommandOnCooldown - ): - await edit( - f"Woah, slow down! This command is on cooldown, " \ - f"wait `{str(_err).split(' ')[7]}` !" - ) - return - - if isinstance( - _err, - app_commands.MissingPermissions - ): - await edit( - "You don't have the required permissions to " \ - "run this command!" - ) - return - - if isinstance( - _err, - app_commands.BotMissingPermissions - ): - await edit( - "I don't have the required permissions to " \ - "run this command! Please ask an admin to " \ - "grant me the required permissions." - ) - return - - # A different check has failed. - await edit(f"`{type(_err).__name__}` : {_err}") - - except app_commands.CommandNotFound: - - # The command could not be found. - await edit( - f"I couldn't find the command you were looking for... " - "\nThis is probably a discord bug related to " \ - "desynchronization between my commands and discord's " \ - "servers. Please try again later." - ) - - except Exception as _err: - # Caught here: - # app_commands.TransformerError - # app_commands.CommandLimitReached - # app_commands.CommandAlreadyRegistered - # app_commands.CommandSignatureMismatch - - self.trace_error(_err) diff --git a/src/db/__init__.py b/src/db/__init__.py new file mode 100644 index 0000000..580dd9b --- /dev/null +++ b/src/db/__init__.py @@ -0,0 +1,27 @@ +""" +Initialize the database modules. +""" + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from .db import DATABASE_URL +from .models import NewsCategories, Base, get_or_create, DefaultNewsCategories + +engine = create_engine(DATABASE_URL) +Session = sessionmaker(bind=engine) + +Base.metadata.create_all(engine) + +session = Session() + +default_categories = [ + (category.name, category.value) + for category in DefaultNewsCategories +] + +for category_name, category_id in default_categories: + get_or_create(session, model=NewsCategories, id=category_id, name=category_name) + +session.commit() +session.close() diff --git a/src/db/db.py b/src/db/db.py new file mode 100644 index 0000000..3983a3f --- /dev/null +++ b/src/db/db.py @@ -0,0 +1,147 @@ +""" + +""" + +import logging +import aiosqlite +from os.path import isfile +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + +DB_PATH = "db/db.sqlite" +BUILD_PATH = "db/build.sql" +DATABASE_URL = "sqlite:///db/db.sqlite" +DATABASE_ASYNC_URL = "sqlite+aiosqlite:///db/db.sqlite" + +log = logging.getLogger(__name__) + + +class DatabaseManager: + def __init__(self, database_url=DATABASE_ASYNC_URL): + self.engine = create_async_engine(database_url, future=True) + self.session_maker = sessionmaker(self.engine, class_=AsyncSession) + self.session = None + + async def __aenter__(self): + self.session = self.session_maker() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.session.close() + await self.session.commit() + self.session = None + await self.engine.dispose() + + +# class DBConnection: +# """ +# Asynchronous context manager for database connections. +# """ + +# def __init__(self): +# self.conn = None + +# async def __aenter__(self): +# self.conn = await connect() +# return self.conn + +# async def __aexit__(self, *args): +# await close(self.conn) + + +# async def connect(): +# log.info("Opening database connection") +# return await aiosqlite.connect(DB_PATH) + +# async def with_commit(func): +# """ +# Wrapper to commit changes to the database. +# """ +# async def inner(*args, **kwargs): +# await func(*args, **kwargs) +# await commit() + +# return inner + +# async def build(conn): +# """ +# Build the database from the build script. +# """ +# log.info("Building database from build script") + +# if isfile(BUILD_PATH): +# await scriptexec(conn, BUILD_PATH) +# return + +# raise ValueError('Build script not found') + +# async def commit(conn): +# """ +# Commit changes to the database. +# """ +# log.info("Committing database changes") +# await conn.commit() + +# async def close(conn): +# """ +# Close the database connection. +# """ +# log.debug("Closing database connection") +# await conn.close() + +# async def field(conn, cmd, *vals): +# """ +# Return a single field. +# """ +# log.debug("Executing command for field: %s, vals:%s", cmd, vals) +# async with conn.execute(cmd, tuple(vals)) as cur: +# if (fetch := await cur.fetchone()) is not None: +# return fetch[0] + +# async def record(conn, cmd, *vals): +# """ +# Return a single record. +# """ +# log.debug("Executing command for record: %s, vals: %s", cmd, vals) +# async with conn.execute(cmd, tuple(vals)) as cur: +# return await cur.fetchone() + +# async def records(conn, cmd, *vals): +# """ +# Return all records. +# """ +# log.debug("Executing command for records: %s, vals: %s", cmd, vals) +# async with conn.execute(cmd, tuple(vals)) as cur: +# return await cur.fetchall() + +# async def column(conn, cmd, *vals): +# """ +# Return a single column. +# """ +# log.debug("Executing command for column: %s, vals: %s", cmd, vals) +# async with conn.execute(cmd, tuple(vals)) as cur: +# return [item[0] for item in await cur.fetchall()] + +# async def execute(conn, cmd, *vals): +# """ +# Execute a command. +# """ +# log.debug("Executing command: %s, vals: %s", cmd, vals) +# async with conn.execute(cmd, tuple(vals)) as cur: +# return cur + +# async def multiexec(conn, cmd, valset): +# """ +# Execute multiple commands. +# """ +# log.debug("Executing multiple commands: %s, valset: %s", cmd, valset) +# async with conn.executemany(cmd, valset): +# pass + +# async def scriptexec(conn, path): +# """ +# Execute a script. +# """ +# log.debug("Executing script: %s", path) +# with open(path, 'r', encoding='utf-8') as script: +# await conn.executescript(script.read()) diff --git a/src/db/models.py b/src/db/models.py new file mode 100644 index 0000000..25c0992 --- /dev/null +++ b/src/db/models.py @@ -0,0 +1,63 @@ +""" + +""" + +from enum import Enum, auto + +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +def get_or_create(session, model, **kwargs): + instance = session.query(model).filter_by(**kwargs).first() + if instance: + return instance + else: + instance = model(**kwargs) + session.add(instance) + session.commit() + return instance + + +class ServerChannels(Base): + __tablename__ = 'server_channels' + + id = Column(Integer, primary_key=True, autoincrement=True) + server_id = Column(Integer, nullable=False) + channel_id = Column(Integer, nullable=False) + news_category_id = Column(Integer, nullable=False) + active = Column(Integer, nullable=False, default=True) + news_articles = relationship('NewsArticles', cascade='all, delete') + + +class NewsArticles(Base): + __tablename__ = 'news_articles' + + id = Column(Integer, primary_key=True, autoincrement=True) + url = Column(String, nullable=False) + server_channel_id = Column(Integer, ForeignKey('server_channels.id'), nullable=False) + + +class NewsCategories(Base): + __tablename__ = 'news_categories' + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, unique=True, nullable=False) + + +class DefaultNewsCategories(Enum): + """ + + """ + + ALL = auto() + WORLD = auto() + UK = auto() + NORTH_AMERICA = auto() + ENTERTAINMENT = auto() + BUSINESS = auto() + TECH = auto() + SCIENCE = auto() + TOP_STORIES = auto() diff --git a/src/extensions/errors.py b/src/extensions/errors.py new file mode 100644 index 0000000..9d7eac0 --- /dev/null +++ b/src/extensions/errors.py @@ -0,0 +1,146 @@ +""" +Extension for the `ErrorCog` cog. +Loading this file via `commands.Bot.load_extension` will add the `ErrorCog` cog to the bot. +""" + +import logging + +from discord import app_commands, Interaction +from discord.ext import commands +from discord.errors import InteractionResponded + +log = logging.getLogger(__name__) + + +class ErrorCog(commands.Cog): + """ + Error handling cog. + Discordpy has problems with error handling, this cog corrects + """ + + default_err_msg = "I'm sorry, but I've encountered an " \ + "error while processing your command." + + def __init__(self, bot): + super().__init__() + self.bot = bot + + # Register the error handler + bot.tree.error(coro=self._dispatch_to_app_command_handler) + + def trace_error(self, error: Exception): + log.error(f"{type(error).__name__} {error}") + raise error + + async def _dispatch_to_app_command_handler( + self, + inter: Interaction, + error: app_commands.AppCommandError + ): + """ + Dispatches the error to the app command handler. + """ + + self.bot.dispatch("app_command_error", inter, error) + + async def _respond_to_interaction(self, inter: Interaction) -> bool: + """ + Respond to an interaction with an error message. + """ + + try: + await inter.response.send_message( + self.default_err_msg, + ephemeral=True + ) + except InteractionResponded: + log.debug("Interaction already responded to.") + return + + @commands.Cog.listener("on_app_command_error") + async def get_app_command_error( + self, + inter: Interaction, + error: app_commands.AppCommandError + ): + """ + Handles the application command error and responds with the appropriate error message. + """ + + try: + # Send the default error message and create an edit + # shorthand to add more details to the message once + # we've figured out what the error is. + log.error(error.with_traceback(None)) + await self._respond_to_interaction(inter) + edit = lambda x: inter.edit_original_response(content=x) + + raise error + + except app_commands.CommandInvokeError as _err: + + # The interaction has already been responded to. + if isinstance(_err.original, InteractionResponded): + await edit(_err.original) + return + + # Some other error occurred while invoking the command. + await edit(f"`{type(_err.original).__name__}` : {_err.original}") + + except app_commands.CheckFailure as _err: + + # The command is still on cooldown. + if isinstance(_err, app_commands.CommandOnCooldown): + await edit( + f"Woah, slow down! This command is on cooldown, " \ + f"wait `{str(_err).split(' ')[7]}` !" + ) + return + + if isinstance(_err, app_commands.MissingPermissions): + await edit( + "You don't have the required permissions to " \ + "run this command!" + ) + return + + if isinstance(_err, app_commands.BotMissingPermissions): + await edit( + "I don't have the required permissions to " \ + "run this command! Please ask an admin to " \ + "grant me the required permissions." + ) + return + + # A different check has failed. + await edit(f"`{type(_err).__name__}` : {_err}") + + except app_commands.CommandNotFound: + + # The command could not be found. + await edit( + f"I couldn't find the command you were looking for... " + "\nThis is probably a discord bug related to " \ + "desynchronization between my commands and discord's " \ + "servers. Please try again later." + ) + + except Exception as _err: + # Caught here: + # app_commands.TransformerError + # app_commands.CommandLimitReached + # app_commands.CommandAlreadyRegistered + # app_commands.CommandSignatureMismatch + + self.trace_error(_err) + + +async def setup(bot): + """ + Setup function for this extension. + Adds the `ErrorCog` cog to the bot. + """ + + cog = ErrorCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/extensions/news.py b/src/extensions/news.py new file mode 100644 index 0000000..e05ea12 --- /dev/null +++ b/src/extensions/news.py @@ -0,0 +1,335 @@ +""" +Extension for the `NewsCog` cog. +Loading this file via `commands.Bot.load_extension` will add the `NewsCog` cog to the bot. +""" + +import asyncio +import logging +from datetime import datetime + +import aiohttp +import discord +from discord import app_commands, Interaction +from discord.ext import commands, tasks +from sqlalchemy import select, insert, delete, and_ +from bbc_feeds import news as news_api, feedparser +from bs4 import BeautifulSoup as bs4 + +from db.db import DatabaseManager +from db.models import ServerChannels, DefaultNewsCategories, NewsArticles + +log = logging.getLogger(__name__) + +category_choices = [ + app_commands.Choice(name=category.name.replace("_", " "), value=category.value) + for category in DefaultNewsCategories +] + + +class NewsStoryType: + """ + Type hinting class for news stories. + """ + + id: str + title: str + title_detail: str + summary: str + summary_detail: str + link: str + links: list[str] + guidislink: str + published: str + published_parsed: datetime + + +class NewsCog(commands.Cog): + """ + News cog. + Delivers embeds of news articles to discord channels. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + self.news_task.start() + + def fetch_articles(self, category_id: int) -> list[NewsStoryType]: + """ + Fetch the latest news articles from the category matching the given `category_id`. + + Arguments + --------- + category_id : int + The ID of the category to fetch news articles from. + + Raises + ------ + ValueError + The given `category_id` doesn't match any known categories. + """ + + n = news_api() + + # IMPORTANT: + # The order of items in this list matters, + # Items follow the order of items in the Enum: DefaultNewsCategories + category_methods = [ + n.all, + n.world, + n.uk, + n.north_america, + n.entertainment, + n.business, + n.tech, + n.science, + n.top_stories + ] + + try: + return category_methods[category_id - 1](limit=1) + except IndexError as err: + raise ValueError(f"Invalid category_id: {category_id}") from err + + async def story_to_embed(self, story: NewsStoryType, category_name: str) -> discord.Embed: + """ + Returns a discord.Embed object representing the given story. + + Parameters + ---------- + story : NewsStoryType + Parsed details on the news story. + + Returns + ------- + discord.Embed + A `discord.Embed` object updated with information on the news story. + """ + + # Fetch web data for the thumbnail + async with aiohttp.ClientSession() as session: + async with session.get(story.link) as response: + html = await response.text() + + # Parse the thumbnail for the news story + soup = bs4(html, "html.parser") + image_src = soup.select_one("meta[property='og:image']").get("content") + + category_name = category_name.replace("_", " ").title() + + embed = discord.Embed( + colour=discord.Colour.from_str("#FFFFFF"), + title=story.title, + description=story.summary, + url=story.link, + ) + embed.set_image(url=image_src) + embed.set_author(name=f"BBC News • {category_name}") + + return embed + + async def followup_with_articles(self, inter: Interaction, category: app_commands.Choice[int]): + """ + Collects articles and follows up an interaction with embeds for these articles. + + Arguments + --------- + inter : discord.Interaction + The interaction between the bot and the user. + + category : app_commands.Choice[int] + The category chosen by the user, represented as a `Choice` object. + """ + + stories = self.fetch_articles(category.value) + + if not stories: + await inter.followup.send("No articles found") + return + + for story in stories: + embed = await self.story_to_embed(story, category.name) + await inter.followup.send(embed=embed) + + async def _get_or_fetch_channel(self, channel_id: int) -> discord.TextChannel: + """ + Returns a `discord.TextChannel` object based on the given `channel_id`. + Will try to return from cache if present, otherwise will make an API call for the channel. + """ + + channel = self.bot.get_channel(channel_id) + return channel or await self.bot.fetch_channel(channel_id) + + @tasks.loop(minutes=5) + async def news_task(self): + """ + Task that checks for the latest news and shares it to discord. + """ + + log.info("Doing news task.") + + async with DatabaseManager() as database: + for category in DefaultNewsCategories: + + # Find channels that accept this category of news + whereclause = and_(ServerChannels.news_category_id == category.value) + query = select(ServerChannels).where(whereclause) + result = await database.session.execute(query) + server_channels = result.scalars().all() + + if not server_channels: + continue + + stories = self.fetch_articles(category.value) + if not stories: + continue + + story = stories[0] + embed = None + + for item in server_channels: + + channel_id = item.channel_id + + # Check the article hasn't already been shared with this channel. + query = select(NewsArticles).where(and_(NewsArticles.server_channel_id == channel_id, NewsArticles.url == story.link)) + result = await database.session.execute(query) + existing_items = result.scalars().all() + + if existing_items: + log_string = ( + "Existing items not sent" + f"\nchannel id: {channel_id}" + f"\narticle url: {story.link}" + f"\ncategory: {category.name}\n" + ) + log.info(log_string) + continue + + log_string = ( + "Items sent" + f"\nchannel id: {channel_id}" + f"\narticle url: {story.link}" + f"\ncategory: {category.name}\n" + ) + log.info(log_string) + + # Add this article as shown to this channel, prevents it from being shown here again. + query = insert(NewsArticles).values(server_channel_id=channel_id, url=story.link) + await database.session.execute(query) + + embed = embed or await self.story_to_embed(story, category.name) + + channel = await self._get_or_fetch_channel(channel_id) + await channel.send(embed=embed) + + await database.session.commit() # commit will terminate the connection for some reason. + + base_group = app_commands.Group(name="bbc", description="BBC News related commands") + channels_group = app_commands.Group(parent=base_group, name="channels", description="Channel commands") + + @base_group.command(name="news") + @app_commands.choices(category=category_choices) + async def get_news(self, inter: Interaction, category: app_commands.Choice[int]): + """ + Get the latest article from BBC news. + """ + + await inter.response.defer() + await self.followup_with_articles(inter, category) + + # @base_group.command(name="ping-me") + # async def ping_me(self, inter: Interaction): + # """ + # Ping you when a new news article has been posted. Use again to remove ping. + # """ + + # await inter.response.send_message("response OK") + + @channels_group.command(name="add-category") + @app_commands.choices(category=category_choices) + async def add_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]): + """ + Assign a category to this channel. Articles under this category will be sent here. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + query = insert(ServerChannels).values(server_id=inter.guild_id, channel_id=channel.id, news_category_id=category.value) + await database.session.execute(query) + await database.session.commit() + + category_name = category.name.replace("_", " ").title() + embed = discord.Embed( + title=f"Category Added • {category_name}", + description=f"{channel.mention} will now receive news from **{category_name}**", + colour=discord.Colour.from_str("#FFFFFF") + ) + + await inter.followup.send(embed=embed) + + @channels_group.command(name="del-category") + @app_commands.choices(category=category_choices) + async def delete_category(self, inter: Interaction, channel: discord.TextChannel, category: app_commands.Choice[int]): + """ + Remove a category from this channel. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + whereclause = and_( + ServerChannels.channel_id == channel.id, + ServerChannels.news_category_id == category.value + ) + query = delete(ServerChannels).where(whereclause) + await database.session.execute(query) + await database.session.commit() + + category_name = category.name.replace("_", " ").title() + embed = discord.Embed( + title=f"Category Removed • {category_name}", + description=f"{channel.mention} will no longer receive news from **{category_name}**", + colour=discord.Colour.from_str("#FFFFFF") + ) + + await inter.followup.send(embed=embed) + + @channels_group.command(name="lst-category") + async def list_category(self, inter: Interaction): + """ + List the categories assigned to this channel. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + query = select(ServerChannels) + result = await database.session.execute(query) + server_channels = result.scalars().all() + + output = "" + + for item in server_channels: + if item.server_id != inter.guild_id: + continue + + channel = inter.guild.get_channel(item.channel_id) + category_name = DefaultNewsCategories(item.news_category_id).name.replace("_", " ") + + output += f"{channel.mention} - {category_name}\n" + + output = output or "No categories set" + await inter.followup.send(output) + + +async def setup(bot): + """ + Setup function for this extension. + Adds the `ErrorCog` cog to the bot. + """ + + cog = NewsCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/logs.py b/src/logs.py new file mode 100644 index 0000000..9b91491 --- /dev/null +++ b/src/logs.py @@ -0,0 +1,95 @@ +""" +Handle async logging for the project. +""" + +import sys +import queue +import logging +from logging.handlers import QueueHandler, QueueListener +from datetime import datetime, timedelta +from itertools import count +from typing import TextIO +from pathlib import Path + +LOGS_DIRECTORY = "logs/" +LOG_FILENAME_FORMAT_PREFIX = "%Y-%m-%d %H-%M-%S" +MAX_LOGFILE_AGE_DAYS = 7 + +log = logging.getLogger(__name__) + +def _open_file() -> TextIO: + """ + Returns a file object for the current log file. + """ + + # Create the logs directory if it doesnt exist + Path(LOGS_DIRECTORY).mkdir(exist_ok=True) + + # Create a generator to generate a unique filename + timestamp = datetime.now().strftime(LOG_FILENAME_FORMAT_PREFIX) + filenames = (f'{timestamp}.log' if i == 0 else f'{timestamp}_({i}).log' for i in count()) + + # Find a filename that doesn't already exist and return it + for filename in filenames: + try: + return (Path(f'{LOGS_DIRECTORY}/{filename}').open('x', encoding='utf-8')) + except FileExistsError: + continue + +def _delete_old_logs(): + """ + Search through the logs directory and delete any expired log files. + """ + + for path in Path(LOGS_DIRECTORY).glob('*.txt'): + prefix = path.stem.split('_')[0] + try: + log_date = datetime.strptime(prefix, LOG_FILENAME_FORMAT_PREFIX) + except ValueError: + log.warning(f'{path.parent} contains a problematic filename: {path.name}') + continue + + age = datetime.now() - log_date + if age >= timedelta(days=MAX_LOGFILE_AGE_DAYS): + log.info(f'Removing expired log file: {path.name}') + path.unlink() + +def update_log_levels(logger_names:tuple[str], level:int): + """ + Quick way to update the log level of multiple loggers at once. + """ + for name in logger_names: + logger=logging.getLogger(name) + logger.setLevel(level) + +def setup_logs(log_level:int=logging.DEBUG) -> str: + """ + Setup a logging queue handler and queue listener. + Also creates a new log file for the current session and deletes old log files. + """ + + # Create a queue to pass log records to the listener + log_queue = queue.Queue() + queue_handler = QueueHandler(log_queue) + + # Configure the root logger to use the queue + logging.basicConfig( + level=log_level, + handlers=(queue_handler,), + format='[%(asctime)s] [%(levelname)-8s] [%(name)-17s]: %(message)s' + ) + + # Create a new log file + file = _open_file() + + file_handler = logging.StreamHandler(file) # Stream logs to the log file + sys_handler = logging.StreamHandler(sys.stdout) # Stream logs to the console + + # Create a listener to handle the queue + queue_listener = QueueListener(log_queue, file_handler, sys_handler) + queue_listener.start() + + # Clear up old log files + _delete_old_logs() + + return file.name diff --git a/src/main.py b/src/main.py index 2b944ef..eb4e23b 100644 --- a/src/main.py +++ b/src/main.py @@ -1,16 +1,38 @@ -"""Entry point for the application.""" +""" +Entry point for the application. +Run this file to get started. +""" +import logging import asyncio -from bot import DiscordBot +from bot import DiscordBot +from logs import setup_logs, update_log_levels async def main(): + """ + Entry point function for the application. + Run this function to get started. + """ + # Grab the token before anything else, because if there is no token + # available then the bot cannot be started anyways. with open("TOKEN", "r") as token_file: token = token_file.read() - await DiscordBot().start(token) + if not token: + raise ValueError("Token file is empty") + setup_logs() + update_log_levels( + ('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'), + level=logging.WARNING + ) + + async with DiscordBot() as bot: + await bot.load_extensions() + await bot.start(token) if __name__ == "__main__": + asyncio.run(main())