Updated implementation of interaction commands
This commit is contained in:
parent
866660ef8b
commit
7ddfe09e4d
@ -3,24 +3,29 @@ Extension for the `CommandCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import validators
|
||||
from typing import Tuple
|
||||
|
||||
import aiohttp
|
||||
import textwrap
|
||||
import feedparser
|
||||
from markdownify import markdownify
|
||||
from discord import app_commands, Interaction, Embed, Colour
|
||||
from discord.ext import commands, tasks
|
||||
from discord.app_commands import Choice, Group, command, autocomplete
|
||||
from sqlalchemy import insert, select, update, and_, or_, delete
|
||||
from discord import Interaction, Embed, Colour
|
||||
from discord.ext import commands
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
|
||||
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||
from feed import get_source, Source
|
||||
from db import DatabaseManager, SentArticleModel, RssSourceModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
rss_list_sort_choices = [
|
||||
Choice(name="Nickname", value=0),
|
||||
Choice(name="Date Added", value=1)
|
||||
]
|
||||
|
||||
async def get_rss_data(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs):
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
||||
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
|
||||
# Run a period task to check this.
|
||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
|
||||
"""Validate a provided RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Nickname of the source. Must not contain URL.
|
||||
url : str
|
||||
URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
String invalid message if invalid, NoneType if valid.
|
||||
FeedParserDict or None
|
||||
The feed parsed from the given URL or None if invalid.
|
||||
"""
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
return "It looks like the nickname you have entered is a URL.\n" \
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
# Check the URL status code is valid
|
||||
if status_code != 200:
|
||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = feedparser.parse(feed_data)
|
||||
if not feed.version:
|
||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||
|
||||
return None, feed
|
||||
|
||||
|
||||
class CommandCog(commands.Cog):
|
||||
"""
|
||||
@ -88,61 +136,35 @@ class CommandCog(commands.Cog):
|
||||
|
||||
return sources
|
||||
|
||||
# All RSS commands belong to this group.
|
||||
rss_group = Group(
|
||||
name="rss",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
@rss_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
|
||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||
"""Add a new RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
url : str
|
||||
The RSS feed URL.
|
||||
nickname : str
|
||||
A name used to identify the RSS source.
|
||||
url : str
|
||||
The RSS feed URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
await followup(inter,
|
||||
f"The URL you have entered is malformed or invalid:\n`{url=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
illegal_message, feed = await validate_rss_source(nickname, url)
|
||||
if illegal_message:
|
||||
await followup(inter, illegal_message, suppress_embeds=True)
|
||||
return
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
await followup(inter,
|
||||
"It looks like the nickname you have entered is a URL.\n"
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
# Check the URL points to an RSS feed.
|
||||
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
|
||||
if status_code != 200:
|
||||
await followup(inter,
|
||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
feed = feedparser.parse(feed_data)
|
||||
if not feed.version:
|
||||
await followup(inter,
|
||||
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
log.debug("RSS feed added")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
@ -157,68 +179,72 @@ class CommandCog(commands.Cog):
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00"))
|
||||
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
||||
|
||||
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="remove")
|
||||
@autocomplete(source=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, source: str):
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
source : str
|
||||
url : str
|
||||
The RSS source to be removed. Autocomplete or enter the URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
log.debug(f"Attempting to remove RSS source ({source=})")
|
||||
log.debug(f"Attempting to remove RSS source ({url=})")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_result = await database.session.execute(
|
||||
select(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == source
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
)
|
||||
rss_source = select_result.fetchone()
|
||||
rss_source = select_result.scalars().one()
|
||||
nickname = rss_source.nick
|
||||
|
||||
delete_result = await database.session.execute(
|
||||
delete(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == source
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
nickname, rss_url = rss_source.nick, rss_source.rss_url
|
||||
|
||||
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
|
||||
|
||||
if delete_result.rowcount:
|
||||
await followup(inter,
|
||||
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
|
||||
suppress_embeds=True
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
if not delete_result.rowcount:
|
||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||
return
|
||||
|
||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||
source = get_source(url)
|
||||
|
||||
# potential_matches = await self.source_autocomplete(inter, source)
|
||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="list")
|
||||
async def list_rss_sources(self, inter: Interaction):
|
||||
@choices(sort=rss_list_sort_choices)
|
||||
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
|
||||
Parameters
|
||||
@ -229,9 +255,32 @@ class CommandCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# Default to the first choice if not specified.
|
||||
if type(sort) is Choice:
|
||||
description = "Sort by "
|
||||
description += "Nickname " if sort.value == 0 else "Date Added "
|
||||
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
||||
else:
|
||||
sort = rss_list_sort_choices[0]
|
||||
description = ""
|
||||
|
||||
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
|
||||
|
||||
match sort.value, sort_reverse:
|
||||
case 0, False:
|
||||
order_by = RssSourceModel.nick.asc()
|
||||
case 0, True:
|
||||
order_by = RssSourceModel.nick.desc()
|
||||
case 1, False: # NOTE:
|
||||
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
|
||||
case 1, True: # date first, not the oldest as it would sort otherwise.
|
||||
order_by = RssSourceModel.created.asc()
|
||||
case _, _:
|
||||
raise ValueError("Unknown sort: %s" % sort)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
query = select(RssSourceModel).where(whereclause).order_by(order_by)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
rss_sources = result.scalars().all()
|
||||
@ -240,10 +289,15 @@ class CommandCog(commands.Cog):
|
||||
await followup(inter, "It looks like you have no rss sources.")
|
||||
return
|
||||
|
||||
output = "## Available RSS Sources\n"
|
||||
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
|
||||
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
|
||||
|
||||
await followup(inter, output, suppress_embeds=True)
|
||||
embed = Embed(
|
||||
title="Saved RSS Feeds",
|
||||
description=f"{description}\n\n{output}",
|
||||
colour=Colour.lighter_grey()
|
||||
)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="fetch")
|
||||
@autocomplete(rss=source_autocomplete)
|
||||
@ -256,7 +310,12 @@ class CommandCog(commands.Cog):
|
||||
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
||||
return
|
||||
|
||||
source = get_source(rss)
|
||||
invalid_message, feed = await validate_rss_source("", rss)
|
||||
if invalid_message:
|
||||
await followup(inter, invalid_message)
|
||||
return
|
||||
|
||||
source = Source.from_parsed(feed)
|
||||
articles = source.get_latest_articles(max)
|
||||
|
||||
embeds = []
|
||||
@ -269,6 +328,7 @@ class CommandCog(commands.Cog):
|
||||
description=article_description,
|
||||
url=article.url,
|
||||
timestamp=article.published,
|
||||
colour=Colour.brand_red()
|
||||
)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
embed.set_image(url=await article.get_thumbnail_url())
|
||||
@ -290,7 +350,7 @@ class CommandCog(commands.Cog):
|
||||
for article in articles
|
||||
])
|
||||
await database.session.execute(query)
|
||||
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
|
||||
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
|
||||
|
||||
await followup(inter, embeds=embeds)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user