Updated implementation of interaction commands
This commit is contained in:
parent
866660ef8b
commit
7ddfe09e4d
@ -3,24 +3,29 @@ Extension for the `CommandCog`.
|
|||||||
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import validators
|
import validators
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import textwrap
|
import textwrap
|
||||||
import feedparser
|
import feedparser
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
from discord import app_commands, Interaction, Embed, Colour
|
from discord import Interaction, Embed, Colour
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands
|
||||||
from discord.app_commands import Choice, Group, command, autocomplete
|
from discord.app_commands import Choice, Group, autocomplete, choices
|
||||||
from sqlalchemy import insert, select, update, and_, or_, delete
|
from sqlalchemy import insert, select, and_, delete
|
||||||
|
|
||||||
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
|
||||||
from feed import get_source, Source
|
from feed import get_source, Source
|
||||||
|
from db import DatabaseManager, SentArticleModel, RssSourceModel
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
rss_list_sort_choices = [
|
||||||
|
Choice(name="Nickname", value=0),
|
||||||
|
Choice(name="Date Added", value=1)
|
||||||
|
]
|
||||||
|
|
||||||
async def get_rss_data(url: str):
|
async def get_rss_data(url: str):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
@ -44,6 +49,49 @@ async def audit(cog, *args, **kwargs):
|
|||||||
|
|
||||||
await cog.bot.audit(*args, **kwargs)
|
await cog.bot.audit(*args, **kwargs)
|
||||||
|
|
||||||
|
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
|
||||||
|
# Run a period task to check this.
|
||||||
|
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
|
||||||
|
"""Validate a provided RSS source.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nickname : str
|
||||||
|
Nickname of the source. Must not contain URL.
|
||||||
|
url : str
|
||||||
|
URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
String invalid message if invalid, NoneType if valid.
|
||||||
|
FeedParserDict or None
|
||||||
|
The feed parsed from the given URL or None if invalid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Ensure the URL is valid
|
||||||
|
if not validators.url(url):
|
||||||
|
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||||
|
|
||||||
|
# Check the nickname is not a URL
|
||||||
|
if validators.url(nickname):
|
||||||
|
return "It looks like the nickname you have entered is a URL.\n" \
|
||||||
|
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||||
|
|
||||||
|
|
||||||
|
feed_data, status_code = await get_rss_data(url)
|
||||||
|
|
||||||
|
# Check the URL status code is valid
|
||||||
|
if status_code != 200:
|
||||||
|
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||||
|
|
||||||
|
# Check the contents is actually an RSS feed.
|
||||||
|
feed = feedparser.parse(feed_data)
|
||||||
|
if not feed.version:
|
||||||
|
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||||
|
|
||||||
|
return None, feed
|
||||||
|
|
||||||
|
|
||||||
class CommandCog(commands.Cog):
|
class CommandCog(commands.Cog):
|
||||||
"""
|
"""
|
||||||
@ -88,61 +136,35 @@ class CommandCog(commands.Cog):
|
|||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
|
# All RSS commands belong to this group.
|
||||||
rss_group = Group(
|
rss_group = Group(
|
||||||
name="rss",
|
name="rss",
|
||||||
description="Commands for rss sources.",
|
description="Commands for rss sources.",
|
||||||
guild_only=True
|
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||||
)
|
)
|
||||||
|
|
||||||
@rss_group.command(name="add")
|
@rss_group.command(name="add")
|
||||||
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
|
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||||
"""Add a new RSS source.
|
"""Add a new RSS source.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
url : str
|
|
||||||
The RSS feed URL.
|
|
||||||
nickname : str
|
nickname : str
|
||||||
A name used to identify the RSS source.
|
A name used to identify the RSS source.
|
||||||
|
url : str
|
||||||
|
The RSS feed URL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
# Ensure the URL is valid
|
illegal_message, feed = await validate_rss_source(nickname, url)
|
||||||
if not validators.url(url):
|
if illegal_message:
|
||||||
await followup(inter,
|
await followup(inter, illegal_message, suppress_embeds=True)
|
||||||
f"The URL you have entered is malformed or invalid:\n`{url=}`",
|
|
||||||
suppress_embeds=True
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check the nickname is not a URL
|
log.debug("RSS feed added")
|
||||||
if validators.url(nickname):
|
|
||||||
await followup(inter,
|
|
||||||
"It looks like the nickname you have entered is a URL.\n"
|
|
||||||
f"For security reasons, this is not allowed.\n`{nickname=}`",
|
|
||||||
suppress_embeds=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check the URL points to an RSS feed.
|
|
||||||
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
|
|
||||||
if status_code != 200:
|
|
||||||
await followup(inter,
|
|
||||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
|
|
||||||
suppress_embeds=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
feed = feedparser.parse(feed_data)
|
|
||||||
if not feed.version:
|
|
||||||
await followup(inter,
|
|
||||||
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
|
||||||
suppress_embeds=True
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
async with DatabaseManager() as database:
|
||||||
query = insert(RssSourceModel).values(
|
query = insert(RssSourceModel).values(
|
||||||
@ -157,68 +179,72 @@ class CommandCog(commands.Cog):
|
|||||||
inter.user.id, database=database
|
inter.user.id, database=database
|
||||||
)
|
)
|
||||||
|
|
||||||
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00"))
|
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
||||||
embed.add_field(name="Nickname", value=nickname)
|
embed.add_field(name="Nickname", value=nickname)
|
||||||
embed.add_field(name="URL", value=url)
|
embed.add_field(name="URL", value=url)
|
||||||
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
||||||
|
|
||||||
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
|
|
||||||
await followup(inter, embed=embed)
|
await followup(inter, embed=embed)
|
||||||
|
|
||||||
@rss_group.command(name="remove")
|
@rss_group.command(name="remove")
|
||||||
@autocomplete(source=source_autocomplete)
|
@autocomplete(url=source_autocomplete)
|
||||||
async def remove_rss_source(self, inter: Interaction, source: str):
|
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||||
"""Delete an existing RSS source.
|
"""Delete an existing RSS source.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
source : str
|
url : str
|
||||||
The RSS source to be removed. Autocomplete or enter the URL.
|
The RSS source to be removed. Autocomplete or enter the URL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
log.debug(f"Attempting to remove RSS source ({source=})")
|
log.debug(f"Attempting to remove RSS source ({url=})")
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
async with DatabaseManager() as database:
|
||||||
select_result = await database.session.execute(
|
select_result = await database.session.execute(
|
||||||
select(RssSourceModel).filter(
|
select(RssSourceModel).filter(
|
||||||
and_(
|
and_(
|
||||||
RssSourceModel.discord_server_id == inter.guild_id,
|
RssSourceModel.discord_server_id == inter.guild_id,
|
||||||
RssSourceModel.rss_url == source
|
RssSourceModel.rss_url == url
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
rss_source = select_result.fetchone()
|
rss_source = select_result.scalars().one()
|
||||||
|
nickname = rss_source.nick
|
||||||
|
|
||||||
delete_result = await database.session.execute(
|
delete_result = await database.session.execute(
|
||||||
delete(RssSourceModel).filter(
|
delete(RssSourceModel).filter(
|
||||||
and_(
|
and_(
|
||||||
RssSourceModel.discord_server_id == inter.guild_id,
|
RssSourceModel.discord_server_id == inter.guild_id,
|
||||||
RssSourceModel.rss_url == source
|
RssSourceModel.rss_url == url
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
nickname, rss_url = rss_source.nick, rss_source.rss_url
|
await audit(self,
|
||||||
|
f"Added RSS source ({nickname=}, {url=})",
|
||||||
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
|
inter.user.id, database=database
|
||||||
|
|
||||||
if delete_result.rowcount:
|
|
||||||
await followup(inter,
|
|
||||||
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
|
|
||||||
suppress_embeds=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not delete_result.rowcount:
|
||||||
|
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||||
return
|
return
|
||||||
|
|
||||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
source = get_source(url)
|
||||||
|
|
||||||
# potential_matches = await self.source_autocomplete(inter, source)
|
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||||
|
embed.add_field(name="Nickname", value=nickname)
|
||||||
|
embed.add_field(name="URL", value=url)
|
||||||
|
embed.set_thumbnail(url=source.icon_url)
|
||||||
|
|
||||||
|
await followup(inter, embed=embed)
|
||||||
|
|
||||||
@rss_group.command(name="list")
|
@rss_group.command(name="list")
|
||||||
async def list_rss_sources(self, inter: Interaction):
|
@choices(sort=rss_list_sort_choices)
|
||||||
|
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
|
||||||
"""Provides a with a list of RSS sources available for the current server.
|
"""Provides a with a list of RSS sources available for the current server.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -229,9 +255,32 @@ class CommandCog(commands.Cog):
|
|||||||
|
|
||||||
await inter.response.defer()
|
await inter.response.defer()
|
||||||
|
|
||||||
|
# Default to the first choice if not specified.
|
||||||
|
if type(sort) is Choice:
|
||||||
|
description = "Sort by "
|
||||||
|
description += "Nickname " if sort.value == 0 else "Date Added "
|
||||||
|
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
||||||
|
else:
|
||||||
|
sort = rss_list_sort_choices[0]
|
||||||
|
description = ""
|
||||||
|
|
||||||
|
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
|
||||||
|
|
||||||
|
match sort.value, sort_reverse:
|
||||||
|
case 0, False:
|
||||||
|
order_by = RssSourceModel.nick.asc()
|
||||||
|
case 0, True:
|
||||||
|
order_by = RssSourceModel.nick.desc()
|
||||||
|
case 1, False: # NOTE:
|
||||||
|
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
|
||||||
|
case 1, True: # date first, not the oldest as it would sort otherwise.
|
||||||
|
order_by = RssSourceModel.created.asc()
|
||||||
|
case _, _:
|
||||||
|
raise ValueError("Unknown sort: %s" % sort)
|
||||||
|
|
||||||
async with DatabaseManager() as database:
|
async with DatabaseManager() as database:
|
||||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||||
query = select(RssSourceModel).where(whereclause)
|
query = select(RssSourceModel).where(whereclause).order_by(order_by)
|
||||||
result = await database.session.execute(query)
|
result = await database.session.execute(query)
|
||||||
|
|
||||||
rss_sources = result.scalars().all()
|
rss_sources = result.scalars().all()
|
||||||
@ -240,10 +289,15 @@ class CommandCog(commands.Cog):
|
|||||||
await followup(inter, "It looks like you have no rss sources.")
|
await followup(inter, "It looks like you have no rss sources.")
|
||||||
return
|
return
|
||||||
|
|
||||||
output = "## Available RSS Sources\n"
|
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
|
||||||
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
|
|
||||||
|
|
||||||
await followup(inter, output, suppress_embeds=True)
|
embed = Embed(
|
||||||
|
title="Saved RSS Feeds",
|
||||||
|
description=f"{description}\n\n{output}",
|
||||||
|
colour=Colour.lighter_grey()
|
||||||
|
)
|
||||||
|
|
||||||
|
await followup(inter, embed=embed)
|
||||||
|
|
||||||
@rss_group.command(name="fetch")
|
@rss_group.command(name="fetch")
|
||||||
@autocomplete(rss=source_autocomplete)
|
@autocomplete(rss=source_autocomplete)
|
||||||
@ -256,7 +310,12 @@ class CommandCog(commands.Cog):
|
|||||||
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
||||||
return
|
return
|
||||||
|
|
||||||
source = get_source(rss)
|
invalid_message, feed = await validate_rss_source("", rss)
|
||||||
|
if invalid_message:
|
||||||
|
await followup(inter, invalid_message)
|
||||||
|
return
|
||||||
|
|
||||||
|
source = Source.from_parsed(feed)
|
||||||
articles = source.get_latest_articles(max)
|
articles = source.get_latest_articles(max)
|
||||||
|
|
||||||
embeds = []
|
embeds = []
|
||||||
@ -269,6 +328,7 @@ class CommandCog(commands.Cog):
|
|||||||
description=article_description,
|
description=article_description,
|
||||||
url=article.url,
|
url=article.url,
|
||||||
timestamp=article.published,
|
timestamp=article.published,
|
||||||
|
colour=Colour.brand_red()
|
||||||
)
|
)
|
||||||
embed.set_thumbnail(url=source.icon_url)
|
embed.set_thumbnail(url=source.icon_url)
|
||||||
embed.set_image(url=await article.get_thumbnail_url())
|
embed.set_image(url=await article.get_thumbnail_url())
|
||||||
@ -290,7 +350,7 @@ class CommandCog(commands.Cog):
|
|||||||
for article in articles
|
for article in articles
|
||||||
])
|
])
|
||||||
await database.session.execute(query)
|
await database.session.execute(query)
|
||||||
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
|
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
|
||||||
|
|
||||||
await followup(inter, embeds=embeds)
|
await followup(inter, embeds=embeds)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user