299 lines
10 KiB
Python
299 lines
10 KiB
Python
"""
|
|
Extension for the `TaskCog`.
|
|
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import datetime
|
|
from os import getenv
|
|
from time import perf_counter
|
|
from collections import deque
|
|
|
|
import aiohttp
|
|
from aiocache import Cache
|
|
from discord import TextChannel
|
|
from discord import app_commands
|
|
from discord.ext import commands, tasks
|
|
from discord.errors import Forbidden
|
|
from feedparser import parse
|
|
|
|
from feed import RSSFeed, Subscription, RSSItem, GuildSettings
|
|
from utils import get_unparsed_feed
|
|
from filters import match_text
|
|
from api import API
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
cache = Cache(Cache.MEMORY)
|
|
|
|
BATCH_SIZE = 100
|
|
|
|
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
|
|
subscription_task_times = [
|
|
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
|
|
for hour in range(24)
|
|
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
|
]
|
|
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
|
|
|
|
|
class TaskCog(commands.Cog):
|
|
"""
|
|
Tasks cog for PYRSS.
|
|
"""
|
|
|
|
api: API | None = None
|
|
content_queue = deque()
|
|
|
|
def __init__(self, bot):
|
|
super().__init__()
|
|
self.bot = bot
|
|
|
|
@commands.Cog.listener()
|
|
async def on_ready(self):
|
|
"""
|
|
Instructions to execute when the cog is ready.
|
|
"""
|
|
self.subscription_task.start()
|
|
log.info("%s cog is ready", self.__class__.__name__)
|
|
|
|
@commands.Cog.listener(name="cog_unload")
|
|
async def on_unload(self):
|
|
"""
|
|
Instructions to execute before the cog is unloaded.
|
|
"""
|
|
self.subscription_task.cancel()
|
|
|
|
group = app_commands.Group(
|
|
name="task",
|
|
description="Commands for tasks",
|
|
guild_only=True
|
|
)
|
|
|
|
@group.command(name="trigger")
|
|
async def cmd_trigger_task(self, inter):
|
|
await inter.response.defer()
|
|
start_time = perf_counter()
|
|
|
|
try:
|
|
await self.subscription_task()
|
|
except Exception as error:
|
|
await inter.followup.send(str(error))
|
|
finally:
|
|
end_time = perf_counter()
|
|
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
|
|
|
|
@tasks.loop(time=subscription_task_times)
|
|
async def subscription_task(self):
|
|
"""
|
|
Task for fetching and processing subscriptions.
|
|
"""
|
|
log.info("Running subscription task")
|
|
start_time = perf_counter()
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
self.api = API(self.bot.api_token, session)
|
|
await self.execute_task()
|
|
|
|
end_time = perf_counter()
|
|
log.debug(f"task completed in {end_time - start_time:.4f} seconds")
|
|
|
|
async def execute_task(self):
|
|
"""Execute the task directly."""
|
|
|
|
# Filter out inactive guild IDs using related settings
|
|
guild_ids = [guild.id for guild in self.bot.guilds]
|
|
guild_settings = await self.get_guild_settings(guild_ids)
|
|
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
|
|
|
|
subscriptions = await self.get_subscriptions(active_guild_ids)
|
|
await self.process_subscriptions(subscriptions)
|
|
|
|
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
|
|
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
|
|
|
|
guild_settings = []
|
|
|
|
# Iterate infinitely taking the iter no. as `page`
|
|
# data will be empty after last page reached.
|
|
for page, _ in enumerate(iter(int, 1)):
|
|
data = await self.get_guild_settings_page(guild_ids, page)
|
|
if not data:
|
|
break
|
|
|
|
guild_settings.extend(data[0])
|
|
|
|
# Only return active guild IDs
|
|
return GuildSettings.from_list(guild_settings)
|
|
|
|
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
|
|
"""Returns an individual page of guild settings."""
|
|
|
|
try:
|
|
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
|
|
except aiohttp.ClientResponseError as error:
|
|
self.handle_pagination_error(error)
|
|
return []
|
|
|
|
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
|
|
"""Handle the error cases from pagination attempts."""
|
|
|
|
match error.status:
|
|
case 404:
|
|
log.debug("final page reached")
|
|
case 403:
|
|
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
|
|
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
|
|
case _:
|
|
log.debug(error)
|
|
|
|
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
|
|
"""Get a list of `Subscription`s matching the given `guild_ids`."""
|
|
|
|
subscriptions = []
|
|
|
|
# Iterate infinitely taking the iter no. as `page`
|
|
# data will be empty after last page reached.
|
|
for page, _ in enumerate(iter(int, 1)):
|
|
data = await self.get_subs_page(guild_ids, page)
|
|
if not data:
|
|
break
|
|
|
|
subscriptions.extend(data[0])
|
|
|
|
return Subscription.from_list(subscriptions)
|
|
|
|
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
|
|
"""Returns an individual page of subscriptions."""
|
|
|
|
try:
|
|
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
|
|
except aiohttp.ClientResponseError as error:
|
|
self.handle_pagination_error(error)
|
|
return []
|
|
|
|
async def process_subscriptions(self, subscriptions: list[Subscription]):
|
|
"""Process a given list of `Subscription`s."""
|
|
|
|
async def process_single_subscription(sub: Subscription):
|
|
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
|
|
|
|
if not sub.active or not sub.channels_count:
|
|
return
|
|
|
|
unparsed_feed = await get_unparsed_feed(sub.url)
|
|
parsed_feed = parse(unparsed_feed)
|
|
|
|
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
|
await self.process_items(sub, rss_feed)
|
|
|
|
semaphore = asyncio.Semaphore(10)
|
|
|
|
async def semaphore_process(sub: Subscription):
|
|
async with semaphore:
|
|
await process_single_subscription(sub)
|
|
|
|
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
|
|
|
|
async def process_items(self, sub: Subscription, feed: RSSFeed):
|
|
log.debug("processing items")
|
|
|
|
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
|
|
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
|
|
|
|
for item in feed.items:
|
|
log.debug("processing item '%s'", item.guid)
|
|
|
|
if item.pub_date < sub.published_threshold:
|
|
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
|
|
continue
|
|
|
|
blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
|
mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
|
|
|
|
for channel in channels:
|
|
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
|
|
|
|
async def fetch_or_get_channels(self, channels_data: list[dict]):
|
|
channels = []
|
|
|
|
for data in channels_data:
|
|
try:
|
|
channel = self.bot.get_channel(data.channel_id)
|
|
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
|
|
except Forbidden:
|
|
log.error(f"Forbidden Channel '{data.channel_id}'")
|
|
|
|
return channels
|
|
|
|
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
|
"""
|
|
Returns `True` if item should be ignored due to filters.
|
|
"""
|
|
|
|
match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
|
|
log.debug("filter match found? '%s'", match_found)
|
|
return match_found
|
|
|
|
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
|
|
message_id = -1
|
|
|
|
log.debug("track and send func %s, %s", item.guid, item.title)
|
|
|
|
result = await self.api.get_tracked_content(guid=item.guid)
|
|
if result[1]:
|
|
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
|
return
|
|
|
|
result = await self.api.get_tracked_content(url=item.link)
|
|
if result[1]:
|
|
log.debug(f"This item is already tracked, skipping '{item.guid}'")
|
|
return
|
|
|
|
if not blocked:
|
|
try:
|
|
log.debug("sending '%s', exists '%s'", item.guid, result[1])
|
|
sendable_item = mutated_item or item
|
|
message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
|
|
message_id = message.id
|
|
except Forbidden:
|
|
log.error(f"Forbidden to send to channel {channel.id}")
|
|
|
|
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
|
|
|
|
async def process_batch(self):
|
|
pass
|
|
|
|
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
|
|
try:
|
|
log.debug("marking as tracked")
|
|
await self.api.create_tracked_content(
|
|
guid=item.guid,
|
|
title=item.title,
|
|
url=item.link,
|
|
subscription=sub.id,
|
|
channel_id=channel_id,
|
|
message_id=message_id,
|
|
blocked=blocked
|
|
)
|
|
return True
|
|
except aiohttp.ClientResponseError as error:
|
|
if error.status == 409:
|
|
log.debug(error)
|
|
else:
|
|
log.error(error)
|
|
|
|
return False
|
|
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Setup function for this extension.
|
|
Adds `TaskCog` to the bot.
|
|
"""
|
|
|
|
cog = TaskCog(bot)
|
|
await bot.add_cog(cog)
|
|
log.info("Added %s cog", cog.__class__.__name__)
|