diff --git a/src/bot.py b/src/bot.py index 1554c1c..55ec02a 100644 --- a/src/bot.py +++ b/src/bot.py @@ -7,10 +7,6 @@ from pathlib import Path from discord import Intents, Game from discord.ext import commands -from sqlalchemy import insert - -from feed import Functions -from db import DatabaseManager, AuditModel log = logging.getLogger(__name__) @@ -20,7 +16,6 @@ class DiscordBot(commands.Bot): def __init__(self, BASE_DIR: Path, developing: bool, api_token: str): activity = Game("Indev") if developing else None super().__init__(command_prefix="-", intents=Intents.all(), activity=activity) - self.functions = Functions(self, api_token) self.BASE_DIR = BASE_DIR self.developing = developing self.api_token = api_token @@ -52,27 +47,3 @@ class DiscordBot(commands.Bot): for path in (self.BASE_DIR / "src/extensions").iterdir(): if path.suffix == ".py": await self.load_extension(f"extensions.{path.stem}") - - async def audit(self, message: str, user_id: int, database: DatabaseManager=None): - """Shorthand for auditing an action. - - Parameters - ---------- - message : str - The message to be audited. - user_id : int - Discord ID of the user being audited. - database : DatabaseManager, optional - An existing database connection to be used if specified, by default None - """ - - query = insert(AuditModel).values(discord_user_id=user_id, message=message) - - log.debug("Logging audit") - - if database: - await database.session.execute(query) - return - - async with DatabaseManager() as database: - await database.session.execute(query) diff --git a/src/db/__init__.py b/src/db/__init__.py deleted file mode 100644 index cdb3799..0000000 --- a/src/db/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Initialize the database modules, create the database tables and default data. -""" - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from .models import Base, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel -from .db import DatabaseManager - -# # Initialise a database session -# engine = create_engine(DatabaseManager.get_database_url(use_async=False)) -# session = sessionmaker(bind=engine)() - -# # Create tables if not exists -# Base.metadata.create_all(engine) - -# session.commit() -# session.close() - diff --git a/src/db/db.py b/src/db/db.py deleted file mode 100644 index 9821310..0000000 --- a/src/db/db.py +++ /dev/null @@ -1,77 +0,0 @@ -""" -Database Manager -""" - -import logging -from os import getenv - -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker - -DB_TYPE = getenv("DB_TYPE", default="sqlite") -DB_HOST = getenv("DB_HOST", default="db.sqlite") -DB_PORT = getenv("DB_PORT") -DB_USERNAME = getenv("DB_USERNAME") -DB_PASSWORD = getenv("DB_PASSWORD") -DB_DATABASE = getenv("DB_DATABASE") - -log = logging.getLogger(__name__) - - -class DatabaseManager: - """ - Asynchronous database context manager. - """ - - def __init__(self, no_commit: bool = False): - database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it? - self.engine = create_async_engine(database_url, future=True) - self.session_maker = sessionmaker(self.engine, class_=AsyncSession) - self.session = None - self.no_commit = no_commit - - @staticmethod - def get_database_url(use_async=True): - """ - Returns a connection string for the database. - """ - - url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}" - url_addon = "" - - # This looks fucking ugly - match use_async, DB_TYPE: - case True, "sqlite": - url_addon = "aiosqlite" - - case True, "postgresql": - url_addon = "asyncpg" - - case False, "sqlite": - pass - - case False, "postgresql": - pass - - case _, _: - raise ValueError(f"Unknown Database Type: {DB_TYPE}") - - url = url.replace(":/", f"+{url_addon}:/") if url_addon else url - - - return url - - - async def __aenter__(self): - self.session = self.session_maker() - log.debug("Database connection open") - return self - - async def __aexit__(self, *_): - if not self.no_commit: - await self.session.commit() - - await self.session.close() - self.session = None - await self.engine.dispose() - log.debug("Database connection closed") diff --git a/src/db/models.py b/src/db/models.py deleted file mode 100644 index 9831a7c..0000000 --- a/src/db/models.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Models and Enums for the database. -All table classes should be suffixed with `Model`. -""" - -from sqlalchemy.sql import func -from sqlalchemy.orm import relationship, declarative_base -from sqlalchemy import ( - Column, - Integer, - String, - DateTime, - BigInteger, - UniqueConstraint, - ForeignKey -) - -Base = declarative_base() - -# back in wed, thu, fri, off new year day then back in after - -class AuditModel(Base): - """ - Table for taking audits. - """ - - __tablename__ = "audit" - - id = Column(Integer, primary_key=True, autoincrement=True) - discord_user_id = Column(BigInteger, nullable=False) - # discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it. - message = Column(String, nullable=False) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - - -class SentArticleModel(Base): - """ - Table for tracking articles that have been served by the bot. - """ - - __tablename__ = "sent_articles" - - id = Column(Integer, primary_key=True, autoincrement=True) - discord_message_id = Column(BigInteger, nullable=False) - discord_channel_id = Column(BigInteger, nullable=False) - discord_server_id = Column(BigInteger, nullable=False) - article_url = Column(String, nullable=False) - when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - - feed_channel_id = Column(Integer, ForeignKey("feed_channel.id", ondelete="CASCADE"), nullable=False) - feed_channel = relationship("FeedChannelModel", overlaps="sent_articles", lazy="joined", cascade="all, delete") - - -class RssSourceModel(Base): - """ - Table for user submitted news feeds. - """ - - __tablename__ = "rss_source" - - id = Column(Integer, primary_key=True, autoincrement=True) - nick = Column(String, nullable=False) - discord_server_id = Column(BigInteger, nullable=False) - rss_url = Column(String, nullable=False) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - - feed_channels = relationship("FeedChannelModel", cascade="all, delete") - - # the nickname must be unique, but only within the same discord server - __table_args__ = ( - UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'), - ) - - -class FeedChannelModel(Base): - """ - Table representing discord channels to be used for news feeds. - """ - - __tablename__ = "feed_channel" - - id = Column(Integer, primary_key=True, autoincrement=True) - discord_channel_id = Column(BigInteger, nullable=False) - discord_server_id = Column(BigInteger, nullable=False) - search_name = Column(String, nullable=False) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 - - sent_articles = relationship("SentArticleModel", cascade="all, delete") - - rss_source_id = Column(Integer, ForeignKey('rss_source.id', ondelete="CASCADE"), nullable=False) - rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete") - - # the rss source must be unique, but only within the same discord channel - __table_args__ = ( - UniqueConstraint('rss_source_id', 'discord_channel_id', name='uq_rss_discord_channel'), - ) diff --git a/src/errors.py b/src/errors.py index 0066a15..2970bf4 100644 --- a/src/errors.py +++ b/src/errors.py @@ -1,10 +1,4 @@ -class IllegalFeed(Exception): - def __init__(self, message: str, **items): - super().__init__(message) - self.items = items - - class TokenMissingError(Exception): """ Exception to indicate a token couldnt be found. diff --git a/src/extensions/new_tasks.py b/src/extensions/new_tasks.py deleted file mode 100644 index 998cace..0000000 --- a/src/extensions/new_tasks.py +++ /dev/null @@ -1,152 +0,0 @@ -""" - -task flow - -=-=-=-=-= - -1. every 10 minutes -2. get all bot guild_ids -3. get all subscriptions against known guild ids (account for pagination of 25) -4. iterate through and process each subscription -5. get all articles from subscription -6. iterate through and process each article - -""" - - -class Cog: - - async def subscription_task(): - - subscriptions = self.get_subscriptions() - - for subscription in subscriptions: - await self.process_subscription(subscription) - - async def get_subscriptions(): - # return subs from api, handle pagination - - async def process_subscription(subscription): - - articles = subscription.get_articles() - - for article in articles: - await self.process_article(article) - - async def process_article(article): - # validate article then: - - if await self.track_article(article): - await self.send_article(article) - - - async def track_article(): - pass - - async def send_article(): - pass - - -class TaskCog(commands.Cog): - """ - Tasks cog for PYRSS. - """ - - def __init__(self, bot): - super().__init__() - self.bot = bot - - @tasks.loop(time=times) - async def subscription_task(self): - async with aiohttp.ClientSession() as session: - api = API(self.bot.api_token, session) - subscriptions = await self.get_subscriptions(api) - await self.process_subscriptions(api, subscriptions) - # articles = [*(await self.get_articles(api, sub)) for sub in subscriptions] - # await self.send_articles(api, articles) - - async def get_subscriptions(self, api) -> list: - guild_ids = [guild.id for guild in self.bot.guilds] - subscriptions = [] - - for page in iter(int, 1): - try: - page_data = (await api.get_subscriptions(server__in=guild_ids, page=page+1))[0] - except apihttp.ClientResponseError as error: - if error.status == 404: - break - - except Exception as error: - log.error("Exception while gathering page data %s", error) - break - - subscriptions.extend(page_data) - - async def process_subscriptions(self, api, subscriptions): - for sub in subscriptions: - if not sub.active or not sub.channel_count: - continue - - unparsed_feed = await get_unparsed_feed(sub.url, api.session) - parsed_feed = await parse(unparsed_feed) - - rss_feed = RSSFeed.from_parsed_feed(parsed_feed) - await self.process_items(api, sub, rss_feed) - - async def process_items(self, api, sub, feed): - - channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels()] - filters = [await api.get_filter(filter_id) for filter_id in sub.filters] - - for item in sub.items: - blocked = any(self.filter_item(_filter, item) for _filter in filters) - mutated_item = item.create_mutated_copy(sub.mutators) - - for channel in channels: - await self.mark_tracked_item(api, sub, item, channel.id, blocked) - - if not blocked: - channel.send(embed=item.to_embed(sub, feed)) - - async def filter_item(self, _filter: dict, item: RSSItem) -> bool: - """ - Returns True if item should be ignored due to filters. - """ - - match_found = False # This is the flag to determine if the content should be filtered - - keywords = _filter["keywords"].split(",") - regex_pattern = _filter["regex"] - is_whitelist = _filter["whitelist"] - - log.debug( - "trying filter '%s', keyword '%s', regex '%s', is whitelist: '%s'", - _filter["name"], keywords, regex_pattern, is_whitelist - ) - - assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used." - - if any(word in item.title or word in item.description for word in keywords): - match_found = True - - if regex_pattern: - regex = re.compile(regex_pattern) - match_found = regex.search(item.title) or regex.search(item.description) - - return not match_found if is_whitelist else match_found - - async def mark_tracked_item(self, sub, item, channel_id, blocked): - try: - api.create_tracked_content( - guid=item.guid, - title=item.title, - url=item.url, - subscription=sub.id, - channel_id=channel_id, - blocked=blocked - ) - except aiohttp.ClientResponseError as error: - if error.status == 409: - log.debug(error) - else: - log.error(error) diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 7634fe9..72cda50 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -4,7 +4,6 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo """ import re -import json import logging import datetime from os import getenv @@ -15,249 +14,129 @@ from discord import TextChannel, Embed, Colour from discord import app_commands from discord.ext import commands, tasks from discord.errors import Forbidden -from sqlalchemy import insert, select, and_ from feedparser import parse -from feed import Source, Article, RSSFeed, Subscription, SubscriptionChannel, SubChannel -from db import ( - DatabaseManager, - FeedChannelModel, - RssSourceModel, - SentArticleModel -) +from feed import RSSFeed, Subscription, RSSItem from utils import get_unparsed_feed from api import API log = logging.getLogger(__name__) TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES") - -# task trigger times : must be of type list -times = [ +subscription_task_times = [ datetime.time(hour, minute, tzinfo=datetime.timezone.utc) for hour in range(24) for minute in range(0, 60, int(TASK_INTERVAL_MINUTES)) ] - log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES) class TaskCog(commands.Cog): """ - Tasks cog. + Tasks cog for PYRSS. """ def __init__(self, bot): super().__init__() self.bot = bot - self.time = None @commands.Cog.listener() async def on_ready(self): - """Instructions to execute when the cog is ready.""" - - # if not self.bot.developing: - self.rss_task.start() - + """ + Instructions to execute when the cog is ready. + """ + self.subscription_task.start() log.info("%s cog is ready", self.__class__.__name__) @commands.Cog.listener(name="cog_unload") async def on_unload(self): - """Instructions to execute before the cog is unloaded.""" - - self.rss_task.cancel() + """ + Instructions to execute before the cog is unloaded. + """ + self.subscription_task.cancel() @app_commands.command(name="debug-trigger-task") async def debug_trigger_task(self, inter): await inter.response.defer() - await self.rss_task() + await self.subscription_task() await inter.followup.send("done") - @tasks.loop(time=times) - async def rss_task(self): - """Automated task responsible for processing rss feeds.""" - + @tasks.loop(time=subscription_task_times) + async def subscription_task(self): + """ + Task for fetching and processing subscriptions. + """ log.info("Running subscription task") - time = process_time() - - guild_ids = [guild.id for guild in self.bot.guilds] - data = [] async with aiohttp.ClientSession() as session: api = API(self.bot.api_token, session) - page = 0 + subscriptions = await self.get_subscriptions(api) + await self.process_subscriptions(api, subscriptions) - while True: - page += 1 - page_data = await self.get_subscriptions(api, guild_ids, page) + async def get_subscriptions(self, api: API) -> list[Subscription]: + guild_ids = [guild.id for guild in self.bot.guilds] + sub_data = [] - if not page_data: - break + for page, _ in enumerate(iter(int, 1)): + try: + log.debug("fetching page '%s'", page + 1) + sub_data.extend( + (await api.get_subscriptions(server__in=guild_ids, page=page+1))[0] + ) + except aiohttp.ClientResponseError as error: + match error.status: + case 404: + log.debug("final page reached '%s'", page) + break + case 403: + log.critical(error) + self.subscription_task.cancel() + return [] # returning an empty list should gracefully end the task + case _: + log.error(error) + break - data.extend(page_data) - log.debug("extending data by '%s' items", len(page_data)) + except Exception as error: + log.error("Exception while gathering page data %s", error) + break - log.debug("finished api data collection, browsed %s pages for %s subscriptions", page, len(data)) - subscriptions = Subscription.from_list(data) - for sub in subscriptions: - await self.process_subscription(api, session, sub) + return Subscription.from_list(sub_data) - log.info("Finished subscription task, time elapsed: %s", process_time() - time) + async def process_subscriptions(self, api: API, subscriptions: list[Subscription]): + for sub in subscriptions: + log.debug("processing subscription '%s'", sub.id) - async def get_subscriptions(self, api, guild_ids: list[int], page: int): + if not sub.active or not sub.channels_count: + continue - log.debug("attempting to get subscriptions for page: %s", page) + unparsed_feed = await get_unparsed_feed(sub.url, api.session) + parsed_feed = parse(unparsed_feed) - try: - return (await api.get_subscriptions(server__in=guild_ids, page=page))[0] - except aiohttp.ClientResponseError as error: - if error.status == 404: - log.debug(error) - return [] + rss_feed = RSSFeed.from_parsed_feed(parsed_feed) + await self.process_items(api, sub, rss_feed) - log.error(error) - - async def process_subscription(self, api: API, session: aiohttp.ClientSession, sub: Subscription): - """ - Process a given Subscription. - """ - - log.debug("processing subscription '%s' '%s' for '%s'", sub.id, sub.name, sub.guild_id) - - if not sub.active: - log.debug("skipping sub because it's active flag is 'False'") - return - - channels: list[TextChannel] = [self.bot.get_channel(subchannel.channel_id) for subchannel in await sub.get_channels(api)] - if not channels: - log.warning("No channels to send this to") - return + async def process_items(self, api: API, sub: Subscription, feed: RSSFeed): + log.debug("processing items") + channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(api)] filters = [await api.get_filter(filter_id) for filter_id in sub.filters] - log.debug("found %s filter(s)", len(filters)) - unparsed_content = await get_unparsed_feed(sub.url, session) - parsed_content = parse(unparsed_content) - source = Source.from_parsed(parsed_content) - articles = source.get_latest_articles(10) - articles.reverse() + for item in feed.items: + log.debug("processing item '%s'", item.guid) - if not articles: - log.debug("No articles found") - - embeds = await self.get_articles_as_embeds(api, session, sub.id, sub.mutators, filters, articles, Colour.from_str("#" + sub.embed_colour)) - await self.send_embeds_in_chunks(embeds, channels) - - async def get_articles_as_embeds( - self, - api: API, - session: aiohttp.ClientSession, - sub_id: int, - mutators: dict[str, list[dict]], - filters: list[dict], - articles: list[Article], - embed_colour: str - ) -> list[Embed]: - """ - Process articles and return their respective embeds. - """ - - embeds = [] - for article in articles: - embed = await self.process_article(api, session, sub_id, mutators, filters, article, embed_colour) - if embed: - embeds.append(embed) - - return embeds - - async def send_embeds_in_chunks(self, embeds: list[Embed], channels: list[TextChannel], embed_limit=10): - """ - Send embeds to a list of `TextChannel` in chunks of `embed_limit` size. - """ - - log.debug("about to send %s embeds") - - for i in range(0, len(embeds), embed_limit): - embeds_chunk = embeds[i:i + embed_limit] - - log.debug("sending chunk of %s embeds", len(embeds_chunk)) + blocked = any(self.filter_item(_filter, item) for _filter in filters) + mutated_item = item.create_mutated_copy(sub.mutators) for channel in channels: - await self.try_send_embeds(embeds, channel) + successful_track = await self.mark_tracked_item(api, sub, item, channel.id, blocked) - async def try_send_embeds(self, embeds: list[Embed], channel: TextChannel): + if successful_track and not blocked: + await channel.send(embed=await item.to_embed(sub, feed, api.session)) + + def filter_item(self, _filter: dict, item: RSSItem) -> bool: """ - Attempt to send embeds to a given `TextChannel`. Gracefully handles errors. - """ - - try: - await channel.send(embeds=embeds) - - except Forbidden: - log.debug( - "Forbidden from sending embed to channel '%s', guild '%s'", - channel.id, channel.guild.id - ) - - except Exception as exc: - log.error(exc) - - async def process_article( - self, - api: API, - session: aiohttp.ClientSession, - sub_id: int, - mutators: dict[str, list[dict]], - filters: list[dict], - article: Article, - embed_colour: str - ) -> Embed | None: - """ - Process a given Article. - Returns an Embed representing the given Article. - """ - - log.debug("processing article '%s' '%s'", article.guid, article.title) - - blocked = any(self.filter_article(_filter, article) for _filter in filters) - log.debug("filter result: %s", "blocked" if blocked else "ok") - - self.mutate_article(article, mutators) - - try: - await api.create_tracked_content( - guid=article.guid, - title=article.title, - url=article.url, - subscription=sub_id, - blocked=blocked, - channel_id="-_-" - ) - log.debug("successfully tracked %s", article.guid) - - except aiohttp.ClientResponseError as error: - if error.status == 409: - log.debug("It looks like this article already exists, skipping") - else: - log.error(error) - - return - - if not blocked: - return await article.to_embed(session, embed_colour) - - def mutate_article(self, article: Article, mutators: list[dict]): - - for mutator in mutators["title"]: - article.mutate("title", mutator) - - for mutator in mutators["desc"]: - article.mutate("description", mutator) - - def filter_article(self, _filter: dict, article: Article) -> bool: - """ - Returns True if article should be ignored due to filters. + Returns True if item should be ignored due to filters. """ match_found = False # This is the flag to determine if the content should be filtered @@ -273,15 +152,35 @@ class TaskCog(commands.Cog): assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used." - if any(word in article.title or word in article.description for word in keywords): - match_found = True - if regex_pattern: regex = re.compile(regex_pattern) - match_found = regex.search(article.title) or regex.search(article.description) + match_found = regex.search(item.title) or regex.search(item.description) + else: + match_found = any(word in item.title or word in item.description for word in keywords) return not match_found if is_whitelist else match_found + async def mark_tracked_item(self, api: API, sub: Subscription, item: RSSItem, channel_id: int, blocked: bool): + try: + log.debug("marking as tracked 'blocked: %s'", blocked) + await api.create_tracked_content( + guid=item.guid, + title=item.title, + url=item.link, + subscription=sub.id, + channel_id=channel_id, + blocked=blocked + ) + return True + except aiohttp.ClientResponseError as error: + if error.status == 409: + log.debug(error) + else: + log.error(error) + + return False + + async def setup(bot): """ Setup function for this extension. diff --git a/src/feed.py b/src/feed.py index 23e5660..77e69ae 100644 --- a/src/feed.py +++ b/src/feed.py @@ -4,7 +4,6 @@ import logging from dataclasses import dataclass from datetime import datetime from abc import ABC, abstractmethod -from random import shuffle, sample import aiohttp import validators @@ -12,30 +11,25 @@ from discord import Embed, Colour from bs4 import BeautifulSoup as bs4 from feedparser import FeedParserDict, parse from markdownify import markdownify -from sqlalchemy import select, insert, delete, and_ -from sqlalchemy.exc import NoResultFound from textwrap import shorten +from feedparser import parse from mutators import registry as mutator_registry -from errors import IllegalFeed -from db import DatabaseManager, RssSourceModel, FeedChannelModel -from utils import get_rss_data, get_unparsed_feed +from utils import get_unparsed_feed from api import API log = logging.getLogger(__name__) dumps = lambda _dict: json.dumps(_dict, indent=8) -from xml.etree.ElementTree import Element, SubElement, tostring -from feedparser import parse - class RSSItem: - def __init__(self, title, link, description, pub_date, guid): + def __init__(self, title, link, description, pub_date, guid, image_url): self.title = title self.link = link self.description = description self.pub_date = pub_date self.guid = guid + self.image_url = image_url def __str__(self): return self.title @@ -43,7 +37,7 @@ class RSSItem: def create_mutated_copy(self, mutators): pass - def create_embed(self, sub, feed): + async def to_embed(self, sub, feed, session): title = shorten(markdownify(self.title, strip=["img", "a"]), 256) desc = shorten(markdownify(self.description, strip=["img"]), 4096) @@ -60,7 +54,6 @@ class RSSItem: # validate urls # author_url = self.source.url if validators.url(self.source.url) else None # icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None - # thumb_url = await self.get_thumbnail_url(session) # validation done inside func # Combined length validation # Can't exceed combined 6000 characters, [400 Bad Request] if failed. @@ -71,27 +64,64 @@ class RSSItem: embed = Embed( title=title, description=desc, - timestamp=self.published, + # timestamp=self.published, url=self.link if validators.url(self.link) else None, - colour=colour + colour=Colour.from_str("#" + sub.embed_colour) ) - # embed.set_thumbnail(url=icon_url) + log.debug("has image url without search: '%s'", self.image_url) + + if sub.article_fetch_image: + embed.set_image(url=self.image_url or await self.get_thumbnail_url(session)) + embed.set_thumbnail(url=feed.image_href if validators.url(feed.image_href) else None) + # embed.set_image(url=thumb_url) # embed.set_author(url=author_url, name=author) # embed.set_footer(text=self.author) return embed + async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None: + """Returns the thumbnail URL for an article. + Parameters + ---------- + session : aiohttp.ClientSession + A client session used to get the thumbnail. + + Returns + ------- + str or None + The thumbnail URL, or None if not found. + """ + + # log.debug("Fetching thumbnail for article: %s", self) + + try: + async with session.get(self.link, timeout=15) as response: + html = await response.text() + except aiohttp.InvalidURL as error: + log.error("invalid thumbnail url: %s", error) + return None + + soup = bs4(html, "html.parser") + image_element = soup.select_one("meta[property='og:image']") + if not image_element: + return None + + image_content = image_element.get("content") + return image_content if validators.url(image_content) else None + +import json class RSSFeed: - def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None): + def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None, image_href=None): self.title = title self.link = link self.description = description self.language = language self.pub_date = pub_date self.last_build_date = last_build_date + self.image_href = image_href self.items = [] def add_item(self, item: RSSItem): @@ -111,8 +141,9 @@ class RSSFeed: language = parsed_feed.feed.get('language', 'en-gb') pub_date = parsed_feed.feed.get('published', None) last_build_date = parsed_feed.feed.get('updated', None) + image_href = parsed_feed.get("image", {}).get("href") - feed = cls(title, link, description, language, pub_date, last_build_date) + feed = cls(title, link, description, language, pub_date, last_build_date, image_href) for entry in parsed_feed.entries: item_title = entry.get('title', 'No title') @@ -120,10 +151,11 @@ class RSSFeed: item_description = entry.get('description', 'No description') item_pub_data = entry.get('published_parsed', None) item_guid = entry.get('id', None) or entry.get("guid", None) + item_image_url = entry.get("media_content", [{}])[0].get("url") - item_published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None + # item_published = datetime(*entry.published_parsed[0:-2]) if item_pub_data else None - item = RSSItem(item_title, item_link, item_description, item_published, item_guid) + item = RSSItem(item_title, item_link, item_description, item_pub_data, item_guid, item_image_url) feed.add_item(item) feed.items.reverse() @@ -131,244 +163,6 @@ class RSSFeed: return feed - -@dataclass -class RSSArticle: - """Represents a news article, or entry from an RSS feed.""" - - guid: str - title: str | None - description: str | None - url: str | None - published: datetime | None - author: str | None - source: object - - @classmethod - def from_entry(cls, source, entry:FeedParserDict): - """Create an Article from an RSS feed entry. - - Parameters - ---------- - entry : FeedParserDict - An entry pulled from a complete FeedParserDict object. - - Returns - ------- - Article - The Article created from the feed entry. - """ - - # log.debug("Creating Article from entry: %s", dumps(entry)) - - published_parsed = entry.get("published_parsed") - published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None - - return cls( - guid=entry["guid"], - title=entry.get("title"), - description=entry.get("description"), - url=entry.get("link"), - published=published, - author = entry.get("author"), - source=source - ) - - def mutate(self, attr: str, mutator: dict[str, str]): - """ - Apply a mutation to a certain text attribute of - this Article instance. - """ - - # WARN: - # This could be really bad if the end user is able to effect the 'attr' value. - # Shouldn't happen though. - - log.debug("applying mutator '%s'", mutator["name"]) - val = mutator["value"] - - try: - mutator = mutator_registry.get_mutator(val) - except ValueError as err: - log.error(err) - - setattr(self, attr, mutator.mutate(getattr(self, attr))) - log.debug("mutated %s, to: %s", attr, getattr(self, attr)) - - - async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None: - """Returns the thumbnail URL for an article. - - Parameters - ---------- - session : aiohttp.ClientSession - A client session used to get the thumbnail. - - Returns - ------- - str or None - The thumbnail URL, or None if not found. - """ - - # log.debug("Fetching thumbnail for article: %s", self) - - try: - async with session.get(self.url, timeout=15) as response: - html = await response.text() - except aiohttp.InvalidURL as error: - log.error("invalid thumbnail url: %s", error) - return None - - soup = bs4(html, "html.parser") - image_element = soup.select_one("meta[property='og:image']") - if not image_element: - return None - - image_content = image_element.get("content") - return image_content if validators.url(image_content) else None - - async def to_embed(self, session: aiohttp.ClientSession, colour: Colour) -> Embed: - """Creates and returns a Discord Embed object from the article. - - Parameters - ---------- - session : aiohttp.ClientSession - A client session used to get additional article data. - - Returns - ------- - Embed - A Discord Embed object representing the article. - """ - - # log.debug(f"Creating embed from article: {self}") - - # Replace HTML with Markdown, and shorten text. - title = shorten(markdownify(self.title, strip=["img", "a"]), 256) - desc = shorten(markdownify(self.description, strip=["img"]), 4096) - author = shorten(self.source.name, 256) - - # validate urls - embed_url = self.url if validators.url(self.url) else None - author_url = self.source.url if validators.url(self.source.url) else None - icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None - thumb_url = await self.get_thumbnail_url(session) # validation done inside func - - # Combined length validation - # Can't exceed combined 6000 characters, [400 Bad Request] if failed. - combined_length = len(title) + len(desc) + (len(author) * 2) - cutoff = combined_length - 6000 - desc = shorten(desc, cutoff) if cutoff > 0 else desc - - embed = Embed( - title=title, - description=desc, - timestamp=self.published, - url=embed_url, - colour=colour - ) - - embed.set_thumbnail(url=icon_url) - embed.set_image(url=thumb_url) - embed.set_author(url=author_url, name=author) - embed.set_footer(text=self.author) - - return embed - - -@dataclass -class RSSFeedSource: - """Represents an RSS Feed.""" - - name: str | None - description: str | None - url: str | None - icon_url: str | None - feed: FeedParserDict - - @classmethod - def from_parsed(cls, feed:FeedParserDict): - """Returns a Source object from a parsed feed. - - Parameters - ---------- - feed : FeedParserDict - The feed used to create the Source. - - Returns - ------- - Source - The Source object - """ - - # log.debug("Creating Source from feed: %s", dumps(feed)) - - channel = feed.get("channel", {}) - - return cls( - name=channel.get("title"), - description=channel.get("description"), - url=channel.get("link"), - icon_url=feed.get("feed", {}).get("image", {}).get("href"), - feed=feed - ) - - @classmethod - async def from_url(cls, url: str): - unparsed_content = await get_unparsed_feed(url) - source = cls.from_parsed(parse(unparsed_content)) - source.url = url - return source - - def get_latest_articles(self, max: int = 999) -> list[Article]: - """Returns a list of Article objects. - - Parameters - ---------- - max : int - The maximum number of articles to return. - - Returns - ------- - list of Article - A list of Article objects. - """ - - # log.debug("Fetching latest articles from %s, max=%s", self, max) - - return [ - Article.from_entry(self, entry) - for i, entry in enumerate(self.feed.entries) - if i < max - ] - - -@dataclass -class RSSFeedSource_: - - uuid: str - name: str - url: str - image: str - discord_server_id: id - created_at: str - - @classmethod - def from_list(cls, data: list) -> list: - result = [] - - for item in data: - key = "discord_server_id" - item[key] = int(item.get(key)) - result.append(cls(**item)) - - return result - - @classmethod - def from_dict(cls, data: dict): - return cls(**data) - - DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -401,6 +195,7 @@ class Subscription(DjangoDataModel): extra_notes: str filters: list[int] mutators: dict[str, list[dict]] + article_fetch_image: bool embed_colour: str active: bool channels_count: int @@ -488,143 +283,3 @@ class TrackedContent(DjangoDataModel): item["creation_datetime"] = datetime.strptime(item["creation_datetime"], DATETIME_FORMAT) return item - - -class Functions: - - def __init__(self, bot, api_token: str): - self.bot = bot - self.api_token = api_token - - async def validate_feed(self, nickname: str, url: str) -> FeedParserDict: - """Validates a feed based on the given nickname and url. - - Parameters - ---------- - nickname : str - Human readable nickname used to refer to the feed. - url : str - URL to fetch content from the feed. - - Returns - ------- - FeedParserDict - A Parsed Dictionary of the feed. - - Raises - ------ - IllegalFeed - If the feed is invalid. - """ - - # Ensure the URL is valid - if not validators.url(url): - raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`") - - # Check the nickname is not a URL - if validators.url(nickname): - raise IllegalFeed( - "It looks like the nickname you have entered is a URL.\n" \ - "For security reasons, this is not allowed.", - nickname=nickname - ) - - feed_data, status_code = await get_rss_data(url) - - if status_code != 200: - raise IllegalFeed( - "The URL provided returned an invalid status code:", - url=url, status_code=status_code - ) - - # Check the contents is actually an RSS feed. - feed = parse(feed_data) - if not feed.version: - raise IllegalFeed( - "The provided URL does not seem to be a valid RSS feed.", - url=url - ) - - return feed - - async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed: - - - log.info("Creating new Feed: %s", name) - - parsed_feed = await self.validate_feed(name, url) - source = Source.from_parsed(parsed_feed) - - async with aiohttp.ClientSession() as session: - data = await API(self.api_token, session).create_new_rssfeed( - name, url, source.icon_url, guild_id - ) - - return RSSFeed.from_dict(data) - - async def delete_rssfeed(self, uuid: str) -> RSSFeed: - - log.info("Deleting Feed '%s'", uuid) - - async with aiohttp.ClientSession() as session: - api = API(self.api_token, session) - data = await api.get_rssfeed(uuid) - await api.delete_rssfeed(uuid) - - return RSSFeed.from_dict(data) - - async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]: - """Get a list of RSS Feeds. - - Args: - guild_id (int): The guild_id to filter by. - - Returns: - list[RSSFeed]: Resulting list of RSS Feeds - """ - - async with aiohttp.ClientSession() as session: - data, count = await API(self.api_token, session).get_rssfeed_list( - discord_server_id=guild_id, - page=page, - page_size=pagesize - ) - - return RSSFeed.from_list(data), count - - async def assign_feed( - self, url: str, channel_name: str, channel_id: int, guild_id: int - ) -> tuple[int, Source]: - """""" - - async with DatabaseManager() as database: - select_query = select(RssSourceModel).where(and_( - RssSourceModel.rss_url == url, - RssSourceModel.discord_server_id == guild_id - )) - select_result = await database.session.execute(select_query) - - rss_source = select_result.scalars().one() - - insert_query = insert(FeedChannelModel).values( - discord_server_id = guild_id, - discord_channel_id = channel_id, - rss_source_id=rss_source.id, - search_name=f"{rss_source.nick} #{channel_name}" - ) - - insert_result = await database.session.execute(insert_query) - return insert_result.inserted_primary_key.id, await Source.from_url(url) - - async def unassign_feed( self, assigned_feed_id: int, guild_id: int): - """""" - - async with DatabaseManager() as database: - query = delete(FeedChannelModel).where(and_( - FeedChannelModel.id == assigned_feed_id, - FeedChannelModel.discord_server_id == guild_id - )) - - result = await database.session.execute(query) - if not result.rowcount: - raise NoResultFound diff --git a/src/tests.py b/src/tests.py deleted file mode 100644 index cc63568..0000000 --- a/src/tests.py +++ /dev/null @@ -1,54 +0,0 @@ - -import unittest -from sqlalchemy import select -from sqlalchemy.engine.cursor import CursorResult -from sqlalchemy.engine.result import ChunkedIteratorResult - -from db import DatabaseManager, FeedChannelModel, AuditModel - - -class TestDatabaseConnections(unittest.IsolatedAsyncioTestCase): - """The purpose of this test, is to ensure that the database connections function properly.""" - - async def test_select__feed_channel_model(self): - """This test runs a select query on the `FeedChannelModel`""" - - async with DatabaseManager() as database: - query = select(FeedChannelModel).limit(1000) - result = await database.session.execute(query) - - self.assertIsInstance( - result, - ChunkedIteratorResult, - f"Result should be `ChunkedIteratorResult`, not {type(result)!r}" - ) - - async def test_select__rss_source_model(self): - """This test runs a select query on the `RssSourceModel`""" - - async with DatabaseManager() as database: - query = select(RssSourceModel).limit(1000) - result = await database.session.execute(query) - - self.assertIsInstance( - result, - ChunkedIteratorResult, - f"Result should be `ChunkedIteratorResult`, not {type(result)!r}" - ) - - async def test_select__audit_model(self): - """This test runs a select query on the `AuditModel`""" - - async with DatabaseManager() as database: - query = select(AuditModel).limit(1000) - result = await database.session.execute(query) - - self.assertIsInstance( - result, - ChunkedIteratorResult, - f"Result should be `ChunkedIteratorResult`, not {type(result)!r}" - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/utils.py b/src/utils.py index 3a5d957..6b4dc30 100644 --- a/src/utils.py +++ b/src/utils.py @@ -42,15 +42,9 @@ async def followup(inter: Interaction, *args, **kwargs): await inter.followup.send(*args, **kwargs) -async def audit(cog, *args, **kwargs): - """Shorthand for auditing an interaction.""" - - await cog.bot.audit(*args, **kwargs) - # https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png - class FollowupIcons: error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png" success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"