From 624d47b74ef9d9fc4b296bc3e520ae77a1e87b3e Mon Sep 17 00:00:00 2001 From: Corban-Lee Jones Date: Mon, 15 Jul 2024 21:28:37 +0100 Subject: [PATCH] working on improving task speed --- requirements.txt | 1 + src/api.py | 10 +++++ src/extensions/tasks.py | 87 +++++++++++++++++++++++++++++------------ src/feed.py | 17 ++++++-- 4 files changed, 87 insertions(+), 28 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5af2777..50ee102 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiocache==0.12.2 aiohttp==3.9.3 aiosignal==1.3.1 aiosqlite==0.19.0 diff --git a/src/api.py b/src/api.py index bd88690..fb2baa8 100644 --- a/src/api.py +++ b/src/api.py @@ -74,6 +74,7 @@ class API: status = response.status + log.debug("request to '%s', response '%s', kwargs '%s'", url, status, kwargs) return {"json": json, "text": text, "status": status} async def _post_data(self, url: str, data: dict | aiohttp.FormData) -> dict: @@ -139,6 +140,15 @@ class API: return await self._post_data(self.API_ENDPOINT + "tracked-content/", data) + async def get_tracked_content(self, **filters) -> tuple[list[dict], int]: + """ + Return an instance of tracked content. + """ + + log.debug("getting tracked content") + + return await self._get_many(self.API_ENDPOINT + f"tracked-content/", filters) + async def get_filter(self, filter_id: int) -> dict: """ Get an instance of Filter. diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 2e1ce21..eb7bf80 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -5,12 +5,16 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo import re import json +import asyncio import logging import datetime +import urllib.parse from os import getenv -from time import process_time +from time import perf_counter +from collections import deque import aiohttp +from aiocache import Cache from discord import TextChannel, Embed, Colour from discord import app_commands from discord.ext import commands, tasks @@ -24,6 +28,10 @@ from api import API log = logging.getLogger(__name__) +cache = Cache(Cache.MEMORY) + +BATCH_SIZE = 100 + TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") subscription_task_times = [ datetime.time(hour, minute, tzinfo=datetime.timezone.utc) @@ -38,6 +46,9 @@ class TaskCog(commands.Cog): Tasks cog for PYRSS. """ + api: API | None = None + content_queue = deque() + def __init__(self, bot): super().__init__() self.bot = bot @@ -60,8 +71,10 @@ class TaskCog(commands.Cog): @app_commands.command(name="debug-trigger-task") async def debug_trigger_task(self, inter): await inter.response.defer() + start_time = perf_counter() await self.subscription_task() - await inter.followup.send("done") + end_time = perf_counter() + await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") @tasks.loop(time=subscription_task_times) async def subscription_task(self): @@ -69,13 +82,17 @@ class TaskCog(commands.Cog): Task for fetching and processing subscriptions. """ log.info("Running subscription task") + start_time = perf_counter() async with aiohttp.ClientSession() as session: - api = API(self.bot.api_token, session) - subscriptions = await self.get_subscriptions(api) - await self.process_subscriptions(api, subscriptions) + self.api = API(self.bot.api_token, session) + subscriptions = await self.get_subscriptions() + await self.process_subscriptions(subscriptions) - async def get_subscriptions(self, api: API) -> list[Subscription]: + end_time = perf_counter() + log.debug(f"task completed in {end_time - start_time:.4f} seconds") + + async def get_subscriptions(self) -> list[Subscription]: guild_ids = [guild.id for guild in self.bot.guilds] sub_data = [] @@ -83,7 +100,7 @@ class TaskCog(commands.Cog): try: log.debug("fetching page '%s'", page + 1) sub_data.extend( - (await api.get_subscriptions(server__in=guild_ids, page=page+1))[0] + (await self.api.get_subscriptions(server__in=guild_ids, page=page+1))[0] ) except aiohttp.ClientResponseError as error: match error.status: @@ -105,26 +122,32 @@ class TaskCog(commands.Cog): return Subscription.from_list(sub_data) - async def process_subscriptions(self, api: API, subscriptions: list[Subscription]): - for sub in subscriptions: - log.debug("processing subscription '%s'", sub.id) + async def process_subscriptions(self, subscriptions: list[Subscription]): + async def process_single_subscription(sub: Subscription): + log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id) if not sub.active or not sub.channels_count: - continue + return - unparsed_feed = await get_unparsed_feed(sub.url, api.session) + unparsed_feed = await get_unparsed_feed(sub.url) parsed_feed = parse(unparsed_feed) rss_feed = RSSFeed.from_parsed_feed(parsed_feed) - await self.process_items(api, sub, rss_feed) + await self.process_items(sub, rss_feed) - async def process_items(self, api: API, sub: Subscription, feed: RSSFeed): + semaphore = asyncio.Semaphore(10) + + async def semaphore_process(sub: Subscription): + async with semaphore: + await process_single_subscription(sub) + + await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions)) + + async def process_items(self, sub: Subscription, feed: RSSFeed): log.debug("processing items") - channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(api)] - filters = [await api.get_filter(filter_id) for filter_id in sub.filters] - - log.debug(json.dumps(filters, indent=4)) + channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(self.api)] + filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters] for item in feed.items: log.debug("processing item '%s'", item.guid) @@ -137,7 +160,7 @@ class TaskCog(commands.Cog): mutated_item = item.create_mutated_copy(sub.mutators) for channel in channels: - await self.track_and_send(api, sub, feed, item, mutated_item, channel, blocked) + await self.track_and_send(sub, feed, item, mutated_item, channel, blocked) def filter_item(self, _filter: dict, item: RSSItem) -> bool: """ @@ -148,21 +171,37 @@ class TaskCog(commands.Cog): log.debug("filter match found? '%s'", match_found) return match_found - async def track_and_send(self, api: API, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem, channel: TextChannel, blocked: bool): + async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem, channel: TextChannel, blocked: bool): message_id = -1 + log.debug("track and send func %s, %s", item.guid, item.title) + + result = await self.api.get_tracked_content(guid=item.guid) + if result[1]: + log.debug(f"This item is already tracked, skipping '{item.guid}'") + return + + result = await self.api.get_tracked_content(url=item.link) + if result[1]: + log.debug(f"This item is already tracked, skipping '{item.guid}'") + return + try: - message = await channel.send(embed=await mutated_item.to_embed(sub, feed, api.session)) + log.debug("sending '%s', exists '%s'", item.guid, result[1]) + message = await channel.send(embed=await mutated_item.to_embed(sub, feed, self.api.session)) message_id = message.id except Forbidden as error: log.error(error) finally: - await self.mark_tracked_item(api, sub, item, channel.id, message_id, blocked) + await self.mark_tracked_item(sub, item, channel.id, message_id, blocked) - async def mark_tracked_item(self, api: API, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): + async def process_batch(self): + pass + + async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): try: log.debug("marking as tracked") - await api.create_tracked_content( + await self.api.create_tracked_content( guid=item.guid, title=item.title, url=item.link, diff --git a/src/feed.py b/src/feed.py index e2cfb1f..460782f 100644 --- a/src/feed.py +++ b/src/feed.py @@ -33,7 +33,9 @@ class RSSItem: title: str description: str pub_date: datetime - image_url: str + content_image_url: str + thumb_image_url: str + entry: FeedParserDict @classmethod def from_parsed_entry(cls, entry: FeedParserDict) -> RSSItem: @@ -57,9 +59,10 @@ class RSSItem: pub_date = entry.get('published_parsed', None) pub_date = datetime(*pub_date[0:6], tzinfo=timezone.utc) - image_url = entry.get("media_content", [{}])[0].get("url") + content_image_url = entry.get("media_content", [{}])[0].get("url") + thumb_image_url = entry.get("media_thumbnail", [{}])[0].get("url") - return cls(guid, link, title, description, pub_date, image_url) + return cls(guid, link, title, description, pub_date, content_image_url, thumb_image_url, entry) def create_mutated_copy(self, mutators: dict[str, dict[str, str]]) -> RSSItem: """Returns a copy of `self` with the specified `mutations`. @@ -125,6 +128,8 @@ class RSSItem: discord.Embed """ + log.debug(json.dumps(self.entry, indent=4)) + # Replace HTML with Markdown, and shorten text. title = shorten(markdownify(self.title, strip=["img", "a"]), 256) desc = shorten(markdownify(self.description, strip=["img"]), 4096) @@ -145,8 +150,12 @@ class RSSItem: ) if sub.article_fetch_image: - embed.set_image(url=self.image_url or await self.get_thumbnail_url(session)) + img_url = self.content_image_url if validators.url(self.content_image_url) else await self.get_thumbnail_url(session) + img_url = self.thumb_image_url if not img_url and validators.url(self.thumb_image_url) else None + embed.set_image(url=img_url) embed.set_thumbnail(url=feed.image_href if validators.url(feed.image_href) else None) + # log.debug("embed image check %s, %s", self.image_url, validators.url(self.image_url)) + # embed.set_image(url=self.image_url if validators.url(self.image_url) else None) embed.set_author(name=author, url=feed.link) embed.set_footer(text=sub.name)