Working on commands

This commit is contained in:
Corban-Lee Jones 2023-12-15 23:13:39 +00:00
parent 85b6f118bc
commit 0770fb3f6f
4 changed files with 231 additions and 81 deletions

View File

@ -46,9 +46,16 @@ class DiscordBot(commands.Bot):
if path.suffix == ".py":
await self.load_extension(f"extensions.{path.stem}")
async def audit(self, message: str, user_id: int):
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
message = f"Requesting latest article"
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
if database:
await database.session.execute(query)
return
async with DatabaseManager() as database:
message = f"Requesting latest article"
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
await database.session.execute(query)
await database.session.execute(query)
log.debug("Audit logged")

View File

@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`.
from enum import Enum, auto
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base
@ -24,7 +24,6 @@ class AuditModel(Base):
discord_user_id = Column(BigInteger, nullable=False)
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
active = Column(Integer, default=True, nullable=False)
class SentArticleModel(Base):
@ -39,7 +38,6 @@ class SentArticleModel(Base):
discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
class RssSourceModel(Base):
@ -50,9 +48,14 @@ class RssSourceModel(Base):
__tablename__ = "rss_source"
id = Column(Integer, primary_key=True, autoincrement=True)
nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False)
active = Column(Integer, default=True, nullable=False)
# the nickname must be unique, but only within the same discord server
__table_args__ = (
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
)
class FeedChannelModel(Base):
@ -64,4 +67,3 @@ class FeedChannelModel(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
discord_channel_id = Column(BigInteger, nullable=False)
active = Column(Integer, default=True, nullable=False)

View File

@ -12,7 +12,8 @@ import feedparser
from markdownify import markdownify
from discord import app_commands, Interaction, Embed
from discord.ext import commands, tasks
from sqlalchemy import insert, select, update, and_, or_
from discord.app_commands import Choice, Group, command, autocomplete
from sqlalchemy import insert, select, update, and_, or_, delete
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from feed import Feeds, get_source
@ -26,6 +27,22 @@ async def get_rss_data(url: str):
return items
async def followup(inter: Interaction, *args, **kwargs):
"""Shorthand for following up on an interaction.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
class CommandCog(commands.Cog):
"""
@ -40,38 +57,87 @@ class CommandCog(commands.Cog):
async def on_ready(self):
log.info(f"{self.__class__.__name__} cog is ready")
rss_group = app_commands.Group(
async def source_autocomplete(self, inter: Interaction, nickname: str):
"""Provides RSS source autocomplete functionality for commands.
Parameters
----------
inter : Interaction
Represents an app command interaction.
nickname : str
_description_
Returns
-------
list of app_commands.Choice
_description_
"""
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.nick.ilike(f"%{nickname}%")
)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
sources = [
Choice(name=rss.nick, value=rss.rss_url)
for rss in result.scalars().all()
]
return sources
rss_group = Group(
name="rss",
description="Commands for rss sources.",
guild_only=True
)
@rss_group.command(name="add")
async def add_rss_source(self, inter: Interaction, url: str):
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
"""Add a new RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
url : str
The RSS feed URL.
nickname : str
A name used to identify the RSS source.
"""
await inter.response.defer()
# validate the input
# Ensure the URL is valid
if not validators.url(url):
await inter.followup.send(
"The URL you have entered is malformed or invalid:\n"
f"`{url=}`",
await followup(inter,
f"The URL you have entered is malformed or invalid:\n`{url=}`",
suppress_embeds=True
)
return
feed_data, status_code = await get_rss_data(url)
# Check the nickname is not a URL
if validators.url(nickname):
await followup(inter,
"It looks like the nickname you have entered is a URL.\n"
f"For security reasons, this is not allowed.\n`{nickname=}`",
suppress_embeds=True
)
return
# Check the URL points to an RSS feed.
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
if status_code != 200:
await inter.followup.send(
f"The URL provided returned an invalid status code:\n"
f"{url=}, {status_code=}",
await followup(inter,
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
suppress_embeds=True
)
return
feed = feedparser.parse(feed_data)
if not feed.version:
await inter.followup.send(
await followup(inter,
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
suppress_embeds=True
)
@ -80,97 +146,147 @@ class CommandCog(commands.Cog):
async with DatabaseManager() as database:
query = insert(RssSourceModel).values(
discord_server_id = inter.guild_id,
rss_url = url
rss_url = url,
nick=nickname
)
await database.session.execute(query)
await inter.followup.send("RSS source added")
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True)
@rss_group.command(name="remove")
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
@autocomplete(source=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, source: str):
"""Delete an existing RSS source.
Parameters
----------
inter : Interaction
Represents an app command interaction.
source : str
The RSS source to be removed. Autocomplete or enter the URL.
"""
await inter.response.defer()
def exists(item) -> bool:
"""
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
Ironically with this func & comment the code is longer, but at least I can read it ...
"""
log.debug(f"Attempting to remove RSS source ({source=})")
return item is not None
async with DatabaseManager() as database:
rss_source = (await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
)
)
)).fetchone()
url_exists = exists(url)
num_exists = exists(number)
if (url_exists and num_exists) or (not url_exists and not num_exists):
await inter.followup.send(
"Please only specify either the existing rss number or url, "
"enter at least one of these, but don't enter both."
result = await database.session.execute(
delete(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == source
)
)
)
return
if url_exists and not validators.url(url):
await inter.followup.send(
"The URL you have entered is malformed or invalid:\n"
f"`{url=}`",
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
if result.rowcount:
await followup(inter,
f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**",
suppress_embeds=True
)
return
async with DatabaseManager() as database:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
query = update(RssSourceModel).where(whereclause).values(active=False)
result = await database.session.execute(query)
await followup(inter, "Couldn't find any RSS sources with this name.")
await inter.followup.send(f"I've updated {result.rowcount} rows")
# potential_matches = await self.source_autocomplete(inter, source)
@rss_group.command(name="list")
@app_commands.choices(filter=[
app_commands.Choice(name="Active Only [default]", value=1),
app_commands.Choice(name="Inactive Only", value=0),
app_commands.Choice(name="All", value=2),
])
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
async def list_rss_sources(self, inter: Interaction):
"""Provides a with a list of RSS sources available for the current server.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.response.defer()
if filter.value == 2:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
else:
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.active == filter.value # should result to 0 or 1
)
async with DatabaseManager() as database:
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
query = select(RssSourceModel).where(whereclause)
result = await database.session.execute(query)
rss_sources = result.scalars().all()
embed_fields = [{
"name": f"[{i}]",
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
} for i, rss in enumerate(rss_sources)]
if not embed_fields:
await inter.followup.send("It looks like you have no rss sources.")
if not rss_sources:
await followup(inter, "It looks like you have no rss sources.")
return
output = "## Available RSS Sources\n"
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
await followup(inter, output, suppress_embeds=True)
@rss_group.command(name="fetch")
@autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
# """"""
await inter.response.defer()
if max > 5:
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
return
embed = Embed(
title="RSS Sources",
description="Here are your rss sources:"
)
source = get_source(rss)
articles = source.get_latest_articles(max)
embeds = []
for article in articles:
md_description = markdownify(article.description, strip=("img",))
article_description = textwrap.shorten(md_description, 4096)
embed = Embed(
title=article.title,
description=article_description,
url=article.url,
timestamp=article.published,
)
embed.set_thumbnail(url=source.icon_url)
embed.set_image(url=await article.get_thumbnail_url())
embed.set_footer(text=article.author)
embed.set_author(
name=source.name,
url=source.url,
)
embeds.append(embed)
async with DatabaseManager() as database:
query = insert(SentArticleModel).values([
{
"discord_server_id": inter.guild_id,
"discord_channel_id": inter.channel_id,
"discord_message_id": inter.id,
"article_url": article.url,
}
for article in articles
])
await database.session.execute(query)
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
await followup(inter, embeds=embeds)
for field in embed_fields:
embed.add_field(**field, inline=False)
# output = "Your rss sources:\n\n"
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
await inter.followup.send(embed=embed)
async def setup(bot):

View File

@ -37,8 +37,18 @@ class Source:
feed=feed
)
def get_latest_article(self):
return Article.from_parsed(self.feed)
def get_latest_articles(self, max: int) -> list:
""""""
articles = []
for i, entry in enumerate(self.feed.entries):
if i >= max:
break
articles.append(Article.from_entry(entry))
return articles
@dataclass
@ -66,6 +76,21 @@ class Article:
author = entry.get("author")
)
@classmethod
def from_entry(cls, entry:FeedParserDict):
published_parsed = entry.get("published_parsed")
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
return cls(
title=entry.get("title"),
description=entry.get("description"),
url=entry.get("link"),
published=published,
author = entry.get("author")
)
async def get_thumbnail_url(self):
"""