reworking task
Some checks failed
Build and Push Docker Image / build (push) Failing after 6m55s

This commit is contained in:
Corban-Lee Jones 2024-10-29 23:44:49 +00:00
parent cc06d3e09f
commit ccfa35adda
2 changed files with 245 additions and 159 deletions

View File

@ -2,16 +2,21 @@ aiocache==0.12.2
aiohttp==3.9.3 aiohttp==3.9.3
aiosignal==1.3.1 aiosignal==1.3.1
aiosqlite==0.19.0 aiosqlite==0.19.0
anyio==4.6.2.post1
async-timeout==4.0.3 async-timeout==4.0.3
asyncpg==0.29.0 asyncpg==0.29.0
attrs==23.2.0 attrs==23.2.0
beautifulsoup4==4.12.3 beautifulsoup4==4.12.3
bump2version==1.0.1 bump2version==1.0.1
certifi==2024.8.30
click==8.1.7 click==8.1.7
discord.py==2.3.2 discord.py==2.3.2
feedparser==6.0.11 feedparser==6.0.11
frozenlist==1.4.1 frozenlist==1.4.1
greenlet==3.0.3 greenlet==3.0.3
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
idna==3.6 idna==3.6
markdownify==0.11.6 markdownify==0.11.6
multidict==6.0.5 multidict==6.0.5
@ -21,6 +26,7 @@ python-dotenv==1.0.0
rapidfuzz==3.9.4 rapidfuzz==3.9.4
sgmllib3k==1.0.0 sgmllib3k==1.0.0
six==1.16.0 six==1.16.0
sniffio==1.3.1
soupsieve==2.5 soupsieve==2.5
SQLAlchemy==2.0.23 SQLAlchemy==2.0.23
typing_extensions==4.10.0 typing_extensions==4.10.0

View File

