From b8f1ffb8d9c6e33f0f965a6d65d3f08097ba60aa Mon Sep 17 00:00:00 2001 From: Corban-Lee Date: Wed, 30 Oct 2024 17:01:07 +0000 Subject: [PATCH] models for task & task work --- src/extensions/tasks.py | 70 ++++++++++----- src/models.py | 193 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 242 insertions(+), 21 deletions(-) create mode 100644 src/models.py diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 6c0c999..6ebec97 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -20,6 +20,7 @@ from discord.ext import commands, tasks from discord.errors import Forbidden from feedparser import parse +import models from feed import RSSFeed, Subscription, RSSItem, GuildSettings from utils import get_unparsed_feed from filters import match_text @@ -48,9 +49,14 @@ class TaskCog(commands.Cog): api: API | None = None content_queue = deque() + api_base_url: str + api_headers: dict + def __init__(self, bot): super().__init__() self.bot = bot + self.api_base_url = "http://localhost:8000/api/" + self.api_headers = {"Authorization": f"Token {self.bot.api_token}"} @commands.Cog.listener() async def on_ready(self): @@ -91,15 +97,9 @@ class TaskCog(commands.Cog): log.info("Running task") start_time = perf_counter() - client_options = { - "base_url": "http://localhost:8000/api/", - "headers": {"Authorization": f"Token {self.bot.api_token}"} - } - - async with httpx.AsyncClient(**client_options) as client: + async with httpx.AsyncClient() as client: servers = await self.get_servers(client) - for server in servers: - await self.process_server(server, client) + await self.process_servers(servers, client) end_time = perf_counter() log.debug(f"completed task in {end_time - start_time:.4f} seconds") @@ -108,7 +108,11 @@ class TaskCog(commands.Cog): for page_number, _ in enumerate(iterable=iter(int, 1), start=1): params.update({"page": page_number}) - response = await client.get(url, params=params) + response = await client.get( + self.api_base_url + url, + headers=self.api_headers, + params=params + ) response.raise_for_status() content = response.json() @@ -117,37 +121,61 @@ class TaskCog(commands.Cog): if not content.get("next"): break - async def get_servers(self, client: httpx.AsyncClient) -> list[dict]: + async def get_servers(self, client: httpx.AsyncClient) -> list[models.Server]: servers = [] async for servers_batch in self.iterate_pages(client, "servers/"): if servers_batch: servers.extend(servers_batch) - return servers + return models.Server.from_list(servers) - async def get_subscriptions(self, server: dict, client: httpx.AsyncClient) -> list[dict]: + async def get_subscriptions(self, server: models.Server, client: httpx.AsyncClient) -> list[models.Subscription]: subscriptions = [] - params = {"server": server.get("id")} + params = {"server": server.id} async for subscriptions_batch in self.iterate_pages(client, "subscriptions/", params): if subscriptions_batch: subscriptions.extend(subscriptions_batch) - return subscriptions + return models.Subscription.from_list(subscriptions) - async def process_server(self, server: dict, client: httpx.AsyncClient): - log.debug(json.dumps(server, indent=4)) + async def process_servers(self, servers: list[models.Server], client: httpx.AsyncClient): + + semaphore = asyncio.Semaphore(10) + + async def batch_process(server: models.Server, client: httpx.AsyncClient): + async with semaphore: await self.process_server(server, client) + + tasks = [batch_process(server, client) for server in servers if server.active] + await asyncio.gather(*tasks) + + async def process_server(self, server: models.Server, client: httpx.AsyncClient): + log.debug(f"processing server: {server.name}") subscriptions = await self.get_subscriptions(server, client) for subscription in subscriptions: - await self.process_subscription(subscription, client) + subscription.server = server + + semaphore = asyncio.Semaphore(10) + + async def batch_process(subscription: models.Subscription, client: httpx.AsyncClient): + async with semaphore: await self.process_subscription(subscription, client) + + tasks = [ + batch_process(subscription, client) + for subscription in subscriptions + if subscription.active + ] + await asyncio.gather(*tasks) + + async def process_subscription(self, subscription: models.Subscription, client: httpx.AsyncClient): + log.debug(f"processing subscription {subscription.name}") + + content = await client.get(subscription.url) + log.debug(content) - async def process_subscription(self, subscription: dict, client: httpx.AsyncClient): - log.debug(json.dumps(subscription, indent=4)) - for content_filter in subscription.get("filters_detail"): - pass diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..3b17b8a --- /dev/null +++ b/src/models.py @@ -0,0 +1,193 @@ +import logging +from enum import Enum +from datetime import datetime + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +log = logging.getLogger(__name__) + + +@dataclass +class DjangoDataModel(ABC): + + @staticmethod + @abstractmethod + def parser(item: dict) -> dict: + return item + + @classmethod + def from_list(cls, data: list[dict]) -> list: + return [cls(**cls.parser(item)) for item in data] + + @classmethod + def from_dict(cls, data: dict): + return cls(**cls.parser(data)) + + +@dataclass(slots=True) +class Server(DjangoDataModel): + id: int + name: str + icon_hash: str + is_bot_operational: bool + active: bool + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = int(item.pop("id")) + return item + + +class MatchingAlgorithm(Enum): + MATCH_NONE = 0 + MATCH_ANY = 1 + MATCH_ALL = 2 + MATCH_LITERAL = 3 + MATCH_REGEX = 4 + MATCH_FUZZY = 5 + MATCH_AUTO = 6 + + @classmethod + def from_value(cls, value: int): + for member in cls: + if member.value == value: + return member + + raise ValueError(f"No {self.__class__.__name__} for value: {value}") + + + +@dataclass(slots=True) +class ContentFilter(DjangoDataModel): + id: int + server_id: int + name: str + matching_pattern: str + matching_algorithm: MatchingAlgorithm + is_insensitive: bool + is_whitelist: bool + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = item.pop("id") + item["server_id"] = item.pop("server") + item["matching_pattern"] = item.pop("match") + item["matching_algorithm"] = MatchingAlgorithm.from_value(item.pop("matching_algorithm")) + return item + + +@dataclass(slots=True) +class MessageMutator(DjangoDataModel): + id: int + name: str + value: str + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = item.pop("id") + return item + +@dataclass(slots=True) +class MessageStyle(DjangoDataModel): + id: int + server_id: int + name: str + colour: str + is_embed: bool + is_hyperlinked: bool + show_author: bool + show_timestamp: bool + show_images: bool + fetch_images: bool + title_mutator: dict | None + description_mutator: dict | None + auto_created: bool + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = int(item.pop("id")) + item["server_id"] = int(item.pop("server")) + item["title_mutator"] = item.pop("title_mutator_detail") + item["description_mutator"] = item.pop("description_mutator_detail") + return item + + +@dataclass(slots=True) +class UniqueContentRule(DjangoDataModel): + id: int + name: str + value: str + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = int(item.pop("id")) + return item + + +@dataclass(slots=True) +class DiscordChannel(DjangoDataModel): + id: int + name: str + is_nsfw: bool + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = int(item.pop("id")) + return item + + +@dataclass(slots=True) +class Subscription(DjangoDataModel): + id: int + server_id: int + name: str + url: str + created_at: datetime + updated_at: datetime + extra_notes: str + active: bool + publish_threshold: datetime + channels: list[DiscordChannel] + filters: list[ContentFilter] + message_style: MessageStyle + unique_rules: UniqueContentRule + _server: Server | None = None + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = int(item.pop("id")) + item["server_id"] = int(item.pop("server")) + item["created_at"] = datetime.strptime(item.pop("created_at"), "%Y-%m-%dT%H:%M:%S.%f%z") + item["updated_at"] = datetime.strptime(item.pop("updated_at"), "%Y-%m-%dT%H:%M:%S.%f%z") + item["publish_threshold"] = datetime.strptime(item.pop("publish_threshold"), "%Y-%m-%dT%H:%M:%S%z") + item["channels"] = DiscordChannel.from_list(item.pop("channels_detail")) + item["filters"] = ContentFilter.from_list(item.pop("filters_detail")) + item["message_style"] = MessageStyle.from_dict(item.pop("message_style_detail")) + item["unique_rules"] = UniqueContentRule.from_list(item.pop("unique_rules_detail")) + return item + + @property + def server(self) -> Server: + return self._server + + @server.setter + def server(self, server: server): + self._server = server + + +@dataclass(slots=True) +class Content(DjangoDataModel): + id: int + subscription_id: int + item_id: str + item_guid: str + item_url: str + item_title: str + item_content_hash: str + + @staticmethod + def parser(item: dict) -> dict: + item["id"] = item.pop("id") + item["server_id"] = item.pop("server") + return item