167 lines
4.9 KiB
Python

"""
Extension for the `TaskCog`.
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
"""
import logging
import datetime
from os import getenv
from time import process_time
from discord import TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from sqlalchemy import insert, select, and_
from feedparser import parse
from feed import Source, Article
from db import (
DatabaseManager,
FeedChannelModel,
RssSourceModel,
SentArticleModel
)
from utils import get_unparsed_feed
log = logging.getLogger(__name__)
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
times = [
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
for hour in range(24)
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
]
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
class TaskCog(commands.Cog):
"""
Tasks cog.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
self.time = None
@commands.Cog.listener()
async def on_ready(self):
"""Instructions to execute when the cog is ready."""
if not self.bot.developing:
self.rss_task.start()
log.info("%s cog is ready", self.__class__.__name__)
@commands.Cog.listener(name="cog_unload")
async def on_unload(self):
"""Instructions to execute before the cog is unloaded."""
self.rss_task.cancel()
@tasks.loop(minutes=10)
async def rss_task(self):
"""Automated task responsible for processing rss feeds."""
log.info("Running rss task")
time = process_time()
async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
result = await database.session.execute(query)
feeds = result.scalars().all()
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task, time elapsed: %s", process_time() - time)
async def process_feed(self, feed: FeedChannelModel, database: DatabaseManager):
"""Process the passed feed. Will also call process for each article found in the feed.
Parameters
----------
feed : FeedChannelModel
Database model for the feed.
database : DatabaseManager
Database connection handler, must be open.
"""
log.debug("Processing feed: %s", feed.id)
channel = self.bot.get_channel(feed.discord_channel_id)
# TODO: integrate the `validate_feed` code into here, also do on list command and show errors.
unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url)
parsed_feed = parse(unparsed_content)
source = Source.from_parsed(parsed_feed)
articles = source.get_latest_articles(5)
if not articles:
log.info("No articles to process")
return
for article in articles:
await self.process_article(feed.id, article, channel, database)
async def process_article(
self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager
):
"""Process the passed article. Will send the embed to a channel if all is valid.
Parameters
----------
feed_id : int
The feed model ID, used to log the sent article.
article : Article
Database model for the article.
channel : TextChannel
Where the article will be sent to.
database : DatabaseManager
Database connection handler, must be open.
"""
log.debug("Processing article: %s", article.url)
query = select(SentArticleModel).where(and_(
SentArticleModel.article_url == article.url,
SentArticleModel.discord_channel_id == channel.id,
))
result = await database.session.execute(query)
if result.scalars().all():
log.debug("Article already processed: %s", article.url)
return
embed = await article.to_embed()
try:
await channel.send(embed=embed)
except Forbidden:
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
return
query = insert(SentArticleModel).values(
article_url = article.url,
discord_channel_id = channel.id,
discord_server_id = channel.guild.id,
discord_message_id = -1,
feed_channel_id = feed_id
)
await database.session.execute(query)
log.debug("new Article processed: %s", article.url)
async def setup(bot):
"""
Setup function for this extension.
Adds `TaskCog` to the bot.
"""
cog = TaskCog(bot)
await bot.add_cog(cog)
log.info("Added %s cog", cog.__class__.__name__)