Linked feed channels to sent articles

This commit is contained in:
Corban-Lee Jones 2023-12-21 11:00:02 +00:00
parent 45e5c6a04f
commit 5b8be819e3
4 changed files with 21 additions and 12 deletions

View File

@ -47,8 +47,8 @@ class SentArticleModel(Base):
article_url = Column(String, nullable=False) article_url = Column(String, nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
# feed_channel_id = Column(Integer, ForeignKey("feed_channel.id"), nullable=False) feed_channel_id = Column(Integer, ForeignKey("feed_channel.id", ondelete="CASCADE"), nullable=False)
# feed_channel = relationship("FeedChannelModel", lazy="joined", cascade="all, delete") feed_channel = relationship("FeedChannelModel", overlaps="sent_articles", lazy="joined", cascade="all, delete")
class RssSourceModel(Base): class RssSourceModel(Base):
@ -83,9 +83,11 @@ class FeedChannelModel(Base):
discord_channel_id = Column(BigInteger, nullable=False) discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False)
search_name = Column(String, nullable=False) search_name = Column(String, nullable=False)
# created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False) sent_articles = relationship("SentArticleModel", cascade="all, delete")
rss_source_id = Column(Integer, ForeignKey('rss_source.id', ondelete="CASCADE"), nullable=False)
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete") rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
# the rss source must be unique, but only within the same discord channel # the rss source must be unique, but only within the same discord channel

View File

@ -78,7 +78,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
return None, feed return None, feed
async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str): async def set_all_articles_as_sent(inter, channel: TextChannel, feed_id: int, rss_url: str):
unparsed_feed = await get_unparsed_feed(rss_url) unparsed_feed = await get_unparsed_feed(rss_url)
source = Source.from_parsed(parse(unparsed_feed)) source = Source.from_parsed(parse(unparsed_feed))
articles = source.get_latest_articles() articles = source.get_latest_articles()
@ -90,6 +90,7 @@ async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str):
"discord_channel_id": channel.id, "discord_channel_id": channel.id,
"discord_message_id": -1, "discord_message_id": -1,
"article_url": article.url, "article_url": article.url,
"feed_channel_id": feed_id
} }
for article in articles for article in articles
]) ])
@ -531,10 +532,11 @@ class FeedCog(commands.Cog):
search_name=f"{nick} #{channel.name}" search_name=f"{nick} #{channel.name}"
) )
await database.session.execute(insert_query) insert_result = await database.session.execute(insert_query)
feed_id = insert_result.inserted_primary_key.id
if prevent_spam: if prevent_spam:
await set_all_articles_as_sent(inter, channel, rss_url) await set_all_articles_as_sent(inter, channel, feed_id, rss_url)
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}") await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
@ -679,6 +681,8 @@ class FeedCog(commands.Cog):
"again if they appear during the next RSS feed scan." "again if they appear during the next RSS feed scan."
) )
audit_group
async def setup(bot): async def setup(bot):
""" """
Setup function for this extension. Setup function for this extension.

View File

@ -78,15 +78,17 @@ class TaskCog(commands.Cog):
return return
for article in articles: for article in articles:
await self.process_article(article, channel, database) await self.process_article(feed.id, article, channel, database)
async def process_article( async def process_article(
self, article: Article, channel: TextChannel, database: DatabaseManager self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager
): ):
"""Process the passed article. Will send the embed to a channel if all is valid. """Process the passed article. Will send the embed to a channel if all is valid.
Parameters Parameters
---------- ----------
feed_id : int
The feed model ID, used to log the sent article.
article : Article article : Article
Database model for the article. Database model for the article.
channel : TextChannel channel : TextChannel
@ -99,7 +101,7 @@ class TaskCog(commands.Cog):
query = select(SentArticleModel).where(and_( query = select(SentArticleModel).where(and_(
SentArticleModel.article_url == article.url, SentArticleModel.article_url == article.url,
SentArticleModel.discord_channel_id == channel.id SentArticleModel.discord_channel_id == channel.id,
)) ))
result = await database.session.execute(query) result = await database.session.execute(query)
@ -118,7 +120,8 @@ class TaskCog(commands.Cog):
article_url = article.url, article_url = article.url,
discord_channel_id = channel.id, discord_channel_id = channel.id,
discord_server_id = channel.guild.id, discord_server_id = channel.guild.id,
discord_message_id = -1 discord_message_id = -1,
feed_channel_id = feed_id
) )
await database.session.execute(query) await database.session.execute(query)

View File

@ -156,7 +156,7 @@ class Source:
feed=feed feed=feed
) )
def get_latest_articles(self, max: int) -> list[Article]: def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects. """Returns a list of Article objects.
Parameters Parameters