2023-12-16 17:55:21 +00:00

307 lines
9.8 KiB
Python

"""
Extension for the `CommandCog`.
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
"""
import json
import logging
import validators
import aiohttp
import textwrap
import feedparser
from markdownify import markdownify
from discord import app_commands, Interaction, Embed, Colour
from discord.ext import commands, tasks
from discord.app_commands import Choice, Group, command, autocomplete
from sqlalchemy import insert, select, update, and_, or_, delete
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import get_source, Source
log = logging.getLogger(__name__)
async def get_rss_data(url: str):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
items = await response.text(), response.status
return items
async def followup(inter: Interaction, *args, **kwargs):
"""Shorthand for following up on an interaction.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
class CommandCog(commands.Cog):
"""
Command cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
_description_
Returns
-------
list of app_commands.Choice
_description_
"""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.rss_url)
for rss in result.scalars().all()
]
return sources
rss_group = Group(
name="rss",
description="Commands for rss sources.",
guild_only=True
)
@rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
"""Add a new RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS feed URL.
nickname : str
A name used to identify the RSS source.
"""
await inter.response.defer()
# Ensure the URL is valid
if not validators.url(url):
await followup(inter,
f"The URL you have entered is malformed or invalid:\n`{url=}`",
suppress_embeds=True
)
return
# Check the nickname is not a URL
if validators.url(nickname):
await followup(inter,
"It looks like the nickname you have entered is a URL.\n"
f"For security reasons, this is not allowed.\n`{nickname=}`",
suppress_embeds=True
)
return
# Check the URL points to an RSS feed.
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
if status_code != 200:
await followup(inter,
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
suppress_embeds=True
)
return
feed = feedparser.parse(feed_data)
if not feed.version:
await followup(inter,
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True
)
return
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id,
rss_url = url,
nick=nickname
)
await database.session.execute(query)
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00"))
embed.add_field(name="Nickname", value=nickname)
embed.add_field(name="URL", value=url)
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
await followup(inter, embed=embed)
@rss_group.command(name="remove")
@autocomplete(source=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, source: str):
"""Delete an existing RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
source : str
The RSS source to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer()
log.debug(f"Attempting to remove RSS source ({source=})")
async with DatabaseManager() as database:
select_result = await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
)
)
)
rss_source = select_result.fetchone()
delete_result = await database.session.execute(
delete(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
)
)
)
nickname, rss_url = rss_source.nick, rss_source.rss_url
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
if delete_result.rowcount:
await followup(inter,
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
suppress_embeds=True
)
return
await followup(inter, "Couldn't find any RSS sources with this name.")
# potential_matches = await self.source_autocomplete(inter, source)
@rss_group.command(name="list")
async def list_rss_sources(self, inter: Interaction):
"""Provides a with a list of RSS sources available for the current server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
if not rss_sources:
await followup(inter, "It looks like you have no rss sources.")
return
output = "## Available RSS Sources\n"
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
await followup(inter, output, suppress_embeds=True)
@rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
# """"""
await inter.response.defer()
if max > 5:
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return
source = get_source(rss)
articles = source.get_latest_articles(max)
embeds = []
for article in articles:
md_description = markdownify(article.description, strip=("img",))
article_description = textwrap.shorten(md_description, 4096)
embed = Embed(
title=article.title,
description=article_description,
url=article.url,
timestamp=article.published,
)
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url())
embed.set_footer(text=article.author)
embed.set_author(
name=source.name,
url=source.url,
)
embeds.append(embed)
async with DatabaseManager() as database:
query = insert(SentArticleModel).values([
{
"discord_server_id": inter.guild_id,
"discord_channel_id": inter.channel_id,
"discord_message_id": inter.id,
"article_url": article.url,
}
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
await followup(inter, embeds=embeds)
async def setup(bot):
"""
Setup function for this extension.
Adds `CommandCog` to the bot.
"""
cog = CommandCog(bot)
await bot.add_cog(cog)
log.info(f"Added {cog.__class__.__name__} cog")