diff --git a/src/bot.py b/src/bot.py index 69e3a1c..e78422d 100644 --- a/src/bot.py +++ b/src/bot.py @@ -47,8 +47,18 @@ class DiscordBot(commands.Bot): await self.load_extension(f"extensions.{path.stem}") async def audit(self, message: str, user_id: int, database: DatabaseManager=None): + """Shorthand for auditing an action. + + Parameters + ---------- + message : str + The message to be audited. + user_id : int + Discord ID of the user being audited. + database : DatabaseManager, optional + An existing database connection to be used if specified, by default None + """ - message = f"Requesting latest article" query = insert(AuditModel).values(discord_user_id=user_id, message=message) if database: diff --git a/src/db/models.py b/src/db/models.py index 66a12ca..7a5b58e 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -47,6 +47,9 @@ class SentArticleModel(Base): article_url = Column(String, nullable=False) when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 + # feed_channel_id = Column(Integer, ForeignKey("feed_channel.id"), nullable=False) + # feed_channel = relationship("FeedChannelModel", lazy="joined", cascade="all, delete") + class RssSourceModel(Base): """ @@ -80,8 +83,9 @@ class FeedChannelModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) search_name = Column(String, nullable=False) - rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) + # created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 + rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete") # the rss source must be unique, but only within the same discord channel diff --git a/src/extensions/rss.py b/src/extensions/rss.py index c911b32..e9914b5 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -15,7 +15,7 @@ from sqlalchemy import insert, select, and_, delete from sqlalchemy.exc import NoResultFound from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401 -from feed import get_source, Source # pylint: disable=E0401 +from feed import get_source, get_unparsed_feed, Source # pylint: disable=E0401 from db import ( # pylint: disable=E0401 DatabaseManager, SentArticleModel, @@ -78,6 +78,23 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed return None, feed +async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str): + unparsed_feed = await get_unparsed_feed(rss_url) + source = Source.from_parsed(parse(unparsed_feed)) + articles = source.get_latest_articles() + + async with DatabaseManager() as database: + query = insert(SentArticleModel).values([ + { + "discord_server_id": inter.guild_id, + "discord_channel_id": channel.id, + "discord_message_id": -1, + "article_url": article.url, + } + for article in articles + ]) + await database.session.execute(query) + class FeedCog(commands.Cog): """ @@ -221,7 +238,7 @@ class FeedCog(commands.Cog): inter.user.id, database=database ) - source = get_source(url) + source = get_source(url) # TODO: replace with async function embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red()) embed.add_field(name="Nickname", value=nickname) @@ -478,14 +495,16 @@ class FeedCog(commands.Cog): @feed_group.command(name="assign") @rename(rss="feed") @autocomplete(rss=autocomplete_rss_sources) - async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None): + async def include_feed( + self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True + ): """Include a feed within the specified channel. Parameters ---------- inter : Interaction Represents an app command interaction. - rss : str + rss : int The RSS feed to include. channel : TextChannel The channel to include the feed in. @@ -514,6 +533,8 @@ class FeedCog(commands.Cog): await database.session.execute(insert_query) + if prevent_spam: + await set_all_articles_as_sent(inter, channel, rss_url) await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") @@ -628,6 +649,35 @@ class FeedCog(commands.Cog): await followup(inter, embed=embed) + admin_group = Group( + parent=feed_group, + name="admin", + description="Administration tasks" + ) + + @admin_group.command(name="clear-sent-articles") + async def clear_sent_articles(self, inter: Interaction): + """Clear the database of all sent articles. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.response.defer() + + async with DatabaseManager() as database: + query = delete(SentArticleModel).where(and_( + SentArticleModel.discord_server_id == inter.guild_id + )) + result = await database.session.execute(query) + + await followup(inter, + f"{result.rowcount} sent articles have been cleared from the database. " + "I will no longer recognise these articles as sent, and will send them " + "again if they appear during the next RSS feed scan." + ) async def setup(bot): """