Updated implementation of interaction commands

This commit is contained in:
Corban-Lee Jones 2023-12-16 23:54:12 +00:00
parent 866660ef8b
commit 7ddfe09e4d

View File

@ -3,24 +3,29 @@ Extension for the `CommandCog`.
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
""" """
import json
import logging import logging
import validators import validators
from typing import Tuple
import aiohttp import aiohttp
import textwrap import textwrap
import feedparser import feedparser
from markdownify import markdownify from markdownify import markdownify
from discord import app_commands, Interaction, Embed, Colour from discord import Interaction, Embed, Colour
from discord.ext import commands, tasks from discord.ext import commands
from discord.app_commands import Choice, Group, command, autocomplete from discord.app_commands import Choice, Group, autocomplete, choices
from sqlalchemy import insert, select, update, and_, or_, delete from sqlalchemy import insert, select, and_, delete
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import get_source, Source from feed import get_source, Source
from db import DatabaseManager, SentArticleModel, RssSourceModel
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
rss_list_sort_choices = [
Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1)
]
async def get_rss_data(url: str): async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(url) as response: async with session.get(url) as response:
@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs):
await cog.bot.audit(*args, **kwargs) await cog.bot.audit(*args, **kwargs)
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
# Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
"""Validate a provided RSS source.
Parameters
----------
nickname : str
Nickname of the source. Must not contain URL.
url : str
URL of the source. Must be URL with valid status code and be an RSS feed.
Returns
-------
str or None
String invalid message if invalid, NoneType if valid.
FeedParserDict or None
The feed parsed from the given URL or None if invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
# Check the nickname is not a URL
if validators.url(nickname):
return "It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`", None
feed_data, status_code = await get_rss_data(url)
# Check the URL status code is valid
if status_code != 200:
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed.
feed = feedparser.parse(feed_data)
if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
return None, feed
class CommandCog(commands.Cog): class CommandCog(commands.Cog):
""" """
@ -88,61 +136,35 @@ class CommandCog(commands.Cog):
return sources return sources
# All RSS commands belong to this group.
rss_group = Group( rss_group = Group(
name="rss", name="rss",
description="Commands for rss sources.", description="Commands for rss sources.",
guild_only=True guild_only=True # We store guild IDs in the database, so guild only = True
) )
@rss_group.command(name="add") @rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str, nickname: str): async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new RSS source. """Add a new RSS source.
Parameters Parameters
---------- ----------
inter : Interaction inter : Interaction
Represents an app command interaction. Represents an app command interaction.
url : str
The RSS feed URL.
nickname : str nickname : str
A name used to identify the RSS source. A name used to identify the RSS source.
url : str
The RSS feed URL.
""" """
await inter.response.defer() await inter.response.defer()
# Ensure the URL is valid illegal_message, feed = await validate_rss_source(nickname, url)
if not validators.url(url): if illegal_message:
await followup(inter, await followup(inter, illegal_message, suppress_embeds=True)
f"The URL you have entered is malformed or invalid:\n`{url=}`",
suppress_embeds=True
)
return return
# Check the nickname is not a URL log.debug("RSS feed added")
if validators.url(nickname):
await followup(inter,
"It looks like the nickname you have entered is a URL.\n"
f"For security reasons, this is not allowed.\n`{nickname=}`",
suppress_embeds=True
)
return
# Check the URL points to an RSS feed.
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
if status_code != 200:
await followup(inter,
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
suppress_embeds=True
)
return
feed = feedparser.parse(feed_data)
if not feed.version:
await followup(inter,
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True
)
return
async with DatabaseManager() as database: async with DatabaseManager() as database:
query = insert(RssSourceModel).values( query = insert(RssSourceModel).values(
@ -157,68 +179,72 @@ class CommandCog(commands.Cog):
inter.user.id, database=database inter.user.id, database=database
) )
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00")) embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
embed.add_field(name="Nickname", value=nickname) embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url) embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href")) embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
await followup(inter, embed=embed) await followup(inter, embed=embed)
@rss_group.command(name="remove") @rss_group.command(name="remove")
@autocomplete(source=source_autocomplete) @autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, source: str): async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source. """Delete an existing RSS source.
Parameters Parameters
---------- ----------
inter : Interaction inter : Interaction
Represents an app command interaction. Represents an app command interaction.
source : str url : str
The RSS source to be removed. Autocomplete or enter the URL. The RSS source to be removed. Autocomplete or enter the URL.
""" """
await inter.response.defer() await inter.response.defer()
log.debug(f"Attempting to remove RSS source ({source=})") log.debug(f"Attempting to remove RSS source ({url=})")
async with DatabaseManager() as database: async with DatabaseManager() as database:
select_result = await database.session.execute( select_result = await database.session.execute(
select(RssSourceModel).filter( select(RssSourceModel).filter(
and_( and_(
RssSourceModel.discord_server_id == inter.guild_id, RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source RssSourceModel.rss_url == url
) )
) )
) )
rss_source = select_result.fetchone() rss_source = select_result.scalars().one()
nickname = rss_source.nick
delete_result = await database.session.execute( delete_result = await database.session.execute(
delete(RssSourceModel).filter( delete(RssSourceModel).filter(
and_( and_(
RssSourceModel.discord_server_id == inter.guild_id, RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source RssSourceModel.rss_url == url
) )
) )
) )
nickname, rss_url = rss_source.nick, rss_source.rss_url await audit(self,
f"Added RSS source ({nickname=}, {url=})",
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works) inter.user.id, database=database
if delete_result.rowcount:
await followup(inter,
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
suppress_embeds=True
) )
if not delete_result.rowcount:
await followup(inter, "Couldn't find any RSS sources with this name.")
return return
await followup(inter, "Couldn't find any RSS sources with this name.") source = get_source(url)
# potential_matches = await self.source_autocomplete(inter, source) embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=source.icon_url)
await followup(inter, embed=embed)
@rss_group.command(name="list") @rss_group.command(name="list")
async def list_rss_sources(self, inter: Interaction): @choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
"""Provides a with a list of RSS sources available for the current server. """Provides a with a list of RSS sources available for the current server.
Parameters Parameters
@ -229,9 +255,32 @@ class CommandCog(commands.Cog):
await inter.response.defer() await inter.response.defer()
# Default to the first choice if not specified.
if type(sort) is Choice:
description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2'
else:
sort = rss_list_sort_choices[0]
description = ""
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
case 0, True:
order_by = RssSourceModel.nick.desc()
case 1, False: # NOTE:
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
case 1, True: # date first, not the oldest as it would sort otherwise.
order_by = RssSourceModel.created.asc()
case _, _:
raise ValueError("Unknown sort: %s" % sort)
async with DatabaseManager() as database: async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause) query = select(RssSourceModel).where(whereclause).order_by(order_by)
result = await database.session.execute(query) result = await database.session.execute(query)
rss_sources = result.scalars().all() rss_sources = result.scalars().all()
@ -240,10 +289,15 @@ class CommandCog(commands.Cog):
await followup(inter, "It looks like you have no rss sources.") await followup(inter, "It looks like you have no rss sources.")
return return
output = "## Available RSS Sources\n" output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
await followup(inter, output, suppress_embeds=True) embed = Embed(
title="Saved RSS Feeds",
description=f"{description}\n\n{output}",
colour=Colour.lighter_grey()
)
await followup(inter, embed=embed)
@rss_group.command(name="fetch") @rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete) @autocomplete(rss=source_autocomplete)
@ -256,7 +310,12 @@ class CommandCog(commands.Cog):
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5") followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return return
source = get_source(rss) invalid_message, feed = await validate_rss_source("", rss)
if invalid_message:
await followup(inter, invalid_message)
return
source = Source.from_parsed(feed)
articles = source.get_latest_articles(max) articles = source.get_latest_articles(max)
embeds = [] embeds = []
@ -269,6 +328,7 @@ class CommandCog(commands.Cog):
description=article_description, description=article_description,
url=article.url, url=article.url,
timestamp=article.published, timestamp=article.published,
colour=Colour.brand_red()
) )
embed.set_thumbnail(url=source.icon_url) embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url()) embed.set_image(url=await article.get_thumbnail_url())
@ -290,7 +350,7 @@ class CommandCog(commands.Cog):
for article in articles for article in articles
]) ])
await database.session.execute(query) await database.session.execute(query)
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database) await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
await followup(inter, embeds=embeds) await followup(inter, embeds=embeds)