Updated implementation of interaction commands

This commit is contained in:
Corban-Lee Jones 2023-12-16 23:54:12 +00:00
parent 866660ef8b
commit 7ddfe09e4d

View File

@ -3,24 +3,29 @@ Extension for the `CommandCog`.
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
"""
import json
import logging
import validators
from typing import Tuple
import aiohttp
import textwrap
import feedparser
from markdownify import markdownify
from discord import app_commands, Interaction, Embed, Colour
from discord.ext import commands, tasks
from discord.app_commands import Choice, Group, command, autocomplete
from sqlalchemy import insert, select, update, and_, or_, delete
from discord import Interaction, Embed, Colour
from discord.ext import commands
from discord.app_commands import Choice, Group, autocomplete, choices
from sqlalchemy import insert, select, and_, delete
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import get_source, Source
from db import DatabaseManager, SentArticleModel, RssSourceModel
log = logging.getLogger(__name__)
rss_list_sort_choices = [
Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1)
]
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs):
await cog.bot.audit(*args, **kwargs)
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
# Run a period task to check this.
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
"""Validate a provided RSS source.
Parameters
----------
nickname : str
Nickname of the source. Must not contain URL.
url : str
URL of the source. Must be URL with valid status code and be an RSS feed.
Returns
-------
str or None
String invalid message if invalid, NoneType if valid.
FeedParserDict or None
The feed parsed from the given URL or None if invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
# Check the nickname is not a URL
if validators.url(nickname):
return "It looks like the nickname you have entered is a URL.\n" \
f"For security reasons, this is not allowed.\n`{nickname=}`", None
feed_data, status_code = await get_rss_data(url)
# Check the URL status code is valid
if status_code != 200:
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
# Check the contents is actually an RSS feed.
feed = feedparser.parse(feed_data)
if not feed.version:
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
return None, feed
class CommandCog(commands.Cog):
"""
@ -88,61 +136,35 @@ class CommandCog(commands.Cog):
return sources
# All RSS commands belong to this group.
rss_group = Group(
name="rss",
description="Commands for rss sources.",
guild_only=True
guild_only=True # We store guild IDs in the database, so guild only = True
)
@rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
"""Add a new RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS feed URL.
nickname : str
A name used to identify the RSS source.
url : str
The RSS feed URL.
"""
await inter.response.defer()
# Ensure the URL is valid
if not validators.url(url):
await followup(inter,
f"The URL you have entered is malformed or invalid:\n`{url=}`",
suppress_embeds=True
)
illegal_message, feed = await validate_rss_source(nickname, url)
if illegal_message:
await followup(inter, illegal_message, suppress_embeds=True)
return
# Check the nickname is not a URL
if validators.url(nickname):
await followup(inter,
"It looks like the nickname you have entered is a URL.\n"
f"For security reasons, this is not allowed.\n`{nickname=}`",
suppress_embeds=True
)
return
# Check the URL points to an RSS feed.
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
if status_code != 200:
await followup(inter,
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
suppress_embeds=True
)
return
feed = feedparser.parse(feed_data)
if not feed.version:
await followup(inter,
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True
)
return
log.debug("RSS feed added")
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
@ -157,68 +179,72 @@ class CommandCog(commands.Cog):
inter.user.id, database=database
)
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00"))
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
await followup(inter, embed=embed)
@rss_group.command(name="remove")
@autocomplete(source=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, source: str):
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
source : str
url : str
The RSS source to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer()
log.debug(f"Attempting to remove RSS source ({source=})")
log.debug(f"Attempting to remove RSS source ({url=})")
async with DatabaseManager() as database:
select_result = await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
RssSourceModel.rss_url == url
)
)
)
rss_source = select_result.fetchone()
rss_source = select_result.scalars().one()
nickname = rss_source.nick
delete_result = await database.session.execute(
delete(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
RssSourceModel.rss_url == url
)
)
)
nickname, rss_url = rss_source.nick, rss_source.rss_url
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
if delete_result.rowcount:
await followup(inter,
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
suppress_embeds=True
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
if not delete_result.rowcount:
await followup(inter, "Couldn't find any RSS sources with this name.")
return
await followup(inter, "Couldn't find any RSS sources with this name.")
source = get_source(url)
# potential_matches = await self.source_autocomplete(inter, source)
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=source.icon_url)
await followup(inter, embed=embed)
@rss_group.command(name="list")
async def list_rss_sources(self, inter: Interaction):
@choices(sort=rss_list_sort_choices)
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
"""Provides a with a list of RSS sources available for the current server.
Parameters
@ -229,9 +255,32 @@ class CommandCog(commands.Cog):
await inter.response.defer()
# Default to the first choice if not specified.
if type(sort) is Choice:
description = "Sort by "
description += "Nickname " if sort.value == 0 else "Date Added "
description += '\U000025BC' if sort_reverse else '\U000025B2'
else:
sort = rss_list_sort_choices[0]
description = ""
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
case 0, True:
order_by = RssSourceModel.nick.desc()
case 1, False: # NOTE:
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
case 1, True: # date first, not the oldest as it would sort otherwise.
order_by = RssSourceModel.created.asc()
case _, _:
raise ValueError("Unknown sort: %s" % sort)
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause)
query = select(RssSourceModel).where(whereclause).order_by(order_by)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
@ -240,10 +289,15 @@ class CommandCog(commands.Cog):
await followup(inter, "It looks like you have no rss sources.")
return
output = "## Available RSS Sources\n"
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
await followup(inter, output, suppress_embeds=True)
embed = Embed(
title="Saved RSS Feeds",
description=f"{description}\n\n{output}",
colour=Colour.lighter_grey()
)
await followup(inter, embed=embed)
@rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete)
@ -256,7 +310,12 @@ class CommandCog(commands.Cog):
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return
source = get_source(rss)
invalid_message, feed = await validate_rss_source("", rss)
if invalid_message:
await followup(inter, invalid_message)
return
source = Source.from_parsed(feed)
articles = source.get_latest_articles(max)
embeds = []
@ -269,6 +328,7 @@ class CommandCog(commands.Cog):
description=article_description,
url=article.url,
timestamp=article.published,
colour=Colour.brand_red()
)
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url())
@ -290,7 +350,7 @@ class CommandCog(commands.Cog):
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
await followup(inter, embeds=embeds)