From 7ddfe09e4d0c2eb6d635646850babfb8f6810111 Mon Sep 17 00:00:00 2001 From: corbz Date: Sat, 16 Dec 2023 23:54:12 +0000 Subject: [PATCH] Updated implementation of interaction commands --- src/extensions/cmd.py | 194 +++++++++++++++++++++++++++--------------- 1 file changed, 127 insertions(+), 67 deletions(-) diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py index 0b4724e..220e449 100644 --- a/src/extensions/cmd.py +++ b/src/extensions/cmd.py @@ -3,24 +3,29 @@ Extension for the `CommandCog`. Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. """ -import json import logging import validators +from typing import Tuple import aiohttp import textwrap import feedparser from markdownify import markdownify -from discord import app_commands, Interaction, Embed, Colour -from discord.ext import commands, tasks -from discord.app_commands import Choice, Group, command, autocomplete -from sqlalchemy import insert, select, update, and_, or_, delete +from discord import Interaction, Embed, Colour +from discord.ext import commands +from discord.app_commands import Choice, Group, autocomplete, choices +from sqlalchemy import insert, select, and_, delete -from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from feed import get_source, Source +from db import DatabaseManager, SentArticleModel, RssSourceModel log = logging.getLogger(__name__) +rss_list_sort_choices = [ + Choice(name="Nickname", value=0), + Choice(name="Date Added", value=1) +] + async def get_rss_data(url: str): async with aiohttp.ClientSession() as session: async with session.get(url) as response: @@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs): await cog.bot.audit(*args, **kwargs) +# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. +# Run a period task to check this. +async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]: + """Validate a provided RSS source. + + Parameters + ---------- + nickname : str + Nickname of the source. Must not contain URL. + url : str + URL of the source. Must be URL with valid status code and be an RSS feed. + + Returns + ------- + str or None + String invalid message if invalid, NoneType if valid. + FeedParserDict or None + The feed parsed from the given URL or None if invalid. + """ + + # Ensure the URL is valid + if not validators.url(url): + return f"The URL you have entered is malformed or invalid:\n`{url=}`", None + + # Check the nickname is not a URL + if validators.url(nickname): + return "It looks like the nickname you have entered is a URL.\n" \ + f"For security reasons, this is not allowed.\n`{nickname=}`", None + + + feed_data, status_code = await get_rss_data(url) + + # Check the URL status code is valid + if status_code != 200: + return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None + + # Check the contents is actually an RSS feed. + feed = feedparser.parse(feed_data) + if not feed.version: + return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None + + return None, feed + class CommandCog(commands.Cog): """ @@ -88,61 +136,35 @@ class CommandCog(commands.Cog): return sources + # All RSS commands belong to this group. rss_group = Group( name="rss", description="Commands for rss sources.", - guild_only=True + guild_only=True # We store guild IDs in the database, so guild only = True ) @rss_group.command(name="add") - async def add_rss_source(self, inter: Interaction, url: str, nickname: str): + async def add_rss_source(self, inter: Interaction, nickname: str, url: str): """Add a new RSS source. Parameters ---------- inter : Interaction Represents an app command interaction. - url : str - The RSS feed URL. nickname : str A name used to identify the RSS source. + url : str + The RSS feed URL. """ await inter.response.defer() - # Ensure the URL is valid - if not validators.url(url): - await followup(inter, - f"The URL you have entered is malformed or invalid:\n`{url=}`", - suppress_embeds=True - ) + illegal_message, feed = await validate_rss_source(nickname, url) + if illegal_message: + await followup(inter, illegal_message, suppress_embeds=True) return - # Check the nickname is not a URL - if validators.url(nickname): - await followup(inter, - "It looks like the nickname you have entered is a URL.\n" - f"For security reasons, this is not allowed.\n`{nickname=}`", - suppress_embeds=True - ) - return - - # Check the URL points to an RSS feed. - feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this. - if status_code != 200: - await followup(inter, - f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", - suppress_embeds=True - ) - return - - feed = feedparser.parse(feed_data) - if not feed.version: - await followup(inter, - f"The provided URL '{url}' does not seem to be a valid RSS feed.", - suppress_embeds=True - ) - return + log.debug("RSS feed added") async with DatabaseManager() as database: query = insert(RssSourceModel).values( @@ -157,68 +179,72 @@ class CommandCog(commands.Cog): inter.user.id, database=database ) - embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00")) + embed = Embed(title="RSS Feed Added", colour=Colour.dark_green()) embed.add_field(name="Nickname", value=nickname) embed.add_field(name="URL", value=url) embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) - # , f"RSS source added [{nickname}]({url})", suppress_embeds=True await followup(inter, embed=embed) @rss_group.command(name="remove") - @autocomplete(source=source_autocomplete) - async def remove_rss_source(self, inter: Interaction, source: str): + @autocomplete(url=source_autocomplete) + async def remove_rss_source(self, inter: Interaction, url: str): """Delete an existing RSS source. Parameters ---------- inter : Interaction Represents an app command interaction. - source : str + url : str The RSS source to be removed. Autocomplete or enter the URL. """ await inter.response.defer() - log.debug(f"Attempting to remove RSS source ({source=})") + log.debug(f"Attempting to remove RSS source ({url=})") async with DatabaseManager() as database: select_result = await database.session.execute( select(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == source + RssSourceModel.rss_url == url ) ) ) - rss_source = select_result.fetchone() + rss_source = select_result.scalars().one() + nickname = rss_source.nick delete_result = await database.session.execute( delete(RssSourceModel).filter( and_( RssSourceModel.discord_server_id == inter.guild_id, - RssSourceModel.rss_url == source + RssSourceModel.rss_url == url ) ) ) - nickname, rss_url = rss_source.nick, rss_source.rss_url - - # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) - - if delete_result.rowcount: - await followup(inter, - f"RSS source deleted successfully\n**[{nickname}]({rss_url})**", - suppress_embeds=True + await audit(self, + f"Added RSS source ({nickname=}, {url=})", + inter.user.id, database=database ) + + if not delete_result.rowcount: + await followup(inter, "Couldn't find any RSS sources with this name.") return - await followup(inter, "Couldn't find any RSS sources with this name.") + source = get_source(url) - # potential_matches = await self.source_autocomplete(inter, source) + embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) + embed.add_field(name="Nickname", value=nickname) + embed.add_field(name="URL", value=url) + embed.set_thumbnail(url=source.icon_url) + + await followup(inter, embed=embed) @rss_group.command(name="list") - async def list_rss_sources(self, inter: Interaction): + @choices(sort=rss_list_sort_choices) + async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): """Provides a with a list of RSS sources available for the current server. Parameters @@ -229,9 +255,32 @@ class CommandCog(commands.Cog): await inter.response.defer() + # Default to the first choice if not specified. + if type(sort) is Choice: + description = "Sort by " + description += "Nickname " if sort.value == 0 else "Date Added " + description += '\U000025BC' if sort_reverse else '\U000025B2' + else: + sort = rss_list_sort_choices[0] + description = "" + + sort = sort if type(sort) == Choice else rss_list_sort_choices[0] + + match sort.value, sort_reverse: + case 0, False: + order_by = RssSourceModel.nick.asc() + case 0, True: + order_by = RssSourceModel.nick.desc() + case 1, False: # NOTE: + order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest + case 1, True: # date first, not the oldest as it would sort otherwise. + order_by = RssSourceModel.created.asc() + case _, _: + raise ValueError("Unknown sort: %s" % sort) + async with DatabaseManager() as database: whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) - query = select(RssSourceModel).where(whereclause) + query = select(RssSourceModel).where(whereclause).order_by(order_by) result = await database.session.execute(query) rss_sources = result.scalars().all() @@ -240,10 +289,15 @@ class CommandCog(commands.Cog): await followup(inter, "It looks like you have no rss sources.") return - output = "## Available RSS Sources\n" - output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources]) + output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)]) - await followup(inter, output, suppress_embeds=True) + embed = Embed( + title="Saved RSS Feeds", + description=f"{description}\n\n{output}", + colour=Colour.lighter_grey() + ) + + await followup(inter, embed=embed) @rss_group.command(name="fetch") @autocomplete(rss=source_autocomplete) @@ -256,7 +310,12 @@ class CommandCog(commands.Cog): followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") return - source = get_source(rss) + invalid_message, feed = await validate_rss_source("", rss) + if invalid_message: + await followup(inter, invalid_message) + return + + source = Source.from_parsed(feed) articles = source.get_latest_articles(max) embeds = [] @@ -269,6 +328,7 @@ class CommandCog(commands.Cog): description=article_description, url=article.url, timestamp=article.published, + colour=Colour.brand_red() ) embed.set_thumbnail(url=source.icon_url) embed.set_image(url=await article.get_thumbnail_url()) @@ -290,7 +350,7 @@ class CommandCog(commands.Cog): for article in articles ]) await database.session.execute(query) - await audit(self, f"User is requesting {max} articles", inter.user.id, database=database) + await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database) await followup(inter, embeds=embeds)