diff --git a/README.md b/README.md index 1102892..bdd0618 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,11 @@ Plans - Multiple news providers - Choose how much of each provider should be delivered - Check for duplicate articles between providers, and only deliver preferred provider article + + +## Dev Notes: + +For the sake of development, the following defintions apply: + +- Feed - An RSS feed stored within the database, submitted by a user. +- Assigned Feed - A discord channel set to receive content from a Feed. \ No newline at end of file diff --git a/src/db/models.py b/src/db/models.py index 2fd5b2d..66a12ca 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -18,6 +18,7 @@ from sqlalchemy import ( Base = declarative_base() +# back in wed, thu, fri, off new year day then back in after class AuditModel(Base): """ @@ -81,7 +82,7 @@ class FeedChannelModel(Base): search_name = Column(String, nullable=False) rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) - rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined") + rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete") # the rss source must be unique, but only within the same discord channel __table_args__ = ( diff --git a/src/extensions/rss.py b/src/extensions/rss.py index ceb1d98..c911b32 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -12,10 +12,11 @@ from discord.ext import commands from discord import Interaction, Embed, Colour, TextChannel from discord.app_commands import Choice, Group, autocomplete, choices, rename from sqlalchemy import insert, select, and_, delete +from sqlalchemy.exc import NoResultFound -from utils import get_rss_data, followup, audit # pylint: disable=E0401 -from feed import get_source, Source # pylint: disable=E0401 -from db import ( # pylint: disable=E0401 +from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401 +from feed import get_source, Source # pylint: disable=E0401 +from db import ( # pylint: disable=E0401 DatabaseManager, SentArticleModel, RssSourceModel, @@ -28,6 +29,11 @@ rss_list_sort_choices = [ Choice(name="Nickname", value=0), Choice(name="Date Added", value=1) ] +channels_list_sort_choices=[ + Choice(name="Feed Nickname", value=0), + Choice(name="Channel ID", value=1), + Choice(name="Date Added", value=2) +] # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the # target resource. Run a period task to check this. @@ -169,6 +175,7 @@ class FeedCog(commands.Cog): await followup(inter, embed=embed) @feed_group.command(name="remove") + @rename(url="option") @autocomplete(url=source_autocomplete) async def remove_rss_source(self, inter: Interaction, url: str): """Delete an existing RSS source. @@ -186,35 +193,34 @@ class FeedCog(commands.Cog): log.debug("Attempting to remove RSS source (url=%s)", url) async with DatabaseManager() as database: - select_result = await database.session.execute( - select(RssSourceModel).filter( - and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - ) + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == url ) - rss_source = select_result.scalars().one() + + # We will select the item first, so we can reference it's nickname later. + select_query = select(RssSourceModel).filter(whereclause) + select_result = await database.session.execute(select_query) + + try: + rss_source = select_result.scalars().one() + except NoResultFound: + await followup_error(inter, + title="Error Deleting Feed", + message=f"I couldn't find anything for `{url}`" + ) + return + nickname = rss_source.nick - delete_result = await database.session.execute( - delete(RssSourceModel).filter( - and_( - RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == url - ) - ) - ) + delete_query = delete(RssSourceModel).filter(whereclause) + delete_result = await database.session.execute(delete_query) await audit(self, - f"Added RSS source ({nickname=}, {url=})", + f"Deleted RSS source ({nickname=}, {url=})", inter.user.id, database=database ) - if not delete_result.rowcount: - await followup(inter, "Couldn't find any RSS sources with this name.") - return - source = get_source(url) embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) @@ -269,7 +275,10 @@ class FeedCog(commands.Cog): rowcount = len(rss_sources) if not rss_sources: - await followup(inter, "It looks like you have no rss sources.") + await followup_error(inter, + title="No Feeds Found", + message="I couldn't find any Feeds for this server." + ) return output = "\n".join([ @@ -286,9 +295,9 @@ class FeedCog(commands.Cog): await followup(inter, embed=embed) - @feed_group.command(name="fetch") - @rename(max_="max") - @autocomplete(rss=source_autocomplete) + # @feed_group.command(name="fetch") + # @rename(max_="max") + # @autocomplete(rss=source_autocomplete) async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1): """Fetch an item from the specified RSS feed. @@ -467,6 +476,7 @@ class FeedCog(commands.Cog): # ) @feed_group.command(name="assign") + @rename(rss="feed") @autocomplete(rss=autocomplete_rss_sources) async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None): """Include a feed within the specified channel. @@ -531,18 +541,17 @@ class FeedCog(commands.Cog): result = await database.session.execute(query) if not result.rowcount: - await followup(inter, "I couldn't find any items under that ID (placeholder response)") + await followup_error(inter, + title="Assigned Feed Not Found", + message=f"I couldn't find any assigned feeds for the option: {option}" + ) return await followup(inter, "I've removed this item (placeholder response)") @feed_group.command(name="channels") - # @choices(sort=[ - # Choice(name="RSS Nickname", value=0), - # Choice(name="Channel ID", value=1), - # Choice(name="Date Added", value=2) - # ]) - async def list_feeds(self, inter: Interaction): # sort: int + @choices(sort=channels_list_sort_choices) + async def list_feeds(self, inter: Interaction, sort: Choice[int] = 0, sort_reverse: bool = False): """List all of the channels and their respective included feeds. Parameters @@ -553,6 +562,34 @@ class FeedCog(commands.Cog): await inter.response.defer() + description = "Sort By " + + if isinstance(sort, Choice): + match sort.value, sort_reverse: + case 0, False: + order_by = RssSourceModel.nick.asc() + description += "Nickname " + case 0, True: + order_by = RssSourceModel.nick.desc() + description += "Nickname " + case 1, False: + order_by = FeedChannelModel.discord_channel_id.asc() + description += "Channel ID " + case 1, True: + order_by = FeedChannelModel.discord_channel_id.desc() + description += "Channel ID " + case 2, False: + order_by = RssSourceModel.created.desc() + description += "Date Added " + case 2, True: + order_by = RssSourceModel.created.asc() + description += "Date Added " + case _, _: + raise ValueError(f"Unknown sort: {sort}") + else: + order_by = FeedChannelModel.discord_channel_id.asc() + description = "" + async with DatabaseManager() as database: whereclause = and_( FeedChannelModel.discord_server_id == inter.guild_id, @@ -562,7 +599,7 @@ class FeedCog(commands.Cog): select(FeedChannelModel, RssSourceModel) .where(whereclause) .join(RssSourceModel) - .order_by(FeedChannelModel.discord_channel_id) + .order_by(order_by) ) result = await database.session.execute(query) @@ -570,8 +607,9 @@ class FeedCog(commands.Cog): rowcount = len(feed_channels) if not feed_channels: - await followup(inter, - "It looks like there are no feed channels available." + await followup_error(inter, + title="No Assigned Feeds Found", + message="Assign a channel to receive feed content with `/feed assign`." ) return @@ -583,7 +621,7 @@ class FeedCog(commands.Cog): embed = Embed( title="Saved Feed Channels", - description=f"{output}", + description=f"{description}\n{output}", colour=Colour.blue() ) embed.set_footer(text=f"Showing {rowcount} results") diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index becd030..2318efe 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -4,17 +4,16 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo """ import logging -import async_timeout +from time import process_time -import aiohttp from feedparser import parse from sqlalchemy import insert, select, and_ -from discord import Interaction, app_commands, TextChannel +from discord import Interaction, TextChannel from discord.ext import commands, tasks from discord.errors import Forbidden -from feed import Source, Article, get_unparsed_feed -from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel +from feed import Source, Article, get_unparsed_feed # pylint disable=E0401 +from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401 log = logging.getLogger(__name__) @@ -27,6 +26,7 @@ class TaskCog(commands.Cog): def __init__(self, bot): super().__init__() self.bot = bot + self.time = None @commands.Cog.listener() async def on_ready(self): @@ -36,15 +36,12 @@ class TaskCog(commands.Cog): log.info("%s cog is ready", self.__class__.__name__) - # @app_commands.command(name="test-trigger-task") - async def test_trigger_task(self, inter: Interaction): - await inter.response.defer() - await self.rss_task() - await inter.followup.send("done") - @tasks.loop(minutes=10) async def rss_task(self): + """Automated task responsible for processing rss feeds.""" + log.info("Running rss task") + time = process_time() async with DatabaseManager() as database: query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel) @@ -54,10 +51,21 @@ class TaskCog(commands.Cog): for feed in feeds: await self.process_feed(feed, database) - log.info("Finished rss task") + log.info("Finished rss task, time elapsed: %s", process_time() - time) + + async def process_feed(self, feed: FeedChannelModel, database: DatabaseManager): + """Process the passed feed. Will also call process for each article found in the feed. + + Parameters + ---------- + feed : FeedChannelModel + Database model for the feed. + database : DatabaseManager + Database connection handler, must be open. + """ + + log.debug("Processing feed: %s", feed.id) - async def process_feed(self, feed: FeedChannelModel, database): - log.info("Processing feed: %s", feed.id) channel = self.bot.get_channel(feed.discord_channel_id) unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url) @@ -72,8 +80,22 @@ class TaskCog(commands.Cog): for article in articles: await self.process_article(article, channel, database) - async def process_article(self, article: Article, channel: TextChannel, database): - log.info("Processing article: %s", article.url) + async def process_article( + self, article: Article, channel: TextChannel, database: DatabaseManager + ): + """Process the passed article. Will send the embed to a channel if all is valid. + + Parameters + ---------- + article : Article + Database model for the article. + channel : TextChannel + Where the article will be sent to. + database : DatabaseManager + Database connection handler, must be open. + """ + + log.debug("Processing article: %s", article.url) query = select(SentArticleModel).where(and_( SentArticleModel.article_url == article.url, @@ -82,14 +104,14 @@ class TaskCog(commands.Cog): result = await database.session.execute(query) if result.scalars().all(): - log.info("Article already processed: %s", article.url) + log.debug("Article already processed: %s", article.url) return embed = await article.to_embed() try: await channel.send(embed=embed) except Forbidden: - log.error("Forbidden: %s · %s", channel.name, channel.id) + log.error("Can't send article to channel: %s · %s", channel.name, channel.id) return query = insert(SentArticleModel).values( @@ -100,8 +122,7 @@ class TaskCog(commands.Cog): ) await database.session.execute(query) - log.info("new Article processed: %s", article.url) - + log.debug("new Article processed: %s", article.url) async def setup(bot): diff --git a/src/feed.py b/src/feed.py index 217ebb4..0499e60 100644 --- a/src/feed.py +++ b/src/feed.py @@ -1,6 +1,3 @@ -""" - -""" import json import logging diff --git a/src/logs.py b/src/logs.py index f09bf92..6c83d30 100644 --- a/src/logs.py +++ b/src/logs.py @@ -19,7 +19,7 @@ log = logging.getLogger(__name__) class LogSetup: - + def __init__(self, BASE_DIR: Path): self.BASE_DIR = BASE_DIR self.LOGS_DIR = BASE_DIR / "logs/" @@ -100,4 +100,4 @@ class LogSetup: # Clear up old log files self._delete_old_logs() - return file.name \ No newline at end of file + return file.name diff --git a/src/main.py b/src/main.py index 25e1818..1194781 100644 --- a/src/main.py +++ b/src/main.py @@ -33,7 +33,7 @@ async def main(): # Setup logging settings and mute spammy loggers logsetup = LogSetup(BASE_DIR) - logsetup.setup_logs(logging.INFO) + logsetup.setup_logs(logging.DEBUG) logsetup.update_log_levels( ('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'), level=logging.WARNING diff --git a/src/utils.py b/src/utils.py index d0c14bb..bdeaf26 100644 --- a/src/utils.py +++ b/src/utils.py @@ -3,7 +3,7 @@ import aiohttp import logging -from discord import Interaction +from discord import Interaction, Embed, Colour log = logging.getLogger(__name__) @@ -29,3 +29,23 @@ async def audit(cog, *args, **kwargs): """Shorthand for auditing an interaction.""" await cog.bot.audit(*args, **kwargs) + +async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs): + """Shorthand for following up on an interaction, except returns an embed styled in + error colours. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.followup.send( + *args, + embed=Embed( + title=title, + description=message, + colour=Colour.red() + ), + **kwargs + )