294 lines
9.0 KiB
Python
294 lines
9.0 KiB
Python
"""
|
|
Extension for the `TaskCog`.
|
|
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
|
"""
|
|
|
|
import re
|
|
import json
|
|
import logging
|
|
import datetime
|
|
from os import getenv
|
|
from time import process_time
|
|
|
|
import aiohttp
|
|
from discord import TextChannel, Embed, Colour
|
|
from discord import app_commands
|
|
from discord.ext import commands, tasks
|
|
from discord.errors import Forbidden
|
|
from sqlalchemy import insert, select, and_
|
|
from feedparser import parse
|
|
|
|
from feed import Source, Article, RSSFeed, Subscription, SubscriptionChannel, SubChannel
|
|
from db import (
|
|
DatabaseManager,
|
|
FeedChannelModel,
|
|
RssSourceModel,
|
|
SentArticleModel
|
|
)
|
|
from utils import get_unparsed_feed
|
|
from api import API
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
|
|
|
|
# task trigger times : must be of type list
|
|
times = [
|
|
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
|
|
for hour in range(24)
|
|
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
|
]
|
|
|
|
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
|
|
|
|
|
class TaskCog(commands.Cog):
|
|
"""
|
|
Tasks cog.
|
|
"""
|
|
|
|
def __init__(self, bot):
|
|
super().__init__()
|
|
self.bot = bot
|
|
self.time = None
|
|
|
|
@commands.Cog.listener()
|
|
async def on_ready(self):
|
|
"""Instructions to execute when the cog is ready."""
|
|
|
|
# if not self.bot.developing:
|
|
self.rss_task.start()
|
|
|
|
log.info("%s cog is ready", self.__class__.__name__)
|
|
|
|
@commands.Cog.listener(name="cog_unload")
|
|
async def on_unload(self):
|
|
"""Instructions to execute before the cog is unloaded."""
|
|
|
|
self.rss_task.cancel()
|
|
|
|
@app_commands.command(name="debug-trigger-task")
|
|
async def debug_trigger_task(self, inter):
|
|
await inter.response.defer()
|
|
await self.rss_task()
|
|
await inter.followup.send("done")
|
|
|
|
@tasks.loop(time=times)
|
|
async def rss_task(self):
|
|
"""Automated task responsible for processing rss feeds."""
|
|
|
|
log.info("Running subscription task")
|
|
time = process_time()
|
|
|
|
guild_ids = [guild.id for guild in self.bot.guilds]
|
|
data = []
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
api = API(self.bot.api_token, session)
|
|
page = 0
|
|
|
|
while True:
|
|
page += 1
|
|
page_data = await self.get_subscriptions(api, guild_ids, page)
|
|
|
|
if not page_data:
|
|
break
|
|
|
|
data.extend(page_data)
|
|
log.debug("extending data by '%s' items", len(page_data))
|
|
|
|
log.debug("finished api data collection, browsed %s pages for %s subscriptions", page, len(data))
|
|
|
|
subscriptions = Subscription.from_list(data)
|
|
for sub in subscriptions:
|
|
await self.process_subscription(api, session, sub)
|
|
|
|
log.info("Finished subscription task, time elapsed: %s", process_time() - time)
|
|
|
|
async def get_subscriptions(self, api, guild_ids: list[int], page: int):
|
|
|
|
log.debug("attempting to get subscriptions for page: %s", page)
|
|
|
|
try:
|
|
return (await api.get_subscriptions(server__in=guild_ids, page=page))[0]
|
|
except aiohttp.ClientResponseError as error:
|
|
if error.status == 404:
|
|
log.debug(error)
|
|
return []
|
|
|
|
log.error(error)
|
|
|
|
async def process_subscription(self, api: API, session: aiohttp.ClientSession, sub: Subscription):
|
|
"""
|
|
Process a given Subscription.
|
|
"""
|
|
|
|
log.debug("processing subscription '%s' '%s' for '%s'", sub.id, sub.name, sub.guild_id)
|
|
|
|
if not sub.active:
|
|
log.debug("skipping sub because it's active flag is 'False'")
|
|
return
|
|
|
|
channels: list[TextChannel] = [self.bot.get_channel(subchannel.channel_id) for subchannel in await sub.get_channels(api)]
|
|
if not channels:
|
|
log.warning("No channels to send this to")
|
|
return
|
|
|
|
filters = [await api.get_filter(filter_id) for filter_id in sub.filters]
|
|
log.debug("found %s filter(s)", len(filters))
|
|
|
|
unparsed_content = await get_unparsed_feed(sub.url, session)
|
|
parsed_content = parse(unparsed_content)
|
|
source = Source.from_parsed(parsed_content)
|
|
articles = source.get_latest_articles(10)
|
|
articles.reverse()
|
|
|
|
if not articles:
|
|
log.debug("No articles found")
|
|
|
|
embeds = await self.get_articles_as_embeds(api, session, sub.id, sub.mutators, filters, articles, Colour.from_str("#" + sub.embed_colour))
|
|
await self.send_embeds_in_chunks(embeds, channels)
|
|
|
|
async def get_articles_as_embeds(
|
|
self,
|
|
api: API,
|
|
session: aiohttp.ClientSession,
|
|
sub_id: int,
|
|
mutators: dict[str, list[dict]],
|
|
filters: list[dict],
|
|
articles: list[Article],
|
|
embed_colour: str
|
|
) -> list[Embed]:
|
|
"""
|
|
Process articles and return their respective embeds.
|
|
"""
|
|
|
|
embeds = []
|
|
for article in articles:
|
|
embed = await self.process_article(api, session, sub_id, mutators, filters, article, embed_colour)
|
|
if embed:
|
|
embeds.append(embed)
|
|
|
|
return embeds
|
|
|
|
async def send_embeds_in_chunks(self, embeds: list[Embed], channels: list[TextChannel], embed_limit=10):
|
|
"""
|
|
Send embeds to a list of `TextChannel` in chunks of `embed_limit` size.
|
|
"""
|
|
|
|
log.debug("about to send %s embeds")
|
|
|
|
for i in range(0, len(embeds), embed_limit):
|
|
embeds_chunk = embeds[i:i + embed_limit]
|
|
|
|
log.debug("sending chunk of %s embeds", len(embeds_chunk))
|
|
|
|
for channel in channels:
|
|
await self.try_send_embeds(embeds, channel)
|
|
|
|
async def try_send_embeds(self, embeds: list[Embed], channel: TextChannel):
|
|
"""
|
|
Attempt to send embeds to a given `TextChannel`. Gracefully handles errors.
|
|
"""
|
|
|
|
try:
|
|
await channel.send(embeds=embeds)
|
|
|
|
except Forbidden:
|
|
log.debug(
|
|
"Forbidden from sending embed to channel '%s', guild '%s'",
|
|
channel.id, channel.guild.id
|
|
)
|
|
|
|
except Exception as exc:
|
|
log.error(exc)
|
|
|
|
async def process_article(
|
|
self,
|
|
api: API,
|
|
session: aiohttp.ClientSession,
|
|
sub_id: int,
|
|
mutators: dict[str, list[dict]],
|
|
filters: list[dict],
|
|
article: Article,
|
|
embed_colour: str
|
|
) -> Embed | None:
|
|
"""
|
|
Process a given Article.
|
|
Returns an Embed representing the given Article.
|
|
"""
|
|
|
|
log.debug("processing article '%s' '%s'", article.guid, article.title)
|
|
|
|
blocked = any(self.filter_article(_filter, article) for _filter in filters)
|
|
log.debug("filter result: %s", "blocked" if blocked else "ok")
|
|
|
|
self.mutate_article(article, mutators)
|
|
|
|
try:
|
|
await api.create_tracked_content(
|
|
guid=article.guid,
|
|
title=article.title,
|
|
url=article.url,
|
|
subscription=sub_id,
|
|
blocked=blocked,
|
|
channel_id="-_-"
|
|
)
|
|
log.debug("successfully tracked %s", article.guid)
|
|
|
|
except aiohttp.ClientResponseError as error:
|
|
if error.status == 409:
|
|
log.debug("It looks like this article already exists, skipping")
|
|
else:
|
|
log.error(error)
|
|
|
|
return
|
|
|
|
if not blocked:
|
|
return await article.to_embed(session, embed_colour)
|
|
|
|
def mutate_article(self, article: Article, mutators: list[dict]):
|
|
|
|
for mutator in mutators["title"]:
|
|
article.mutate("title", mutator)
|
|
|
|
for mutator in mutators["desc"]:
|
|
article.mutate("description", mutator)
|
|
|
|
def filter_article(self, _filter: dict, article: Article) -> bool:
|
|
"""
|
|
Returns True if article should be ignored due to filters.
|
|
"""
|
|
|
|
match_found = False # This is the flag to determine if the content should be filtered
|
|
|
|
keywords = _filter["keywords"].split(",")
|
|
regex_pattern = _filter["regex"]
|
|
is_whitelist = _filter["whitelist"]
|
|
|
|
log.debug(
|
|
"trying filter '%s', keyword '%s', regex '%s', is whitelist: '%s'",
|
|
_filter["name"], keywords, regex_pattern, is_whitelist
|
|
)
|
|
|
|
assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used."
|
|
|
|
if any(word in article.title or word in article.description for word in keywords):
|
|
match_found = True
|
|
|
|
if regex_pattern:
|
|
regex = re.compile(regex_pattern)
|
|
match_found = regex.search(article.title) or regex.search(article.description)
|
|
|
|
return not match_found if is_whitelist else match_found
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Setup function for this extension.
|
|
Adds `TaskCog` to the bot.
|
|
"""
|
|
|
|
cog = TaskCog(bot)
|
|
await bot.add_cog(cog)
|
|
log.info("Added %s cog", cog.__class__.__name__)
|