From 5b8be819e3827206a2975ef6790f831c66baf6fd Mon Sep 17 00:00:00 2001 From: Corban-Lee Date: Thu, 21 Dec 2023 11:00:02 +0000 Subject: [PATCH] Linked feed channels to sent articles --- src/db/models.py | 10 ++++++---- src/extensions/rss.py | 10 +++++++--- src/extensions/tasks.py | 11 +++++++---- src/feed.py | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/db/models.py b/src/db/models.py index 7a5b58e..13bd3e6 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -47,8 +47,8 @@ class SentArticleModel(Base): article_url = Column(String, nullable=False) when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - # feed_channel_id = Column(Integer, ForeignKey("feed_channel.id"), nullable=False) - # feed_channel = relationship("FeedChannelModel", lazy="joined", cascade="all, delete") + feed_channel_id = Column(Integer, ForeignKey("feed_channel.id", ondelete="CASCADE"), nullable=False) + feed_channel = relationship("FeedChannelModel", overlaps="sent_articles", lazy="joined", cascade="all, delete") class RssSourceModel(Base): @@ -83,9 +83,11 @@ class FeedChannelModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) search_name = Column(String, nullable=False) - # created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) + sent_articles = relationship("SentArticleModel", cascade="all, delete") + + rss_source_id = Column(Integer, ForeignKey('rss_source.id', ondelete="CASCADE"), nullable=False) rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete") # the rss source must be unique, but only within the same discord channel diff --git a/src/extensions/rss.py b/src/extensions/rss.py index e9914b5..0866127 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -78,7 +78,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed return None, feed -async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str): +async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str): unparsed_feed = await get_unparsed_feed(rss_url) source = Source.from_parsed(parse(unparsed_feed)) articles = source.get_latest_articles() @@ -90,6 +90,7 @@ async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str): "discord_channel_id": channel.id, "discord_message_id": -1, "article_url": article.url, + "feed_channel_id": feed_id } for article in articles ]) @@ -531,10 +532,11 @@ class FeedCog(commands.Cog): search_name=f"{nick} #{channel.name}" ) - await database.session.execute(insert_query) + insert_result = await database.session.execute(insert_query) + feed_id = insert_result.inserted_primary_key.id if prevent_spam: - await set_all_articles_as_sent(inter, channel, rss_url) + await set_all_articles_as_sent(inter, channel, feed_id, rss_url) await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") @@ -679,6 +681,8 @@ class FeedCog(commands.Cog): "again if they appear during the next RSS feed scan." ) + audit_group + async def setup(bot): """ Setup function for this extension. diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 2318efe..bb77533 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -78,15 +78,17 @@ class TaskCog(commands.Cog): return for article in articles: - await self.process_article(article, channel, database) + await self.process_article(feed.id, article, channel, database) async def process_article( - self, article: Article, channel: TextChannel, database: DatabaseManager + self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager ): """Process the passed article. Will send the embed to a channel if all is valid. Parameters ---------- + feed_id : int + The feed model ID, used to log the sent article. article : Article Database model for the article. channel : TextChannel @@ -99,7 +101,7 @@ class TaskCog(commands.Cog): query = select(SentArticleModel).where(and_( SentArticleModel.article_url == article.url, - SentArticleModel.discord_channel_id == channel.id + SentArticleModel.discord_channel_id == channel.id, )) result = await database.session.execute(query) @@ -118,7 +120,8 @@ class TaskCog(commands.Cog): article_url = article.url, discord_channel_id = channel.id, discord_server_id = channel.guild.id, - discord_message_id = -1 + discord_message_id = -1, + feed_channel_id = feed_id ) await database.session.execute(query) diff --git a/src/feed.py b/src/feed.py index 0499e60..dacb0de 100644 --- a/src/feed.py +++ b/src/feed.py @@ -156,7 +156,7 @@ class Source: feed=feed ) - def get_latest_articles(self, max: int) -> list[Article]: + def get_latest_articles(self, max: int = 999) -> list[Article]: """Returns a list of Article objects. Parameters