""" Extension for the `TaskCog`. Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. """ import asyncio import logging import datetime from os import getenv from time import perf_counter from collections import deque import aiohttp from aiocache import Cache from discord import TextChannel from discord import app_commands from discord.ext import commands, tasks from discord.errors import Forbidden from feedparser import parse from feed import RSSFeed, Subscription, RSSItem, GuildSettings from utils import get_unparsed_feed from filters import match_text from api import API log = logging.getLogger(__name__) cache = Cache(Cache.MEMORY) BATCH_SIZE = 100 TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") subscription_task_times = [ datetime.time(hour, minute, tzinfo=datetime.timezone.utc) for hour in range(24) for minute in range(0, 60, int(TASK_INTERVAL_MINUTES)) ] log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES) class TaskCog(commands.Cog): """ Tasks cog for PYRSS. """ api: API | None = None content_queue = deque() def __init__(self, bot): super().__init__() self.bot = bot @commands.Cog.listener() async def on_ready(self): """ Instructions to execute when the cog is ready. """ self.subscription_task.start() log.info("%s cog is ready", self.__class__.__name__) @commands.Cog.listener(name="cog_unload") async def on_unload(self): """ Instructions to execute before the cog is unloaded. """ self.subscription_task.cancel() group = app_commands.Group( name="task", description="Commands for tasks", guild_only=True ) @group.command(name="trigger") async def cmd_trigger_task(self, inter): await inter.response.defer() start_time = perf_counter() try: await self.subscription_task() except Exception as error: await inter.followup.send(str(error)) finally: end_time = perf_counter() await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") @tasks.loop(time=subscription_task_times) async def subscription_task(self): """ Task for fetching and processing subscriptions. """ log.info("Running subscription task") start_time = perf_counter() async with aiohttp.ClientSession() as session: self.api = API(self.bot.api_token, session) await self.execute_task() end_time = perf_counter() log.debug(f"task completed in {end_time - start_time:.4f} seconds") async def execute_task(self): """Execute the task directly.""" # Filter out inactive guild IDs using related settings guild_ids = [guild.id for guild in self.bot.guilds] guild_settings = await self.get_guild_settings(guild_ids) active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active] subscriptions = await self.get_subscriptions(active_guild_ids) await self.process_subscriptions(subscriptions) async def get_guild_settings(self, guild_ids: list[int]) -> list[int]: """Returns a list of guild settings from the Bot's guilds, if they exist.""" guild_settings = [] # Iterate infinitely taking the iter no. as `page` # data will be empty after last page reached. for page, _ in enumerate(iter(int, 1)): data = await self.get_guild_settings_page(guild_ids, page) if not data: break guild_settings.extend(data[0]) # Only return active guild IDs return GuildSettings.from_list(guild_settings) async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]: """Returns an individual page of guild settings.""" try: return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1) except aiohttp.ClientResponseError as error: self.handle_pagination_error(error) return [] def handle_pagination_error(self, error: aiohttp.ClientResponseError): """Handle the error cases from pagination attempts.""" match error.status: case 404: log.debug("final page reached") case 403: log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True) self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently case _: log.debug(error) async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]: """Get a list of `Subscription`s matching the given `guild_ids`.""" subscriptions = [] # Iterate infinitely taking the iter no. as `page` # data will be empty after last page reached. for page, _ in enumerate(iter(int, 1)): data = await self.get_subs_page(guild_ids, page) if not data: break subscriptions.extend(data[0]) return Subscription.from_list(subscriptions) async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]: """Returns an individual page of subscriptions.""" try: return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1) except aiohttp.ClientResponseError as error: self.handle_pagination_error(error) return [] async def process_subscriptions(self, subscriptions: list[Subscription]): """Process a given list of `Subscription`s.""" async def process_single_subscription(sub: Subscription): log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id) if not sub.active or not sub.channels_count: return unparsed_feed = await get_unparsed_feed(sub.url) parsed_feed = parse(unparsed_feed) rss_feed = RSSFeed.from_parsed_feed(parsed_feed) await self.process_items(sub, rss_feed) semaphore = asyncio.Semaphore(10) async def semaphore_process(sub: Subscription): async with semaphore: await process_single_subscription(sub) await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions)) async def process_items(self, sub: Subscription, feed: RSSFeed): log.debug("processing items") channels = await self.fetch_or_get_channels(await sub.get_channels(self.api)) filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters] for item in feed.items: log.debug("processing item '%s'", item.guid) if item.pub_date < sub.published_threshold: log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold) continue blocked = any(self.filter_item(_filter, item) for _filter in filters) mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None for channel in channels: await self.track_and_send(sub, feed, item, mutated_item, channel, blocked) async def fetch_or_get_channels(self, channels_data: list[dict]): channels = [] for data in channels_data: try: channel = self.bot.get_channel(data.channel_id) channels.append(channel or await self.bot.fetch_channel(data.channel_id)) except Forbidden: log.error(f"Forbidden Channel '{data.channel_id}'") return channels def filter_item(self, _filter: dict, item: RSSItem) -> bool: """ Returns `True` if item should be ignored due to filters. """ match_found = match_text(_filter, item.title) or match_text(_filter, item.description) log.debug("filter match found? '%s'", match_found) return match_found async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool): message_id = -1 log.debug("track and send func %s, %s", item.guid, item.title) result = await self.api.get_tracked_content(guid=item.guid) if result[1]: log.debug(f"This item is already tracked, skipping '{item.guid}'") return result = await self.api.get_tracked_content(url=item.link) if result[1]: log.debug(f"This item is already tracked, skipping '{item.guid}'") return if not blocked: try: log.debug("sending '%s', exists '%s'", item.guid, result[1]) sendable_item = mutated_item or item message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session)) message_id = message.id except Forbidden: log.error(f"Forbidden to send to channel {channel.id}") await self.mark_tracked_item(sub, item, channel.id, message_id, blocked) async def process_batch(self): pass async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): try: log.debug("marking as tracked") await self.api.create_tracked_content( guid=item.guid, title=item.title, url=item.link, subscription=sub.id, channel_id=channel_id, message_id=message_id, blocked=blocked ) return True except aiohttp.ClientResponseError as error: if error.status == 409: log.debug(error) else: log.error(error) return False async def setup(bot): """ Setup function for this extension. Adds `TaskCog` to the bot. """ cog = TaskCog(bot) await bot.add_cog(cog) log.info("Added %s cog", cog.__class__.__name__)