diff --git a/requirements.txt b/requirements.txt index 1b235b9..a122e28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,16 +2,21 @@ aiocache==0.12.2 aiohttp==3.9.3 aiosignal==1.3.1 aiosqlite==0.19.0 +anyio==4.6.2.post1 async-timeout==4.0.3 asyncpg==0.29.0 attrs==23.2.0 beautifulsoup4==4.12.3 bump2version==1.0.1 +certifi==2024.8.30 click==8.1.7 discord.py==2.3.2 feedparser==6.0.11 frozenlist==1.4.1 greenlet==3.0.3 +h11==0.14.0 +httpcore==1.0.6 +httpx==0.27.2 idna==3.6 markdownify==0.11.6 multidict==6.0.5 @@ -21,6 +26,7 @@ python-dotenv==1.0.0 rapidfuzz==3.9.4 sgmllib3k==1.0.0 six==1.16.0 +sniffio==1.3.1 soupsieve==2.5 SQLAlchemy==2.0.23 typing_extensions==4.10.0 diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index ec49ee6..6c0c999 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -3,6 +3,7 @@ Extension for the `TaskCog`. Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot. """ +import json import asyncio import logging import datetime @@ -11,9 +12,10 @@ from time import perf_counter from collections import deque import aiohttp +import httpx from aiocache import Cache from discord import TextChannel -from discord import app_commands +from discord import app_commands, Interaction from discord.ext import commands, tasks from discord.errors import Forbidden from feedparser import parse @@ -55,7 +57,7 @@ class TaskCog(commands.Cog): """ Instructions to execute when the cog is ready. """ - self.subscription_task.start() + # self.subscription_task.start() log.info("%s cog is ready", self.__class__.__name__) @commands.Cog.listener(name="cog_unload") @@ -72,219 +74,297 @@ class TaskCog(commands.Cog): ) @group.command(name="trigger") - async def cmd_trigger_task(self, inter): + async def cmd_trigger_task(self, inter: Interaction): await inter.response.defer() start_time = perf_counter() try: - await self.subscription_task() + await self.do_task() except Exception as error: + log.error(error.with_traceback()) await inter.followup.send(str(error)) finally: end_time = perf_counter() - await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") + await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds") - @tasks.loop(time=subscription_task_times) - async def subscription_task(self): - """ - Task for fetching and processing subscriptions. - """ - log.info("Running subscription task") + async def do_task(self): + log.info("Running task") start_time = perf_counter() - async with aiohttp.ClientSession() as session: - self.api = API(self.bot.api_token, session) - await self.execute_task() + client_options = { + "base_url": "http://localhost:8000/api/", + "headers": {"Authorization": f"Token {self.bot.api_token}"} + } + + async with httpx.AsyncClient(**client_options) as client: + servers = await self.get_servers(client) + for server in servers: + await self.process_server(server, client) end_time = perf_counter() - log.debug(f"task completed in {end_time - start_time:.4f} seconds") + log.debug(f"completed task in {end_time - start_time:.4f} seconds") - async def execute_task(self): - """Execute the task directly.""" + async def iterate_pages(self, client: httpx.AsyncClient, url: str, params: dict={}): - # Filter out inactive guild IDs using related settings - guild_ids = [guild.id for guild in self.bot.guilds] - guild_settings = await self.get_guild_settings(guild_ids) - active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active] + for page_number, _ in enumerate(iterable=iter(int, 1), start=1): + params.update({"page": page_number}) + response = await client.get(url, params=params) + response.raise_for_status() + content = response.json() - subscriptions = await self.get_subscriptions(active_guild_ids) - await self.process_subscriptions(subscriptions) + yield content.get("results", []) - async def get_guild_settings(self, guild_ids: list[int]) -> list[int]: - """Returns a list of guild settings from the Bot's guilds, if they exist.""" - - guild_settings = [] - - # Iterate infinitely taking the iter no. as `page` - # data will be empty after last page reached. - for page, _ in enumerate(iter(int, 1)): - data = await self.get_guild_settings_page(guild_ids, page) - if not data: + if not content.get("next"): break - guild_settings.extend(data[0]) + async def get_servers(self, client: httpx.AsyncClient) -> list[dict]: + servers = [] - # Only return active guild IDs - return GuildSettings.from_list(guild_settings) + async for servers_batch in self.iterate_pages(client, "servers/"): + if servers_batch: + servers.extend(servers_batch) - async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]: - """Returns an individual page of guild settings.""" - - try: - return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1) - except aiohttp.ClientResponseError as error: - self.handle_pagination_error(error) - return [] - - def handle_pagination_error(self, error: aiohttp.ClientResponseError): - """Handle the error cases from pagination attempts.""" - - match error.status: - case 404: - log.debug("final page reached") - case 403: - log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True) - self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently - case _: - log.debug(error) - - async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]: - """Get a list of `Subscription`s matching the given `guild_ids`.""" + return servers + async def get_subscriptions(self, server: dict, client: httpx.AsyncClient) -> list[dict]: subscriptions = [] + params = {"server": server.get("id")} - # Iterate infinitely taking the iter no. as `page` - # data will be empty after last page reached. - for page, _ in enumerate(iter(int, 1)): - data = await self.get_subs_page(guild_ids, page) - if not data: - break + async for subscriptions_batch in self.iterate_pages(client, "subscriptions/", params): + if subscriptions_batch: + subscriptions.extend(subscriptions_batch) - subscriptions.extend(data[0]) + return subscriptions - return Subscription.from_list(subscriptions) + async def process_server(self, server: dict, client: httpx.AsyncClient): + log.debug(json.dumps(server, indent=4)) - async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]: - """Returns an individual page of subscriptions.""" + subscriptions = await self.get_subscriptions(server, client) + for subscription in subscriptions: + await self.process_subscription(subscription, client) - try: - return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1) - except aiohttp.ClientResponseError as error: - self.handle_pagination_error(error) - return [] + async def process_subscription(self, subscription: dict, client: httpx.AsyncClient): + log.debug(json.dumps(subscription, indent=4)) - async def process_subscriptions(self, subscriptions: list[Subscription]): - """Process a given list of `Subscription`s.""" + for content_filter in subscription.get("filters_detail"): + pass - async def process_single_subscription(sub: Subscription): - log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id) - if not sub.active or not sub.channels_count: - return - unparsed_feed = await get_unparsed_feed(sub.url) - parsed_feed = parse(unparsed_feed) + # @group.command(name="trigger") + # async def cmd_trigger_task(self, inter): + # await inter.response.defer() + # start_time = perf_counter() - rss_feed = RSSFeed.from_parsed_feed(parsed_feed) - await self.process_items(sub, rss_feed) + # try: + # await self.subscription_task() + # except Exception as error: + # await inter.followup.send(str(error)) + # finally: + # end_time = perf_counter() + # await inter.followup.send(f"completed in {end_time - start_time:.4f} seconds") - semaphore = asyncio.Semaphore(10) + # @tasks.loop(time=subscription_task_times) + # async def subscription_task(self): + # """ + # Task for fetching and processing subscriptions. + # """ + # log.info("Running subscription task") + # start_time = perf_counter() - async def semaphore_process(sub: Subscription): - async with semaphore: - await process_single_subscription(sub) + # async with aiohttp.ClientSession() as session: + # self.api = API(self.bot.api_token, session) + # await self.execute_task() - await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions)) + # end_time = perf_counter() + # log.debug(f"task completed in {end_time - start_time:.4f} seconds") - async def process_items(self, sub: Subscription, feed: RSSFeed): - log.debug("processing items") + # async def execute_task(self): + # """Execute the task directly.""" - channels = await self.fetch_or_get_channels(await sub.get_channels(self.api)) - filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters] + # # Filter out inactive guild IDs using related settings + # guild_ids = [guild.id for guild in self.bot.guilds] + # guild_settings = await self.get_guild_settings(guild_ids) + # active_guild_ids = [settings.guild_id for settings in guild_settings if settings.active] - for item in feed.items: - log.debug("processing item '%s'", item.guid) + # subscriptions = await self.get_subscriptions(active_guild_ids) + # await self.process_subscriptions(subscriptions) - if item.pub_date < sub.published_threshold: - log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold) - continue + # async def get_guild_settings(self, guild_ids: list[int]) -> list[int]: + # """Returns a list of guild settings from the Bot's guilds, if they exist.""" - blocked = any(self.filter_item(_filter, item) for _filter in filters) - mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None + # guild_settings = [] - for channel in channels: - await self.track_and_send(sub, feed, item, mutated_item, channel, blocked) + # # Iterate infinitely taking the iter no. as `page` + # # data will be empty after last page reached. + # for page, _ in enumerate(iter(int, 1)): + # data = await self.get_guild_settings_page(guild_ids, page) + # if not data: + # break - async def fetch_or_get_channels(self, channels_data: list[dict]): - channels = [] + # guild_settings.extend(data[0]) - for data in channels_data: - try: - channel = self.bot.get_channel(data.channel_id) - channels.append(channel or await self.bot.fetch_channel(data.channel_id)) - except Forbidden: - log.error(f"Forbidden Channel '{data.channel_id}'") + # # Only return active guild IDs + # return GuildSettings.from_list(guild_settings) - return channels + # async def get_guild_settings_page(self, guild_ids: list[int], page: int) -> list[dict]: + # """Returns an individual page of guild settings.""" - def filter_item(self, _filter: dict, item: RSSItem) -> bool: - """ - Returns `True` if item should be ignored due to filters. - """ + # try: + # return await self.api.get_guild_settings(guild_id__in=guild_ids, page=page+1) + # except aiohttp.ClientResponseError as error: + # self.handle_pagination_error(error) + # return [] - match_found = match_text(_filter, item.title) or match_text(_filter, item.description) - log.debug("filter match found? '%s'", match_found) - return match_found + # def handle_pagination_error(self, error: aiohttp.ClientResponseError): + # """Handle the error cases from pagination attempts.""" - async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool): - message_id = -1 + # match error.status: + # case 404: + # log.debug("final page reached") + # case 403: + # log.critical("[403] Bot likely lacks permissions: %s", error, exc_info=True) + # self.subscription_task.cancel() # can't do task without proper auth, so cancel permanently + # case _: + # log.debug(error) - log.debug("track and send func %s, %s", item.guid, item.title) + # async def get_subscriptions(self, guild_ids: list[int]) -> list[Subscription]: + # """Get a list of `Subscription`s matching the given `guild_ids`.""" - result = await self.api.get_tracked_content(guid=item.guid) - if result[1]: - log.debug(f"This item is already tracked, skipping '{item.guid}'") - return + # subscriptions = [] - result = await self.api.get_tracked_content(url=item.link) - if result[1]: - log.debug(f"This item is already tracked, skipping '{item.guid}'") - return + # # Iterate infinitely taking the iter no. as `page` + # # data will be empty after last page reached. + # for page, _ in enumerate(iter(int, 1)): + # data = await self.get_subs_page(guild_ids, page) + # if not data: + # break - if not blocked: - try: - log.debug("sending '%s', exists '%s'", item.guid, result[1]) - sendable_item = mutated_item or item - message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session)) - message_id = message.id - except Forbidden: - log.error(f"Forbidden to send to channel {channel.id}") + # subscriptions.extend(data[0]) - await self.mark_tracked_item(sub, item, channel.id, message_id, blocked) + # return Subscription.from_list(subscriptions) - async def process_batch(self): - pass + # async def get_subs_page(self, guild_ids: list[int], page: int) -> list[Subscription]: + # """Returns an individual page of subscriptions.""" - async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): - try: - log.debug("marking as tracked") - await self.api.create_tracked_content( - guid=item.guid, - title=item.title, - url=item.link, - subscription=sub.id, - channel_id=channel_id, - message_id=message_id, - blocked=blocked - ) - return True - except aiohttp.ClientResponseError as error: - if error.status == 409: - log.debug(error) - else: - log.error(error) + # try: + # return await self.api.get_subscriptions(guild_id__in=guild_ids, page=page+1) + # except aiohttp.ClientResponseError as error: + # self.handle_pagination_error(error) + # return [] - return False + # async def process_subscriptions(self, subscriptions: list[Subscription]): + # """Process a given list of `Subscription`s.""" + + # async def process_single_subscription(sub: Subscription): + # log.debug("processing subscription '%s' for '%s'", sub.id, sub.guild_id) + + # if not sub.active or not sub.channels_count: + # return + + # unparsed_feed = await get_unparsed_feed(sub.url) + # parsed_feed = parse(unparsed_feed) + + # rss_feed = RSSFeed.from_parsed_feed(parsed_feed) + # await self.process_items(sub, rss_feed) + + # semaphore = asyncio.Semaphore(10) + + # async def semaphore_process(sub: Subscription): + # async with semaphore: + # await process_single_subscription(sub) + + # await asyncio.gather(*(semaphore_process(sub) for sub in subscriptions)) + + # async def process_items(self, sub: Subscription, feed: RSSFeed): + # log.debug("processing items") + + # channels = await self.fetch_or_get_channels(await sub.get_channels(self.api)) + # filters = [await self.api.get_filter(filter_id) for filter_id in sub.filters] + + # for item in feed.items: + # log.debug("processing item '%s'", item.guid) + + # if item.pub_date < sub.published_threshold: + # log.debug("item '%s' older than subscription threshold '%s', skipping", item.pub_date, sub.published_threshold) + # continue + + # blocked = any(self.filter_item(_filter, item) for _filter in filters) + # mutated_item = item.create_mutated_copy(sub.mutators) if sub.mutators else None + + # for channel in channels: + # await self.track_and_send(sub, feed, item, mutated_item, channel, blocked) + + # async def fetch_or_get_channels(self, channels_data: list[dict]): + # channels = [] + + # for data in channels_data: + # try: + # channel = self.bot.get_channel(data.channel_id) + # channels.append(channel or await self.bot.fetch_channel(data.channel_id)) + # except Forbidden: + # log.error(f"Forbidden Channel '{data.channel_id}'") + + # return channels + + # def filter_item(self, _filter: dict, item: RSSItem) -> bool: + # """ + # Returns `True` if item should be ignored due to filters. + # """ + + # match_found = match_text(_filter, item.title) or match_text(_filter, item.description) + # log.debug("filter match found? '%s'", match_found) + # return match_found + + # async def track_and_send(self, sub: Subscription, feed: RSSFeed, item: RSSItem, mutated_item: RSSItem | None, channel: TextChannel, blocked: bool): + # message_id = -1 + + # log.debug("track and send func %s, %s", item.guid, item.title) + + # result = await self.api.get_tracked_content(guid=item.guid) + # if result[1]: + # log.debug(f"This item is already tracked, skipping '{item.guid}'") + # return + + # result = await self.api.get_tracked_content(url=item.link) + # if result[1]: + # log.debug(f"This item is already tracked, skipping '{item.guid}'") + # return + + # if not blocked: + # try: + # log.debug("sending '%s', exists '%s'", item.guid, result[1]) + # sendable_item = mutated_item or item + # message = await channel.send(embed=await sendable_item.to_embed(sub, feed, self.api.session)) + # message_id = message.id + # except Forbidden: + # log.error(f"Forbidden to send to channel {channel.id}") + + # await self.mark_tracked_item(sub, item, channel.id, message_id, blocked) + + # async def process_batch(self): + # pass + + # async def mark_tracked_item(self, sub: Subscription, item: RSSItem, channel_id: int, message_id: int, blocked: bool): + # try: + # log.debug("marking as tracked") + # await self.api.create_tracked_content( + # guid=item.guid, + # title=item.title, + # url=item.link, + # subscription=sub.id, + # channel_id=channel_id, + # message_id=message_id, + # blocked=blocked + # ) + # return True + # except aiohttp.ClientResponseError as error: + # if error.status == 409: + # log.debug(error) + # else: + # log.error(error) + + # return False async def setup(bot):