233 lines
7.6 KiB
Python

"""
Extension for the `TaskCog`.
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
"""
import re
import json
import asyncio
import logging
import datetime
import urllib.parse
from os import getenv
from time import perf_counter
from collections import deque
import aiohttp
from aiocache import Cache
from discord import TextChannel, Embed, Colour
from discord import app_commands
from discord.ext import commands, tasks
from discord.errors import Forbidden
from feedparser import parse
from feed import RSSFeed, Subscription, RSSItem
from utils import get_unparsed_feed
from filters import match_text
from api import API
log = logging.getLogger(__name__)
cache = Cache(Cache.MEMORY)
BATCH_SIZE = 100
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
subscription_task_times = [
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
for hour in range(24)
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
]
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
class TaskCog(commands.Cog):
"""
Tasks cog for PYRSS.
"""
api: API | None = None
content_queue = deque()
def __init__(self, bot):
super().__init__()
self.bot = bot
@commands.Cog.listener()
async def on_ready(self):
"""
Instructions to execute when the cog is ready.
"""
self.subscription_task.start()
log.info("%s cog is ready", self.__class__.__name__)
@commands.Cog.listener(name="cog_unload")
async def on_unload(self):
"""
Instructions to execute before the cog is unloaded.
"""
self.subscription_task.cancel()
@app_commands.command(name="debug-trigger-task")
async def debug_trigger_task(self, inter):
await inter.response.defer()
start_time = perf_counter()
await self.subscription_task()
end_time = perf_counter()
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
@tasks.loop(time=subscription_task_times)
async def subscription_task(self):
"""
Task for fetching and processing subscriptions.
"""
log.info("Running subscription task")
start_time = perf_counter()
async with aiohttp.ClientSession() as session:
self.api = API(self.bot.api_token, session)
subscriptions = await self.get_subscriptions()
await self.process_subscriptions(subscriptions)
end_time = perf_counter()
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
async def get_subscriptions(self) -> list[Subscription]:
guild_ids = [guild.id for guild in self.bot.guilds]
sub_data = []
for page, _ in enumerate(iter(int, 1)):
try:
log.debug("fetching page '%s'", page + 1)
sub_data.extend(
(await self.api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
)
except aiohttp.ClientResponseError as error:
match error.status:
case 404:
log.debug("final page reached '%s'", page)
break
case 403:
log.critical(error)
self.subscription_task.cancel()
return [] # returning an empty list should gracefully end the task
case _:
log.error(error)
break
except Exception as error:
log.error("Exception while gathering page data %s", error)
break
return Subscription.from_list(sub_data)
async def process_subscriptions(self, subscriptions: list[Subscription]):
async def process_single_subscription(sub: Subscription):
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
if not sub.active or not sub.channels_count:
return
unparsed_feed = await get_unparsed_feed(sub.url)
parsed_feed = parse(unparsed_feed)
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
await self.process_items(sub, rss_feed)
semaphore = asyncio.Semaphore(10)
async def semaphore_process(sub: Subscription):
async with semaphore:
await process_single_subscription(sub)
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
async def process_items(self, sub: Subscription, feed: RSSFeed):
log.debug("processing items")
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(self.api)]
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
for item in feed.items:
log.debug("processing item '%s'", item.guid)
if item.pub_date < sub.published_threshold:
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
continue
blocked = any(self.filter_item(_filter, item) for _filter in filters)
mutated_item = item.create_mutated_copy(sub.mutators)
for channel in channels:
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
"""
Returns `True` if item should be ignored due to filters.
"""
match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
log.debug("filter match found? '%s'", match_found)
return match_found
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem, channel: TextChannel, blocked: bool):
message_id = -1
log.debug("track and send func %s, %s", item.guid, item.title)
result = await self.api.get_tracked_content(guid=item.guid)
if result[1]:
log.debug(f"This item is already tracked, skipping '{item.guid}'")
return
result = await self.api.get_tracked_content(url=item.link)
if result[1]:
log.debug(f"This item is already tracked, skipping '{item.guid}'")
return
if not blocked:
try:
log.debug("sending '%s', exists '%s'", item.guid, result[1])
message = await channel.send(embed=await mutated_item.to_embed(sub, feed, self.api.session))
message_id = message.id
except Forbidden as error:
log.error(error)
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
async def process_batch(self):
pass
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
try:
log.debug("marking as tracked")
await self.api.create_tracked_content(
guid=item.guid,
title=item.title,
url=item.link,
subscription=sub.id,
channel_id=channel_id,
message_id=message_id,
blocked=blocked
)
return True
except aiohttp.ClientResponseError as error:
if error.status == 409:
log.debug(error)
else:
log.error(error)
return False
async def setup(bot):
"""
Setup function for this extension.
Adds `TaskCog` to the bot.
"""
cog = TaskCog(bot)
await bot.add_cog(cog)
log.info("Added %s cog", cog.__class__.__name__)