diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 6ebec97..17f5acf 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -7,28 +7,30 @@ import json import asyncio import logging import datetime +import traceback from os import getenv from time import perf_counter from collections import deque -import aiohttp +# import aiohttp import httpx -from aiocache import Cache +import feedparser +import discord +# from aiocache import Cache from discord import TextChannel from discord import app_commands, Interaction from discord.ext import commands, tasks from discord.errors import Forbidden -from feedparser import parse import models -from feed import RSSFeed, Subscription, RSSItem, GuildSettings -from utils import get_unparsed_feed -from filters import match_text +# from feed import RSSFeed, Subscription, RSSItem, GuildSettings +# from utils import get_unparsed_feed +# from filters import match_text from api import API log = logging.getLogger(__name__) -cache = Cache(Cache.MEMORY) +# cache = Cache(Cache.MEMORY) BATCH_SIZE = 100 @@ -86,9 +88,9 @@ class TaskCog(commands.Cog): try: await self.do_task() - except Exception as error: - log.error(error.with_traceback()) - await inter.followup.send(str(error)) + except Exception as exc: + log.exception(exc) + await inter.followup.send(str(exc) or "unknown error") finally: end_time = perf_counter() await inter.followup.send(f"completed command in {end_time - start_time:.4f} seconds") @@ -152,6 +154,7 @@ class TaskCog(commands.Cog): async def process_server(self, server: models.Server, client: httpx.AsyncClient): log.debug(f"processing server: {server.name}") + start_time = perf_counter() subscriptions = await self.get_subscriptions(server, client) for subscription in subscriptions: @@ -169,15 +172,35 @@ class TaskCog(commands.Cog): ] await asyncio.gather(*tasks) + end_time = perf_counter() + log.debug(f"Finished processing server: {server.name} in {end_time - start_time:.4f} seconds") + async def process_subscription(self, subscription: models.Subscription, client: httpx.AsyncClient): log.debug(f"processing subscription {subscription.name}") + start_time = perf_counter() - content = await client.get(subscription.url) - log.debug(content) + raw_rss_content = await subscription.get_rss_content(client) + if not raw_rss_content: + return + channels = await subscription.get_discord_channels(self.bot) + contents = models.Content.from_raw_rss(raw_rss_content, subscription) + valid_contents, invalid_contents = subscription.filter_entries(contents) + for content in valid_contents: + await self.process_content(content, channels) + tasks = [channel.send(content.item_title) for channel in channels] + asyncio.gather(*tasks) + end_time = perf_counter() + log.debug(f"Finished processing subscription: {subscription.name} in {end_time - start_time:.4f}") + async def process_valid_contents(contents: list[models.Content], channels: list[discord.TextChannel], client: httpx.AsyncClient): + semaphore = asyncio.Semaphore(5) + + async def batch_process(content: models.Content, ) + + # @group.command(name="trigger") # async def cmd_trigger_task(self, inter): diff --git a/src/models.py b/src/models.py index 3b17b8a..3ebd6cb 100644 --- a/src/models.py +++ b/src/models.py @@ -1,10 +1,16 @@ +import re import logging +import hashlib from enum import Enum from datetime import datetime - from abc import ABC, abstractmethod from dataclasses import dataclass +import httpx +import discord +import rapidfuzz +import feedparser + log = logging.getLogger(__name__) @@ -40,13 +46,13 @@ class Server(DjangoDataModel): class MatchingAlgorithm(Enum): - MATCH_NONE = 0 - MATCH_ANY = 1 - MATCH_ALL = 2 - MATCH_LITERAL = 3 - MATCH_REGEX = 4 - MATCH_FUZZY = 5 - MATCH_AUTO = 6 + NONE = 0 + ANY = 1 + ALL = 2 + LITERAL = 3 + REGEX = 4 + FUZZY = 5 + AUTO = 6 @classmethod def from_value(cls, value: int): @@ -54,7 +60,7 @@ class MatchingAlgorithm(Enum): if member.value == value: return member - raise ValueError(f"No {self.__class__.__name__} for value: {value}") + raise ValueError(f"No {cls.__class__.__name__} for value: {value}") @@ -76,6 +82,96 @@ class ContentFilter(DjangoDataModel): item["matching_algorithm"] = MatchingAlgorithm.from_value(item.pop("matching_algorithm")) return item + @property + def _regex_flags(self): + return re.IGNORECASE if self.is_insensitive else 0 + + @property + def cleaned_matching_pattern(self): + """ + Splits the pattern to individual keywords, getting rid of unnecessary + spaces and grouping quoted words together. + + """ + findterms = re.compile(r'"([^"]+)"|(\S+)').findall + normspace = re.compile(r"\s+").sub + return [ + re.escape(normspace(" ", (t[0] or t[1]).strip())).replace(r"\ ", r"\s+") + for t in findterms(self.matching_pattern) + ] + + def _match_any(self, matching_against: str): + for word in self.cleaned_matching_pattern: + if re.search(rf"\b{word}\b", matching_against, self._regex_flags): + return True + + return False + + def _match_all(self, matching_against: str): + for word in self.cleaned_matching_pattern: + if re.search(rf"\b{word}\b", matching_against, self._regex_flags): + return False + + return True + + def _match_literal(self, matching_against: str): + return bool( + re.search( + rf"\b{re.escape(self.matching_pattern)}\b", + matching_against, + self._regex_flags + ) + ) + + def _match_regex(self, matching_against: str): + try: + return bool(re.search( + re.compile(self.matching_pattern, self._regex_flags), + matching_against + )) + except re.error as exc: + log.error(f"Filter regex error: {exc}") + return False + + def _match_fuzzy(self, matching_against: str): + matching_against = re.sub(r"[^\w\s]", "", matching_against) + matching_pattern = re.sub(r"[^\w\s]", "", self.matching_pattern) + if self.is_insensitive: + matching_against = matching_against.lower() + matching_pattern = matching_pattern.lower() + + return rapidfuzz.fuzz.partial_ratio( + matching_against, + matching_pattern, + score_cutoff=90 + ) + + def _get_algorithm_func(self): + match self.matching_algorithm: + case MatchingAlgorithm.NONE: return + case MatchingAlgorithm.ANY: return self._match_any + case MatchingAlgorithm.ALL: return self._match_all + case MatchingAlgorithm.LITERAL: return self._match_literal + case MatchingAlgorithm.REGEX: return self._match_regex + case MatchingAlgorithm.FUZZY: return self._match_fuzzy + case _: return + + def matches(self, content) -> bool: + log.debug(f"applying filter: {self}") + + if not self.matching_pattern.strip(): + return False + + algorithm_func = self._get_algorithm_func() + if not algorithm_func: + log.error(f"Bad algorithm function: {self.matching_algorithm}") + return False + + match_found = algorithm_func(content.item_title) or algorithm_func(content.item_description) + log.debug(f"filter match found: {match_found}") + + return not match_found if self.is_whitelist else match_found + @dataclass(slots=True) class MessageMutator(DjangoDataModel): @@ -175,6 +271,49 @@ class Subscription(DjangoDataModel): def server(self, server: server): self._server = server + async def get_rss_content(self, client: httpx.AsyncClient) -> str: + try: + response = await client.get(self.url) + response.raise_for_status() + except httpx.HTTPError as exc: + log.error("(%s) HTTP Exception for %s - %s", type(exc), exc.request.url, exc) + return + + content_type = response.headers.get("Content-Type") + if not "text/xml" in content_type: + log.warning("Invalid 'Content-Type' header: %s (must contain 'text/xml')", content_type) + return + + return response.text + + async def get_discord_channels(self, bot) -> list: + channels = [] + + for channel_detail in self.channels: + try: + channel = bot.get_channel(channel_detail.id) + channels.append(channel or await bot.fetch_channel(channel_detail.id)) + except discord.Forbidden: + log.error(f"Forbidden channel: ({channel.name}, {channel.id}) from ({self.server.name}, {self.server.id})") + + return channels + + def filter_entries(self, contents: list) -> tuple[list, list]: + log.debug(f"filtering entries for {self.name} in {self.server.name}") + + valid_contents = [] + invalid_contents = [] + + for content in contents: + log.debug(f"filtering: '{content.item_title}'") + if any(content_filter.matches(content) for content_filter in self.filters): + invalid_contents.append(content) + else: + valid_contents.append(content) + + log.debug(f"filtered content: valid:{len(valid_contents)}, invalid:{len(invalid_contents)}") + return valid_contents, invalid_contents + @dataclass(slots=True) class Content(DjangoDataModel): @@ -184,10 +323,45 @@ class Content(DjangoDataModel): item_guid: str item_url: str item_title: str - item_content_hash: str + item_description: str + _subscription: Subscription | None = None @staticmethod def parser(item: dict) -> dict: item["id"] = item.pop("id") - item["server_id"] = item.pop("server") + item["subscription_id"] = item.pop("subscription") return item + + @classmethod + def from_raw_rss(cls, raw_rss_content: str, subscription: Subscription): + parsed_rss = feedparser.parse(raw_rss_content) + contents = [] + + for entry in parsed_rss.entries: + # content_hash = hashlib.new("sha256") + # content_hash.update(entry.get("description", "").encode()) + # content_hash.hexdigest() + + data = { + "id": -1, + "subscription": subscription.id, + "item_id": entry.get("id", ""), + "item_guid": entry.get("guid", ""), + "item_url": entry.get("link", ""), + "item_title": entry.get("title", ""), + "item_description": entry.get("description", "") + } + + content = Content.from_dict(data) + content.subscription = subscription + contents.append(content) + + return contents + + @property + def subscription(self) -> Subscription: + return self._subscription + + @subscription.setter + def subscription(self, subscription: Subscription): + self._subscription = subscription