@ -3,6 +3,7 @@ Extension for the `TaskCog`.
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
""" """
import json
import asyncio import asyncio
import logging import logging
import datetime import datetime
@ -11,9 +12,10 @@ from time import perf_counter
from collections import deque from collections import deque
import aiohttp import aiohttp
import httpx
from aiocache import Cache from aiocache import Cache
from discord import TextChannel from discord import TextChannel
from discord import app_commands from discord import app_commands, Interaction
from discord.ext import commands, tasks from discord.ext import commands, tasks
from discord.errors import Forbidden from discord.errors import Forbidden
from feedparser import parse from feedparser import parse
@ -55,7 +57,7 @@ class TaskCog(commands.Cog):
""" """
Instructions to execute when the cog is ready. Instructions to execute when the cog is ready.
""" """
self.subscription_task.start() # self.subscription_task.start()
log.info("%s cog is ready", self.__class__.__name__) log.info("%s cog is ready", self.__class__.__name__)
@commands.Cog.listener(name="cog_unload") @commands.Cog.listener(name="cog_unload")
@ -72,219 +74,297 @@ class TaskCog(commands.Cog):
) )
@group.command(name="trigger") @group.command(name="trigger")
async def cmd_trigger_task(self, inter): async def cmd_trigger_task(self, inter: Interaction):
await inter.response.defer() await inter.response.defer()
start_time = perf_counter() start_time = perf_counter()
try: try:
await self.subscription_task() await self.do_task()
except Exception as error: except Exception as error:
log.error(error.with_traceback())
await inter.followup.send(str(error)) await inter.followup.send(str(error))
finally: finally:
end_time = perf_counter() end_time = perf_counter()
await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds")
@tasks.loop(time=subscription_task_times) async def do_task(self):
async def subscription_task(self): log.info("Running task")
"""
Task for fetching and processing subscriptions.
"""
log.info("Running subscription task")
start_time = perf_counter() start_time = perf_counter()
async with aiohttp.ClientSession() as session: client_options = {
self.api = API(self.bot.api_token, session) "base_url": "http://localhost:8000/api/",
await self.execute_task() "headers": {"Authorization": f"Token {self.bot.api_token}"}
}
async with httpx.AsyncClient(**client_options) as client:
servers = await self.get_servers(client)
for server in servers:
await self.process_server(server, client)
end_time = perf_counter() end_time = perf_counter()
log.debug(f"task completed in {end_time - start_time:.4f} seconds") log.debug(f"completed task in {end_time - start_time:.4f} seconds")
async def execute_task(self): async def iterate_pages(self, client: httpx.AsyncClient, url: str, params: dict={}):
"""Execute the task directly."""
# Filter out inactive guild IDs using related settings for page_number, _ in enumerate(iterable=iter(int, 1), start=1):
guild_ids = [guild.id for guild in self.bot.guilds] params.update({"page": page_number})
guild_settings = await self.get_guild_settings(guild_ids) response = await client.get(url, params=params)
active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active] response.raise_for_status()
content = response.json()
subscriptions = await self.get_subscriptions(active_guild_ids) yield content.get("results", [])
await self.process_subscriptions(subscriptions)
async def get_guild_settings(self, guild_ids: list[int]) -> list[int]: if not content.get("next"):
"""Returns a list of guild settings from the Bot's guilds, if they exist."""
guild_settings = []
# Iterate infinitely taking the iter no. as `page`
# data will be empty after last page reached.
for page, _ in enumerate(iter(int, 1)):
data = await self.get_guild_settings_page(guild_ids, page)
if not data:
break break
guild_settings.extend(data[0]) async def get_servers(self, client: httpx.AsyncClient) -> list[dict]:
servers = []
# Only return active guild IDs async for servers_batch in self.iterate_pages(client, "servers/"):
return GuildSettings.from_list(guild_settings) if servers_batch:
servers.extend(servers_batch)
async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]: return servers
"""Returns an individual page of guild settings."""
try:
return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
except aiohttp.ClientResponseError as error:
self.handle_pagination_error(error)
return []
def handle_pagination_error(self, error: aiohttp.ClientResponseError):
"""Handle the error cases from pagination attempts."""
match error.status:
case 404:
log.debug("final page reached")
case 403:
log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
case _:
log.debug(error)
async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
"""Get a list of `Subscription`s matching the given `guild_ids`."""
async def get_subscriptions(self, server: dict, client: httpx.AsyncClient) -> list[dict]:
subscriptions = [] subscriptions = []
params = {"server": server.get("id")}
# Iterate infinitely taking the iter no. as `page` async for subscriptions_batch in self.iterate_pages(client, "subscriptions/", params):
# data will be empty after last page reached. if subscriptions_batch:
for page, _ in enumerate(iter(int, 1)): subscriptions.extend(subscriptions_batch)
data = await self.get_subs_page(guild_ids, page)
if not data:
break
subscriptions.extend(data[0]) return subscriptions
return Subscription.from_list(subscriptions) async def process_server(self, server: dict, client: httpx.AsyncClient):
log.debug(json.dumps(server, indent=4))
async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]: subscriptions = await self.get_subscriptions(server, client)
"""Returns an individual page of subscriptions.""" for subscription in subscriptions:
await self.process_subscription(subscription, client)
try: async def process_subscription(self, subscription: dict, client: httpx.AsyncClient):
return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1) log.debug(json.dumps(subscription, indent=4))
except aiohttp.ClientResponseError as error:
self.handle_pagination_error(error)
return []
async def process_subscriptions(self, subscriptions: list[Subscription]): for content_filter in subscription.get("filters_detail"):
"""Process a given list of `Subscription`s.""" pass
async def process_single_subscription(sub: Subscription):
log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
if not sub.active or not sub.channels_count:
return
unparsed_feed = await get_unparsed_feed(sub.url) # @group.command(name="trigger")
parsed_feed = parse(unparsed_feed) # async def cmd_trigger_task(self, inter):
# await inter.response.defer()
# start_time = perf_counter()
rss_feed = RSSFeed.from_parsed_feed(parsed_feed) # try:
await self.process_items(sub, rss_feed) # await self.subscription_task()
# except Exception as error:
# await inter.followup.send(str(error))
# finally:
# end_time = perf_counter()
# await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds")
semaphore = asyncio.Semaphore(10) # @tasks.loop(time=subscription_task_times)
# async def subscription_task(self):
# """
# Task for fetching and processing subscriptions.
# """
# log.info("Running subscription task")
# start_time = perf_counter()
async def semaphore_process(sub: Subscription): # async with aiohttp.ClientSession() as session:
async with semaphore: # self.api = API(self.bot.api_token, session)
await process_single_subscription(sub) # await self.execute_task()
await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions)) # end_time = perf_counter()
# log.debug(f"task completed in {end_time - start_time:.4f} seconds")
async def process_items(self, sub: Subscription, feed: RSSFeed): # async def execute_task(self):
log.debug("processing items") # """Execute the task directly."""
channels = await self.fetch_or_get_channels(await sub.get_channels(self.api)) # # Filter out inactive guild IDs using related settings
filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters] # guild_ids = [guild.id for guild in self.bot.guilds]
# guild_settings = await self.get_guild_settings(guild_ids)
# active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active]
for item in feed.items: # subscriptions = await self.get_subscriptions(active_guild_ids)
log.debug("processing item '%s'", item.guid) # await self.process_subscriptions(subscriptions)
if item.pub_date < sub.published_threshold: # async def get_guild_settings(self, guild_ids: list[int]) -> list[int]:
log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold) # """Returns a list of guild settings from the Bot's guilds, if they exist."""
continue
blocked = any(self.filter_item(_filter, item) for _filter in filters) # guild_settings = []
mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
for channel in channels: # # Iterate infinitely taking the iter no. as `page`
await self.track_and_send(sub, feed, item, mutated_item, channel, blocked) # # data will be empty after last page reached.
# for page, _ in enumerate(iter(int, 1)):
# data = await self.get_guild_settings_page(guild_ids, page)
# if not data:
# break
async def fetch_or_get_channels(self, channels_data: list[dict]): # guild_settings.extend(data[0])
channels = []
for data in channels_data: # # Only return active guild IDs
try: # return GuildSettings.from_list(guild_settings)
channel = self.bot.get_channel(data.channel_id)
channels.append(channel or await self.bot.fetch_channel(data.channel_id))
except Forbidden:
log.error(f"Forbidden Channel '{data.channel_id}'")
return channels # async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]:
# """Returns an individual page of guild settings."""
def filter_item(self, _filter: dict, item: RSSItem) -> bool: # try:
""" # return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1)
Returns `True` if item should be ignored due to filters. # except aiohttp.ClientResponseError as error:
""" # self.handle_pagination_error(error)
# return []
match_found = match_text(_filter, item.title) or match_text(_filter, item.description) # def handle_pagination_error(self, error: aiohttp.ClientResponseError):
log.debug("filter match found? '%s'", match_found) # """Handle the error cases from pagination attempts."""
return match_found
async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool): # match error.status:
message_id = -1 # case 404:
# log.debug("final page reached")
# case 403:
# log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True)
# self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently
# case _:
# log.debug(error)
log.debug("track and send func %s, %s", item.guid, item.title) # async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]:
# """Get a list of `Subscription`s matching the given `guild_ids`."""
result = await self.api.get_tracked_content(guid=item.guid) # subscriptions = []
if result[1]:
log.debug(f"This item is already tracked, skipping '{item.guid}'")
return
result = await self.api.get_tracked_content(url=item.link) # # Iterate infinitely taking the iter no. as `page`
if result[1]: # # data will be empty after last page reached.
log.debug(f"This item is already tracked, skipping '{item.guid}'") # for page, _ in enumerate(iter(int, 1)):
return # data = await self.get_subs_page(guild_ids, page)
# if not data:
# break
if not blocked: # subscriptions.extend(data[0])
try:
log.debug("sending '%s', exists '%s'", item.guid, result[1])
sendable_item = mutated_item or item
message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
message_id = message.id
except Forbidden:
log.error(f"Forbidden to send to channel {channel.id}")
await self.mark_tracked_item(sub, item, channel.id, message_id, blocked) # return Subscription.from_list(subscriptions)
async def process_batch(self): # async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]:
pass # """Returns an individual page of subscriptions."""
async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): # try:
try: # return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1)
log.debug("marking as tracked") # except aiohttp.ClientResponseError as error:
await self.api.create_tracked_content( # self.handle_pagination_error(error)
guid=item.guid, # return []
title=item.title,
url=item.link,
subscription=sub.id,
channel_id=channel_id,
message_id=message_id,
blocked=blocked
)
return True
except aiohttp.ClientResponseError as error:
if error.status == 409:
log.debug(error)
else:
log.error(error)
return False # async def process_subscriptions(self, subscriptions: list[Subscription]):
# """Process a given list of `Subscription`s."""
# async def process_single_subscription(sub: Subscription):
# log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id)
# if not sub.active or not sub.channels_count:
# return
# unparsed_feed = await get_unparsed_feed(sub.url)
# parsed_feed = parse(unparsed_feed)
# rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
# await self.process_items(sub, rss_feed)
# semaphore = asyncio.Semaphore(10)
# async def semaphore_process(sub: Subscription):
# async with semaphore:
# await process_single_subscription(sub)
# await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions))
# async def process_items(self, sub: Subscription, feed: RSSFeed):
# log.debug("processing items")
# channels = await self.fetch_or_get_channels(await sub.get_channels(self.api))
# filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters]
# for item in feed.items:
# log.debug("processing item '%s'", item.guid)
# if item.pub_date < sub.published_threshold:
# log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold)
# continue
# blocked = any(self.filter_item(_filter, item) for _filter in filters)
# mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None
# for channel in channels:
# await self.track_and_send(sub, feed, item, mutated_item, channel, blocked)
# async def fetch_or_get_channels(self, channels_data: list[dict]):
# channels = []
# for data in channels_data:
# try:
# channel = self.bot.get_channel(data.channel_id)
# channels.append(channel or await self.bot.fetch_channel(data.channel_id))
# except Forbidden:
# log.error(f"Forbidden Channel '{data.channel_id}'")
# return channels
# def filter_item(self, _filter: dict, item: RSSItem) -> bool:
# """
# Returns `True` if item should be ignored due to filters.
# """
# match_found = match_text(_filter, item.title) or match_text(_filter, item.description)
# log.debug("filter match found? '%s'", match_found)
# return match_found
# async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool):
# message_id = -1
# log.debug("track and send func %s, %s", item.guid, item.title)
# result = await self.api.get_tracked_content(guid=item.guid)
# if result[1]:
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
# return
# result = await self.api.get_tracked_content(url=item.link)
# if result[1]:
# log.debug(f"This item is already tracked, skipping '{item.guid}'")
# return
# if not blocked:
# try:
# log.debug("sending '%s', exists '%s'", item.guid, result[1])
# sendable_item = mutated_item or item
# message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session))
# message_id = message.id
# except Forbidden:
# log.error(f"Forbidden to send to channel {channel.id}")
# await self.mark_tracked_item(sub, item, channel.id, message_id, blocked)
# async def process_batch(self):
# pass
# async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool):
# try:
# log.debug("marking as tracked")
# await self.api.create_tracked_content(
# guid=item.guid,
# title=item.title,
# url=item.link,
# subscription=sub.id,
# channel_id=channel_id,
# message_id=message_id,
# blocked=blocked
# )
# return True
# except aiohttp.ClientResponseError as error:
# if error.status == 409:
# log.debug(error)
# else:
# log.error(error)
# return False
async def setup(bot): async def setup(bot):