Working on commands

This commit is contained in:
Corban-Lee Jones 2023-12-15 23:13:39 +00:00
parent 85b6f118bc
commit 0770fb3f6f
4 changed files with 231 additions and 81 deletions

View File

@ -46,9 +46,16 @@ class DiscordBot(commands.Bot):
if path.suffix == ".py": if path.suffix == ".py":
await self.load_extension(f"extensions.{path.stem}") await self.load_extension(f"extensions.{path.stem}")
async def audit(self, message: str, user_id: int): async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
message = f"Requesting latest article"
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
if database:
await database.session.execute(query)
return
async with DatabaseManager() as database: async with DatabaseManager() as database:
message = f"Requesting latest article" await database.session.execute(query)
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
await database.session.execute(query) log.debug("Audit logged")

View File

@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`.
from enum import Enum, auto from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -24,7 +24,6 @@ class AuditModel(Base):
discord_user_id = Column(BigInteger, nullable=False) discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False) message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
active = Column(Integer, default=True, nullable=False)
class SentArticleModel(Base): class SentArticleModel(Base):
@ -39,7 +38,6 @@ class SentArticleModel(Base):
discord_channel_id = Column(BigInteger, nullable=False) discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False) article_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
class RssSourceModel(Base): class RssSourceModel(Base):
@ -50,9 +48,14 @@ class RssSourceModel(Base):
__tablename__ = "rss_source" __tablename__ = "rss_source"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False) rss_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
# the nickname must be unique, but only within the same discord server
__table_args__ = (
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
)
class FeedChannelModel(Base): class FeedChannelModel(Base):
@ -64,4 +67,3 @@ class FeedChannelModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
discord_channel_id = Column(BigInteger, nullable=False) discord_channel_id = Column(BigInteger, nullable=False)
active = Column(Integer, default=True, nullable=False)

View File

@ -12,7 +12,8 @@ import feedparser
from markdownify import markdownify from markdownify import markdownify
from discord import app_commands, Interaction, Embed from discord import app_commands, Interaction, Embed
from discord.ext import commands, tasks from discord.ext import commands, tasks
from sqlalchemy import insert, select, update, and_, or_ from discord.app_commands import Choice, Group, command, autocomplete
from sqlalchemy import insert, select, update, and_, or_, delete
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import Feeds, get_source from feed import Feeds, get_source
@ -26,6 +27,22 @@ async def get_rss_data(url: str):
return items return items
async def followup(inter: Interaction, *args, **kwargs):
"""Shorthand for following up on an interaction.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
class CommandCog(commands.Cog): class CommandCog(commands.Cog):
""" """
@ -40,38 +57,87 @@ class CommandCog(commands.Cog):
async def on_ready(self): async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready") log.info(f"{self.__class__.__name__} cog is ready")
rss_group = app_commands.Group( async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
_description_
Returns
-------
list of app_commands.Choice
_description_
"""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.rss_url)
for rss in result.scalars().all()
]
return sources
rss_group = Group(
name="rss", name="rss",
description="Commands for rss sources.", description="Commands for rss sources.",
guild_only=True guild_only=True
) )
@rss_group.command(name="add") @rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str): async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
"""Add a new RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS feed URL.
nickname : str
A name used to identify the RSS source.
"""
await inter.response.defer() await inter.response.defer()
# validate the input # Ensure the URL is valid
if not validators.url(url): if not validators.url(url):
await inter.followup.send( await followup(inter,
"The URL you have entered is malformed or invalid:\n" f"The URL you have entered is malformed or invalid:\n`{url=}`",
f"`{url=}`",
suppress_embeds=True suppress_embeds=True
) )
return return
feed_data, status_code = await get_rss_data(url) # Check the nickname is not a URL
if validators.url(nickname):
await followup(inter,
"It looks like the nickname you have entered is a URL.\n"
f"For security reasons, this is not allowed.\n`{nickname=}`",
suppress_embeds=True
)
return
# Check the URL points to an RSS feed.
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
if status_code != 200: if status_code != 200:
await inter.followup.send( await followup(inter,
f"The URL provided returned an invalid status code:\n" f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
f"{url=}, {status_code=}",
suppress_embeds=True suppress_embeds=True
) )
return return
feed = feedparser.parse(feed_data) feed = feedparser.parse(feed_data)
if not feed.version: if not feed.version:
await inter.followup.send( await followup(inter,
f"The provided URL '{url}' does not seem to be a valid RSS feed.", f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True suppress_embeds=True
) )
@ -80,97 +146,147 @@ class CommandCog(commands.Cog):
async with DatabaseManager() as database: async with DatabaseManager() as database:
query = insert(RssSourceModel).values( query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id, discord_server_id = inter.guild_id,
rss_url = url rss_url = url,
nick=nickname
) )
await database.session.execute(query) await database.session.execute(query)
await inter.followup.send("RSS source added") await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True)
@rss_group.command(name="remove") @rss_group.command(name="remove")
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None): @autocomplete(source=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, source: str):
"""Delete an existing RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
source : str
The RSS source to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer() await inter.response.defer()
def exists(item) -> bool: log.debug(f"Attempting to remove RSS source ({source=})")
"""
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
Ironically with this func & comment the code is longer, but at least I can read it ...
"""
return item is not None async with DatabaseManager() as database:
rss_source = (await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
)
)
)).fetchone()
url_exists = exists(url) result = await database.session.execute(
num_exists = exists(number) delete(RssSourceModel).filter(
and_(
if (url_exists and num_exists) or (not url_exists and not num_exists): RssSourceModel.discord_server_id == inter.guild_id,
await inter.followup.send( RssSourceModel.rss_url == source
"Please only specify either the existing rss number or url, " )
"enter at least one of these, but don't enter both." )
) )
return
if url_exists and not validators.url(url): # TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
await inter.followup.send(
"The URL you have entered is malformed or invalid:\n" if result.rowcount:
f"`{url=}`", await followup(inter,
f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**",
suppress_embeds=True suppress_embeds=True
) )
return return
async with DatabaseManager() as database: await followup(inter, "Couldn't find any RSS sources with this name.")
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
query = update(RssSourceModel).where(whereclause).values(active=False)
result = await database.session.execute(query)
await inter.followup.send(f"I've updated {result.rowcount} rows") # potential_matches = await self.source_autocomplete(inter, source)
@rss_group.command(name="list") @rss_group.command(name="list")
@app_commands.choices(filter=[ async def list_rss_sources(self, inter: Interaction):
app_commands.Choice(name="Active Only [default]", value=1), """Provides a with a list of RSS sources available for the current server.
app_commands.Choice(name="Inactive Only", value=0),
app_commands.Choice(name="All", value=2), Parameters
]) ----------
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]): inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer() await inter.response.defer()
if filter.value == 2:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
else:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.active == filter.value # should result to 0 or 1
)
async with DatabaseManager() as database: async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause) query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query) result = await database.session.execute(query)
rss_sources = result.scalars().all() rss_sources = result.scalars().all()
embed_fields = [{
"name": f"[{i}]",
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
} for i, rss in enumerate(rss_sources)]
if not embed_fields: if not rss_sources:
await inter.followup.send("It looks like you have no rss sources.") await followup(inter, "It looks like you have no rss sources.")
return
output = "## Available RSS Sources\n"
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
await followup(inter, output, suppress_embeds=True)
@rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
# """"""
await inter.response.defer()
if max > 5:
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return return
embed = Embed( source = get_source(rss)
title="RSS Sources", articles = source.get_latest_articles(max)
description="Here are your rss sources:"
) embeds = []
for article in articles:
md_description = markdownify(article.description, strip=("img",))
article_description = textwrap.shorten(md_description, 4096)
embed = Embed(
title=article.title,
description=article_description,
url=article.url,
timestamp=article.published,
)
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url())
embed.set_footer(text=article.author)
embed.set_author(
name=source.name,
url=source.url,
)
embeds.append(embed)
async with DatabaseManager() as database:
query = insert(SentArticleModel).values([
{
"discord_server_id": inter.guild_id,
"discord_channel_id": inter.channel_id,
"discord_message_id": inter.id,
"article_url": article.url,
}
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
await followup(inter, embeds=embeds)
for field in embed_fields:
embed.add_field(**field, inline=False)
# output = "Your rss sources:\n\n"
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
await inter.followup.send(embed=embed)
async def setup(bot): async def setup(bot):

View File

@ -37,8 +37,18 @@ class Source:
feed=feed feed=feed
) )
def get_latest_article(self): def get_latest_articles(self, max: int) -> list:
return Article.from_parsed(self.feed) """"""
articles = []
for i, entry in enumerate(self.feed.entries):
if i >= max:
break
articles.append(Article.from_entry(entry))
return articles
@dataclass @dataclass
@ -66,6 +76,21 @@ class Article:
author = entry.get("author") author = entry.get("author")
) )
@classmethod
def from_entry(cls, entry:FeedParserDict):
published_parsed = entry.get("published_parsed")
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
return cls(
title=entry.get("title"),
description=entry.get("description"),
url=entry.get("link"),
published=published,
author = entry.get("author")
)
async def get_thumbnail_url(self): async def get_thumbnail_url(self):
""" """