Working on commands
This commit is contained in:
parent
85b6f118bc
commit
0770fb3f6f
15
src/bot.py
15
src/bot.py
@ -46,9 +46,16 @@ class DiscordBot(commands.Bot):
|
||||
if path.suffix == ".py":
|
||||
await self.load_extension(f"extensions.{path.stem}")
|
||||
|
||||
async def audit(self, message: str, user_id: int):
|
||||
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
|
||||
|
||||
message = f"Requesting latest article"
|
||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||
|
||||
if database:
|
||||
await database.session.execute(query)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
message = f"Requesting latest article"
|
||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||
await database.session.execute(query)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.debug("Audit logged")
|
||||
|
@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`.
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@ -24,7 +24,6 @@ class AuditModel(Base):
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
|
||||
|
||||
class SentArticleModel(Base):
|
||||
@ -39,7 +38,6 @@ class SentArticleModel(Base):
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
article_url = Column(String, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
|
||||
|
||||
class RssSourceModel(Base):
|
||||
@ -50,9 +48,14 @@ class RssSourceModel(Base):
|
||||
__tablename__ = "rss_source"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
nick = Column(String, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
rss_url = Column(String, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
|
||||
# the nickname must be unique, but only within the same discord server
|
||||
__table_args__ = (
|
||||
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
|
||||
)
|
||||
|
||||
|
||||
class FeedChannelModel(Base):
|
||||
@ -64,4 +67,3 @@ class FeedChannelModel(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
|
@ -12,7 +12,8 @@ import feedparser
|
||||
from markdownify import markdownify
|
||||
from discord import app_commands, Interaction, Embed
|
||||
from discord.ext import commands, tasks
|
||||
from sqlalchemy import insert, select, update, and_, or_
|
||||
from discord.app_commands import Choice, Group, command, autocomplete
|
||||
from sqlalchemy import insert, select, update, and_, or_, delete
|
||||
|
||||
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||
from feed import Feeds, get_source
|
||||
@ -26,6 +27,22 @@ async def get_rss_data(url: str):
|
||||
|
||||
return items
|
||||
|
||||
async def followup(inter: Interaction, *args, **kwargs):
|
||||
"""Shorthand for following up on an interaction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.followup.send(*args, **kwargs)
|
||||
|
||||
async def audit(cog, *args, **kwargs):
|
||||
"""Shorthand for auditing an interaction."""
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
||||
|
||||
|
||||
class CommandCog(commands.Cog):
|
||||
"""
|
||||
@ -40,38 +57,87 @@ class CommandCog(commands.Cog):
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
rss_group = app_commands.Group(
|
||||
async def source_autocomplete(self, inter: Interaction, nickname: str):
|
||||
"""Provides RSS source autocomplete functionality for commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
nickname : str
|
||||
_description_
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of app_commands.Choice
|
||||
_description_
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.nick.ilike(f"%{nickname}%")
|
||||
)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
sources = [
|
||||
Choice(name=rss.nick, value=rss.rss_url)
|
||||
for rss in result.scalars().all()
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
rss_group = Group(
|
||||
name="rss",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True
|
||||
)
|
||||
|
||||
@rss_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, url: str):
|
||||
async def add_rss_source(self, inter: Interaction, url: str, nickname: str):
|
||||
"""Add a new RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
url : str
|
||||
The RSS feed URL.
|
||||
nickname : str
|
||||
A name used to identify the RSS source.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# validate the input
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
await inter.followup.send(
|
||||
"The URL you have entered is malformed or invalid:\n"
|
||||
f"`{url=}`",
|
||||
await followup(inter,
|
||||
f"The URL you have entered is malformed or invalid:\n`{url=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
await followup(inter,
|
||||
"It looks like the nickname you have entered is a URL.\n"
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
# Check the URL points to an RSS feed.
|
||||
feed_data, status_code = await get_rss_data(url) # TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource. Run a period task to check this.
|
||||
if status_code != 200:
|
||||
await inter.followup.send(
|
||||
f"The URL provided returned an invalid status code:\n"
|
||||
f"{url=}, {status_code=}",
|
||||
await followup(inter,
|
||||
f"The URL provided returned an invalid status code:\n{url=}, {status_code=}",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
feed = feedparser.parse(feed_data)
|
||||
if not feed.version:
|
||||
await inter.followup.send(
|
||||
await followup(inter,
|
||||
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
||||
suppress_embeds=True
|
||||
)
|
||||
@ -80,97 +146,147 @@ class CommandCog(commands.Cog):
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
rss_url = url
|
||||
rss_url = url,
|
||||
nick=nickname
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
await inter.followup.send("RSS source added")
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
await followup(inter, f"RSS source added [{nickname}]({url})", suppress_embeds=True)
|
||||
|
||||
@rss_group.command(name="remove")
|
||||
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
|
||||
@autocomplete(source=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, source: str):
|
||||
"""Delete an existing RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
source : str
|
||||
The RSS source to be removed. Autocomplete or enter the URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
def exists(item) -> bool:
|
||||
"""
|
||||
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
|
||||
Ironically with this func & comment the code is longer, but at least I can read it ...
|
||||
"""
|
||||
log.debug(f"Attempting to remove RSS source ({source=})")
|
||||
|
||||
return item is not None
|
||||
async with DatabaseManager() as database:
|
||||
rss_source = (await database.session.execute(
|
||||
select(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == source
|
||||
)
|
||||
)
|
||||
)).fetchone()
|
||||
|
||||
url_exists = exists(url)
|
||||
num_exists = exists(number)
|
||||
|
||||
if (url_exists and num_exists) or (not url_exists and not num_exists):
|
||||
await inter.followup.send(
|
||||
"Please only specify either the existing rss number or url, "
|
||||
"enter at least one of these, but don't enter both."
|
||||
result = await database.session.execute(
|
||||
delete(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == source
|
||||
)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if url_exists and not validators.url(url):
|
||||
await inter.followup.send(
|
||||
"The URL you have entered is malformed or invalid:\n"
|
||||
f"`{url=}`",
|
||||
# TODO: `if not result.rowcount` then show unique message and possible matches if any (like how the autocomplete works)
|
||||
|
||||
if result.rowcount:
|
||||
await followup(inter,
|
||||
f"RSS source deleted successfully\n**[{rss_source.nick}]({rss_source.rss_url})**",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
query = update(RssSourceModel).where(whereclause).values(active=False)
|
||||
result = await database.session.execute(query)
|
||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||
|
||||
await inter.followup.send(f"I've updated {result.rowcount} rows")
|
||||
# potential_matches = await self.source_autocomplete(inter, source)
|
||||
|
||||
@rss_group.command(name="list")
|
||||
@app_commands.choices(filter=[
|
||||
app_commands.Choice(name="Active Only [default]", value=1),
|
||||
app_commands.Choice(name="Inactive Only", value=0),
|
||||
app_commands.Choice(name="All", value=2),
|
||||
])
|
||||
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
|
||||
async def list_rss_sources(self, inter: Interaction):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
if filter.value == 2:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
else:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.active == filter.value # should result to 0 or 1
|
||||
)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
rss_sources = result.scalars().all()
|
||||
embed_fields = [{
|
||||
"name": f"[{i}]",
|
||||
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
|
||||
} for i, rss in enumerate(rss_sources)]
|
||||
|
||||
if not embed_fields:
|
||||
await inter.followup.send("It looks like you have no rss sources.")
|
||||
if not rss_sources:
|
||||
await followup(inter, "It looks like you have no rss sources.")
|
||||
return
|
||||
|
||||
output = "## Available RSS Sources\n"
|
||||
output += "\n".join([f"**[{rss.nick}]({rss.rss_url})** " for rss in rss_sources])
|
||||
|
||||
await followup(inter, output, suppress_embeds=True)
|
||||
|
||||
@rss_group.command(name="fetch")
|
||||
@autocomplete(rss=source_autocomplete)
|
||||
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
|
||||
# """"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
if max > 5:
|
||||
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
||||
return
|
||||
|
||||
embed = Embed(
|
||||
title="RSS Sources",
|
||||
description="Here are your rss sources:"
|
||||
)
|
||||
source = get_source(rss)
|
||||
articles = source.get_latest_articles(max)
|
||||
|
||||
embeds = []
|
||||
for article in articles:
|
||||
md_description = markdownify(article.description, strip=("img",))
|
||||
article_description = textwrap.shorten(md_description, 4096)
|
||||
|
||||
embed = Embed(
|
||||
title=article.title,
|
||||
description=article_description,
|
||||
url=article.url,
|
||||
timestamp=article.published,
|
||||
)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
embed.set_image(url=await article.get_thumbnail_url())
|
||||
embed.set_footer(text=article.author)
|
||||
embed.set_author(
|
||||
name=source.name,
|
||||
url=source.url,
|
||||
)
|
||||
embeds.append(embed)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(SentArticleModel).values([
|
||||
{
|
||||
"discord_server_id": inter.guild_id,
|
||||
"discord_channel_id": inter.channel_id,
|
||||
"discord_message_id": inter.id,
|
||||
"article_url": article.url,
|
||||
}
|
||||
for article in articles
|
||||
])
|
||||
await database.session.execute(query)
|
||||
await audit(self, f"User is requesting {max} articles", inter.user.id, database=database)
|
||||
|
||||
await followup(inter, embeds=embeds)
|
||||
|
||||
|
||||
for field in embed_fields:
|
||||
embed.add_field(**field, inline=False)
|
||||
|
||||
# output = "Your rss sources:\n\n"
|
||||
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
|
||||
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
29
src/feed.py
29
src/feed.py
@ -37,8 +37,18 @@ class Source:
|
||||
feed=feed
|
||||
)
|
||||
|
||||
def get_latest_article(self):
|
||||
return Article.from_parsed(self.feed)
|
||||
def get_latest_articles(self, max: int) -> list:
|
||||
""""""
|
||||
|
||||
articles = []
|
||||
|
||||
for i, entry in enumerate(self.feed.entries):
|
||||
if i >= max:
|
||||
break
|
||||
|
||||
articles.append(Article.from_entry(entry))
|
||||
|
||||
return articles
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -66,6 +76,21 @@ class Article:
|
||||
author = entry.get("author")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entry(cls, entry:FeedParserDict):
|
||||
|
||||
published_parsed = entry.get("published_parsed")
|
||||
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
|
||||
|
||||
return cls(
|
||||
title=entry.get("title"),
|
||||
description=entry.get("description"),
|
||||
url=entry.get("link"),
|
||||
published=published,
|
||||
author = entry.get("author")
|
||||
)
|
||||
|
||||
|
||||
async def get_thumbnail_url(self):
|
||||
"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user