307 lines
9.8 KiB
Python
307 lines
9.8 KiB
Python
"""
|
|
Extension for the `CommandCog`.
|
|
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import validators
|
|
|
|
import aiohttp
|
|
import textwrap
|
|
import feedparser
|
|
from markdownify import markdownify
|
|
from discord import app_commands, Interaction, Embed, Colour
|
|
from discord.ext import commands, tasks
|
|
from discord.app_commands import Choice, Group, command, autocomplete
|
|
from sqlalchemy import insert, select, update, and_, or_, delete
|
|
|
|
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
|
from feed import get_source, Source
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
async def get_rss_data(url: str):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
items = await response.text(), response.status
|
|
|
|
return items
|
|
|
|
async def followup(inter: Interaction, *args, **kwargs):
|
|
"""Shorthand for following up on an interaction.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
"""
|
|
|
|
await inter.followup.send(*args, **kwargs)
|
|
|
|
async def audit(cog, *args, **kwargs):
|
|
"""Shorthand for auditing an interaction."""
|
|
|
|
await cog.bot.audit(*args, **kwargs)
|
|
|
|
|
|
class CommandCog(commands.Cog):
|
|
"""
|
|
Command cog.
|
|
"""
|
|
|
|
def __init__(self, bot):
|
|
super().__init__()
|
|
self.bot = bot
|
|
|
|
@commands.Cog.listener()
|
|
async def on_ready(self):
|
|
log.info(f"{self.__class__.__name__} cog is ready")
|
|
|
|
async def source_autocomplete(self, inter: Interaction, nickname: str):
|
|
"""Provides RSS source autocomplete functionality for commands.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
nickname : str
|
|
_description_
|
|
|
|
Returns
|
|
-------
|
|
list of app_commands.Choice
|
|
_description_
|
|
"""
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.nick.ilike(f"%{nickname}%")
|
|
)
|
|
query = select(RssSourceModel).where(whereclause)
|
|
result = await database.session.execute(query)
|
|
sources = [
|
|
Choice(name=rss.nick, value=rss.rss_url)
|
|
for rss in result.scalars().all()
|
|
]
|
|
|
|
return sources
|
|
|
|
rss_group = Group(
|
|
name="rss",
|
|
description="Commands for rss sources.",
|
|
guild_only=True
|
|
)
|
|
|
|
@rss_group.command(name="add")
|
|
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
|
|
"""Add a new RSS source.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
url : str
|
|
The RSS feed URL.
|
|
nickname : str
|
|
A name used to identify the RSS source.
|
|
"""
|
|
|
|
await inter.response.defer()
|
|
|
|
# Ensure the URL is valid
|
|
if not validators.url(url):
|
|
await followup(inter,
|
|
f"The URL you have entered is malformed or invalid:\n`{url=}`",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
# Check the nickname is not a URL
|
|
if validators.url(nickname):
|
|
await followup(inter,
|
|
"It looks like the nickname you have entered is a URL.\n"
|
|
f"For security reasons, this is not allowed.\n`{nickname=}`",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
# Check the URL points to an RSS feed.
|
|
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
|
|
if status_code != 200:
|
|
await followup(inter,
|
|
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
feed = feedparser.parse(feed_data)
|
|
if not feed.version:
|
|
await followup(inter,
|
|
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
async with DatabaseManager() as database:
|
|
query = insert(RssSourceModel).values(
|
|
discord_server_id = inter.guild_id,
|
|
rss_url = url,
|
|
nick=nickname
|
|
)
|
|
await database.session.execute(query)
|
|
|
|
await audit(self,
|
|
f"Added RSS source ({nickname=}, {url=})",
|
|
inter.user.id, database=database
|
|
)
|
|
|
|
embed = Embed(title="RSS Feed Added", colour=Colour.from_str("#59ff00"))
|
|
embed.add_field(name="Nickname", value=nickname)
|
|
embed.add_field(name="URL", value=url)
|
|
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
|
|
|
# , f"RSS source added [{nickname}]({url})", suppress_embeds=True
|
|
await followup(inter, embed=embed)
|
|
|
|
@rss_group.command(name="remove")
|
|
@autocomplete(source=source_autocomplete)
|
|
async def remove_rss_source(self, inter: Interaction, source: str):
|
|
"""Delete an existing RSS source.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
source : str
|
|
The RSS source to be removed. Autocomplete or enter the URL.
|
|
"""
|
|
|
|
await inter.response.defer()
|
|
|
|
log.debug(f"Attempting to remove RSS source ({source=})")
|
|
|
|
async with DatabaseManager() as database:
|
|
select_result = await database.session.execute(
|
|
select(RssSourceModel).filter(
|
|
and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.rss_url == source
|
|
)
|
|
)
|
|
)
|
|
rss_source = select_result.fetchone()
|
|
|
|
delete_result = await database.session.execute(
|
|
delete(RssSourceModel).filter(
|
|
and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.rss_url == source
|
|
)
|
|
)
|
|
)
|
|
|
|
nickname, rss_url = rss_source.nick, rss_source.rss_url
|
|
|
|
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
|
|
|
|
if delete_result.rowcount:
|
|
await followup(inter,
|
|
f"RSS source deleted successfully\n**[{nickname}]({rss_url})**",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
await followup(inter, "Couldn't find any RSS sources with this name.")
|
|
|
|
# potential_matches = await self.source_autocomplete(inter, source)
|
|
|
|
@rss_group.command(name="list")
|
|
async def list_rss_sources(self, inter: Interaction):
|
|
"""Provides a with a list of RSS sources available for the current server.
|
|
|
|
Parameters
|
|
----------
|
|
inter : Interaction
|
|
Represents an app command interaction.
|
|
"""
|
|
|
|
await inter.response.defer()
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
|
query = select(RssSourceModel).where(whereclause)
|
|
result = await database.session.execute(query)
|
|
|
|
rss_sources = result.scalars().all()
|
|
|
|
if not rss_sources:
|
|
await followup(inter, "It looks like you have no rss sources.")
|
|
return
|
|
|
|
output = "## Available RSS Sources\n"
|
|
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
|
|
|
|
await followup(inter, output, suppress_embeds=True)
|
|
|
|
@rss_group.command(name="fetch")
|
|
@autocomplete(rss=source_autocomplete)
|
|
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
|
|
# """"""
|
|
|
|
await inter.response.defer()
|
|
|
|
if max > 5:
|
|
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
|
return
|
|
|
|
source = get_source(rss)
|
|
articles = source.get_latest_articles(max)
|
|
|
|
embeds = []
|
|
for article in articles:
|
|
md_description = markdownify(article.description, strip=("img",))
|
|
article_description = textwrap.shorten(md_description, 4096)
|
|
|
|
embed = Embed(
|
|
title=article.title,
|
|
description=article_description,
|
|
url=article.url,
|
|
timestamp=article.published,
|
|
)
|
|
embed.set_thumbnail(url=source.icon_url)
|
|
embed.set_image(url=await article.get_thumbnail_url())
|
|
embed.set_footer(text=article.author)
|
|
embed.set_author(
|
|
name=source.name,
|
|
url=source.url,
|
|
)
|
|
embeds.append(embed)
|
|
|
|
async with DatabaseManager() as database:
|
|
query = insert(SentArticleModel).values([
|
|
{
|
|
"discord_server_id": inter.guild_id,
|
|
"discord_channel_id": inter.channel_id,
|
|
"discord_message_id": inter.id,
|
|
"article_url": article.url,
|
|
}
|
|
for article in articles
|
|
])
|
|
await database.session.execute(query)
|
|
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
|
|
|
|
await followup(inter, embeds=embeds)
|
|
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Setup function for this extension.
|
|
Adds `CommandCog` to the bot.
|
|
"""
|
|
|
|
cog = CommandCog(bot)
|
|
await bot.add_cog(cog)
|
|
log.info(f"Added {cog.__class__.__name__} cog")
|