heavy rewrites. (rm old code, improve tasks, etc)

This commit is contained in:
Corban-Lee Jones 2024-07-08 22:33:03 +01:00
parent 574d54a2eb
commit 1f199c36f9
10 changed files with 142 additions and 1028 deletions

View File

@ -7,10 +7,6 @@ from pathlib import Path
from discord import Intents, Game
from discord.ext import commands
from sqlalchemy import insert
from feed import Functions
from db import DatabaseManager, AuditModel
log = logging.getLogger(__name__)
@ -20,7 +16,6 @@ class DiscordBot(commands.Bot):
def __init__(self, BASE_DIR: Path, developing: bool, api_token: str):
activity = Game("Indev") if developing else None
super().__init__(command_prefix="-", intents=Intents.all(), activity=activity)
self.functions = Functions(self, api_token)
self.BASE_DIR = BASE_DIR
self.developing = developing
self.api_token = api_token
@ -52,27 +47,3 @@ class DiscordBot(commands.Bot):
for path in (self.BASE_DIR / "src/extensions").iterdir():
if path.suffix == ".py":
await self.load_extension(f"extensions.{path.stem}")
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
"""Shorthand for auditing an action.
Parameters
----------
message : str
The message to be audited.
user_id : int
Discord ID of the user being audited.
database : DatabaseManager, optional
An existing database connection to be used if specified, by default None
"""
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
log.debug("Logging audit")
if database:
await database.session.execute(query)
return
async with DatabaseManager() as database:
await database.session.execute(query)

View File

@ -1,20 +0,0 @@
"""
Initialize the database modules, create the database tables and default data.
"""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from .models import Base, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
from .db import DatabaseManager
# # Initialise a database session
# engine = create_engine(DatabaseManager.get_database_url(use_async=False))
# session = sessionmaker(bind=engine)()
# # Create tables if not exists
# Base.metadata.create_all(engine)
# session.commit()
# session.close()

View File

@ -1,77 +0,0 @@
"""
Database Manager
"""
import logging
from os import getenv
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
DB_TYPE = getenv("DB_TYPE", default="sqlite")
DB_HOST = getenv("DB_HOST", default="db.sqlite")
DB_PORT = getenv("DB_PORT")
DB_USERNAME = getenv("DB_USERNAME")
DB_PASSWORD = getenv("DB_PASSWORD")
DB_DATABASE = getenv("DB_DATABASE")
log = logging.getLogger(__name__)
class DatabaseManager:
"""
Asynchronous database context manager.
"""
def __init__(self, no_commit: bool = False):
database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
self.engine = create_async_engine(database_url, future=True)
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
self.session = None
self.no_commit = no_commit
@staticmethod
def get_database_url(use_async=True):
"""
Returns a connection string for the database.
"""
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
url_addon = ""
# This looks fucking ugly
match use_async, DB_TYPE:
case True, "sqlite":
url_addon = "aiosqlite"
case True, "postgresql":
url_addon = "asyncpg"
case False, "sqlite":
pass
case False, "postgresql":
pass
case _, _:
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
url = url.replace(":/", f"+{url_addon}:/") if url_addon else url
return url
async def __aenter__(self):
self.session = self.session_maker()
log.debug("Database connection open")
return self
async def __aexit__(self, *_):
if not self.no_commit:
await self.session.commit()
await self.session.close()
self.session = None
await self.engine.dispose()
log.debug("Database connection closed")

View File

@ -1,96 +0,0 @@
"""
Models and Enums for the database.
All table classes should be suffixed with `Model`.
"""
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship, declarative_base
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
BigInteger,
UniqueConstraint,
ForeignKey
)
Base = declarative_base()
# back in wed, thu, fri, off new year day then back in after
class AuditModel(Base):
"""
Table for taking audits.
"""
__tablename__ = "audit"
id = Column(Integer, primary_key=True, autoincrement=True)
discord_user_id = Column(BigInteger, nullable=False)
# discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it.
message = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
class SentArticleModel(Base):
"""
Table for tracking articles that have been served by the bot.
"""
__tablename__ = "sent_articles"
id = Column(Integer, primary_key=True, autoincrement=True)
discord_message_id = Column(BigInteger, nullable=False)
discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
article_url = Column(String, nullable=False)
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
feed_channel_id = Column(Integer, ForeignKey("feed_channel.id", ondelete="CASCADE"), nullable=False)
feed_channel = relationship("FeedChannelModel", overlaps="sent_articles", lazy="joined", cascade="all, delete")
class RssSourceModel(Base):
"""
Table for user submitted news feeds.
"""
__tablename__ = "rss_source"
id = Column(Integer, primary_key=True, autoincrement=True)
nick = Column(String, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
rss_url = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
feed_channels = relationship("FeedChannelModel", cascade="all, delete")
# the nickname must be unique, but only within the same discord server
__table_args__ = (
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
)
class FeedChannelModel(Base):
"""
Table representing discord channels to be used for news feeds.
"""
__tablename__ = "feed_channel"
id = Column(Integer, primary_key=True, autoincrement=True)
discord_channel_id = Column(BigInteger, nullable=False)
discord_server_id = Column(BigInteger, nullable=False)
search_name = Column(String, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
sent_articles = relationship("SentArticleModel", cascade="all, delete")
rss_source_id = Column(Integer, ForeignKey('rss_source.id', ondelete="CASCADE"), nullable=False)
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
# the rss source must be unique, but only within the same discord channel
__table_args__ = (
UniqueConstraint('rss_source_id', 'discord_channel_id', name='uq_rss_discord_channel'),
)

View File

@ -1,10 +1,4 @@
class IllegalFeed(Exception):
def __init__(self, message: str, **items):
super().__init__(message)
self.items = items
class TokenMissingError(Exception):
"""
Exception to indicate a token couldnt be found.

View File

@ -1,152 +0,0 @@
"""
task flow
=-=-=-=-=
1. every 10 minutes
2. get all bot guild_ids
3. get all subscriptions against known guild ids (account for pagination of 25)
4. iterate through and process each subscription
5. get all articles from subscription
6. iterate through and process each article
"""
class Cog:
async def subscription_task():
subscriptions = self.get_subscriptions()
for subscription in subscriptions:
await self.process_subscription(subscription)
async def get_subscriptions():
# return subs from api, handle pagination
async def process_subscription(subscription):
articles = subscription.get_articles()
for article in articles:
await self.process_article(article)
async def process_article(article):
# validate article then:
if await self.track_article(article):
await self.send_article(article)
async def track_article():
pass
async def send_article():
pass
class TaskCog(commands.Cog):
"""
Tasks cog for PYRSS.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
@tasks.loop(time=times)
async def subscription_task(self):
async with aiohttp.ClientSession() as session:
api = API(self.bot.api_token, session)
subscriptions = await self.get_subscriptions(api)
await self.process_subscriptions(api, subscriptions)
# articles = [*(await self.get_articles(api, sub)) for sub in subscriptions]
# await self.send_articles(api, articles)
async def get_subscriptions(self, api) -> list:
guild_ids = [guild.id for guild in self.bot.guilds]
subscriptions = []
for page in iter(int, 1):
try:
page_data = (await api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
except apihttp.ClientResponseError as error:
if error.status == 404:
break
except Exception as error:
log.error("Exception while gathering page data %s", error)
break
subscriptions.extend(page_data)
async def process_subscriptions(self, api, subscriptions):
for sub in subscriptions:
if not sub.active or not sub.channel_count:
continue
unparsed_feed = await get_unparsed_feed(sub.url, api.session)
parsed_feed = await parse(unparsed_feed)
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
await self.process_items(api, sub, rss_feed)
async def process_items(self, api, sub, feed):
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels()]
filters = [await api.get_filter(filter_id) for filter_id in sub.filters]
for item in sub.items:
blocked = any(self.filter_item(_filter, item) for _filter in filters)
mutated_item = item.create_mutated_copy(sub.mutators)
for channel in channels:
await self.mark_tracked_item(api, sub, item, channel.id, blocked)
if not blocked:
channel.send(embed=item.to_embed(sub, feed))
async def filter_item(self, _filter: dict, item: RSSItem) -> bool:
"""
Returns True if item should be ignored due to filters.
"""
match_found = False # This is the flag to determine if the content should be filtered
keywords = _filter["keywords"].split(",")
regex_pattern = _filter["regex"]
is_whitelist = _filter["whitelist"]
log.debug(
"trying filter '%s', keyword '%s', regex '%s', is whitelist: '%s'",
_filter["name"], keywords, regex_pattern, is_whitelist
)
assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used."
if any(word in item.title or word in item.description for word in keywords):
match_found = True
if regex_pattern:
regex = re.compile(regex_pattern)
match_found = regex.search(item.title) or regex.search(item.description)
return not match_found if is_whitelist else match_found
async def mark_tracked_item(self, sub, item, channel_id, blocked):
try:
api.create_tracked_content(
guid=item.guid,
title=item.title,
url=item.url,
subscription=sub.id,
channel_id=channel_id,
blocked=blocked
)
except aiohttp.ClientResponseError as error:
if error.status == 409:
log.debug(error)
else:
log.error(error)

View File

@ -4,7 +4,6 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
"""
import re
import json
import logging
import datetime
from os import getenv
@ -15,249 +14,129 @@ from discord import TextChannel, Embed, Colour
from discord import app_commands
from discord.ext import commands, tasks
from discord.errors import Forbidden
from sqlalchemy import insert, select, and_
from feedparser import parse
from feed import Source, Article, RSSFeed, Subscription, SubscriptionChannel, SubChannel
from db import (
DatabaseManager,
FeedChannelModel,
RssSourceModel,
SentArticleModel
)
from feed import RSSFeed, Subscription, RSSItem
from utils import get_unparsed_feed
from api import API
log = logging.getLogger(__name__)
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
# task trigger times : must be of type list
times = [
subscription_task_times = [
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
for hour in range(24)
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
]
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
class TaskCog(commands.Cog):
"""
Tasks cog.
Tasks cog for PYRSS.
"""
def __init__(self, bot):
super().__init__()
self.bot = bot
self.time = None
@commands.Cog.listener()
async def on_ready(self):
"""Instructions to execute when the cog is ready."""
# if not self.bot.developing:
self.rss_task.start()
"""
Instructions to execute when the cog is ready.
"""
self.subscription_task.start()
log.info("%s cog is ready", self.__class__.__name__)
@commands.Cog.listener(name="cog_unload")
async def on_unload(self):
"""Instructions to execute before the cog is unloaded."""
self.rss_task.cancel()
"""
Instructions to execute before the cog is unloaded.
"""
self.subscription_task.cancel()
@app_commands.command(name="debug-trigger-task")
async def debug_trigger_task(self, inter):
await inter.response.defer()
await self.rss_task()
await self.subscription_task()
await inter.followup.send("done")
@tasks.loop(time=times)
async def rss_task(self):
"""Automated task responsible for processing rss feeds."""
@tasks.loop(time=subscription_task_times)
async def subscription_task(self):
"""
Task for fetching and processing subscriptions.
"""
log.info("Running subscription task")
time = process_time()
guild_ids = [guild.id for guild in self.bot.guilds]
data = []
async with aiohttp.ClientSession() as session:
api = API(self.bot.api_token, session)
page = 0
subscriptions = await self.get_subscriptions(api)
await self.process_subscriptions(api, subscriptions)
while True:
page += 1
page_data = await self.get_subscriptions(api, guild_ids, page)
async def get_subscriptions(self, api: API) -> list[Subscription]:
guild_ids = [guild.id for guild in self.bot.guilds]
sub_data = []
if not page_data:
break
for page, _ in enumerate(iter(int, 1)):
try:
log.debug("fetching page '%s'", page + 1)
sub_data.extend(
(await api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
)
except aiohttp.ClientResponseError as error:
match error.status:
case 404:
log.debug("final page reached '%s'", page)
break
case 403:
log.critical(error)
self.subscription_task.cancel()
return [] # returning an empty list should gracefully end the task
case _:
log.error(error)
break
data.extend(page_data)
log.debug("extending data by '%s' items", len(page_data))
except Exception as error:
log.error("Exception while gathering page data %s", error)
break
log.debug("finished api data collection, browsed %s pages for %s subscriptions", page, len(data))
subscriptions = Subscription.from_list(data)
for sub in subscriptions:
await self.process_subscription(api, session, sub)
return Subscription.from_list(sub_data)
log.info("Finished subscription task, time elapsed: %s", process_time() - time)
async def process_subscriptions(self, api: API, subscriptions: list[Subscription]):
for sub in subscriptions:
log.debug("processing subscription '%s'", sub.id)
async def get_subscriptions(self, api, guild_ids: list[int], page: int):
if not sub.active or not sub.channels_count:
continue
log.debug("attempting to get subscriptions for page: %s", page)
unparsed_feed = await get_unparsed_feed(sub.url, api.session)
parsed_feed = parse(unparsed_feed)
try:
return (await api.get_subscriptions(server__in=guild_ids, page=page))[0]
except aiohttp.ClientResponseError as error:
if error.status == 404:
log.debug(error)
return []
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
await self.process_items(api, sub, rss_feed)
log.error(error)
async def process_subscription(self, api: API, session: aiohttp.ClientSession, sub: Subscription):
"""
Process a given Subscription.
"""
log.debug("processing subscription '%s' '%s' for '%s'", sub.id, sub.name, sub.guild_id)
if not sub.active:
log.debug("skipping sub because it's active flag is 'False'")
return
channels: list[TextChannel] = [self.bot.get_channel(subchannel.channel_id) for subchannel in await sub.get_channels(api)]
if not channels:
log.warning("No channels to send this to")
return
async def process_items(self, api: API, sub: Subscription, feed: RSSFeed):
log.debug("processing items")
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(api)]
filters = [await api.get_filter(filter_id) for filter_id in sub.filters]
log.debug("found %s filter(s)", len(filters))
unparsed_content = await get_unparsed_feed(sub.url, session)
parsed_content = parse(unparsed_content)
source = Source.from_parsed(parsed_content)
articles = source.get_latest_articles(10)
articles.reverse()
for item in feed.items:
log.debug("processing item '%s'", item.guid)
if not articles:
log.debug("No articles found")
embeds = await self.get_articles_as_embeds(api, session, sub.id, sub.mutators, filters, articles, Colour.from_str("#" + sub.embed_colour))
await self.send_embeds_in_chunks(embeds, channels)
async def get_articles_as_embeds(
self,
api: API,
session: aiohttp.ClientSession,
sub_id: int,
mutators: dict[str, list[dict]],
filters: list[dict],
articles: list[Article],
embed_colour: str
) -> list[Embed]:
"""
Process articles and return their respective embeds.
"""
embeds = []
for article in articles:
embed = await self.process_article(api, session, sub_id, mutators, filters, article, embed_colour)
if embed:
embeds.append(embed)
return embeds
async def send_embeds_in_chunks(self, embeds: list[Embed], channels: list[TextChannel], embed_limit=10):
"""
Send embeds to a list of `TextChannel` in chunks of `embed_limit` size.
"""
log.debug("about to send %s embeds")
for i in range(0, len(embeds), embed_limit):
embeds_chunk = embeds[i:i + embed_limit]
log.debug("sending chunk of %s embeds", len(embeds_chunk))
blocked = any(self.filter_item(_filter, item) for _filter in filters)
mutated_item = item.create_mutated_copy(sub.mutators)
for channel in channels:
await self.try_send_embeds(embeds, channel)
successful_track = await self.mark_tracked_item(api, sub, item, channel.id, blocked)
async def try_send_embeds(self, embeds: list[Embed], channel: TextChannel):
if successful_track and not blocked:
await channel.send(embed=await item.to_embed(sub, feed, api.session))
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
"""
Attempt to send embeds to a given `TextChannel`. Gracefully handles errors.
"""
try:
await channel.send(embeds=embeds)
except Forbidden:
log.debug(
"Forbidden from sending embed to channel '%s', guild '%s'",
channel.id, channel.guild.id
)
except Exception as exc:
log.error(exc)
async def process_article(
self,
api: API,
session: aiohttp.ClientSession,
sub_id: int,
mutators: dict[str, list[dict]],
filters: list[dict],
article: Article,
embed_colour: str
) -> Embed | None:
"""
Process a given Article.
Returns an Embed representing the given Article.
"""
log.debug("processing article '%s' '%s'", article.guid, article.title)
blocked = any(self.filter_article(_filter, article) for _filter in filters)
log.debug("filter result: %s", "blocked" if blocked else "ok")
self.mutate_article(article, mutators)
try:
await api.create_tracked_content(
guid=article.guid,
title=article.title,
url=article.url,
subscription=sub_id,
blocked=blocked,
channel_id="-_-"
)
log.debug("successfully tracked %s", article.guid)
except aiohttp.ClientResponseError as error:
if error.status == 409:
log.debug("It looks like this article already exists, skipping")
else:
log.error(error)
return
if not blocked:
return await article.to_embed(session, embed_colour)
def mutate_article(self, article: Article, mutators: list[dict]):
for mutator in mutators["title"]:
article.mutate("title", mutator)
for mutator in mutators["desc"]:
article.mutate("description", mutator)
def filter_article(self, _filter: dict, article: Article) -> bool:
"""
Returns True if article should be ignored due to filters.
Returns True if item should be ignored due to filters.
"""
match_found = False # This is the flag to determine if the content should be filtered
@ -273,15 +152,35 @@ class TaskCog(commands.Cog):
assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used."
if any(word in article.title or word in article.description for word in keywords):
match_found = True
if regex_pattern:
regex = re.compile(regex_pattern)
match_found = regex.search(article.title) or regex.search(article.description)
match_found = regex.search(item.title) or regex.search(item.description)
else:
match_found = any(word in item.title or word in item.description for word in keywords)
return not match_found if is_whitelist else match_found
async def mark_tracked_item(self, api: API, sub: Subscription, item: RSSItem, channel_id: int, blocked: bool):
try:
log.debug("marking as tracked 'blocked: %s'", blocked)
await api.create_tracked_content(
guid=item.guid,
title=item.title,
url=item.link,
subscription=sub.id,
channel_id=channel_id,
blocked=blocked
)
return True
except aiohttp.ClientResponseError as error:
if error.status == 409:
log.debug(error)
else:
log.error(error)
return False
async def setup(bot):
"""
Setup function for this extension.

View File

@ -4,7 +4,6 @@ import logging
from dataclasses import dataclass
from datetime import datetime
from abc import ABC, abstractmethod
from random import shuffle, sample
import aiohttp
import validators
@ -12,30 +11,25 @@ from discord import Embed, Colour
from bs4 import BeautifulSoup as bs4
from feedparser import FeedParserDict, parse
from markdownify import markdownify
from sqlalchemy import select, insert, delete, and_
from sqlalchemy.exc import NoResultFound
from textwrap import shorten
from feedparser import parse
from mutators import registry as mutator_registry
from errors import IllegalFeed
from db import DatabaseManager, RssSourceModel, FeedChannelModel
from utils import get_rss_data, get_unparsed_feed
from utils import get_unparsed_feed
from api import API
log = logging.getLogger(__name__)
dumps = lambda _dict: json.dumps(_dict, indent=8)
from xml.etree.ElementTree import Element, SubElement, tostring
from feedparser import parse
class RSSItem:
def __init__(self, title, link, description, pub_date, guid):
def __init__(self, title, link, description, pub_date, guid, image_url):
self.title = title
self.link = link
self.description = description
self.pub_date = pub_date
self.guid = guid
self.image_url = image_url
def __str__(self):
return self.title
@ -43,7 +37,7 @@ class RSSItem:
def create_mutated_copy(self, mutators):
pass
def create_embed(self, sub, feed):
async def to_embed(self, sub, feed, session):
title = shorten(markdownify(self.title, strip=["img", "a"]), 256)
desc = shorten(markdownify(self.description, strip=["img"]), 4096)
@ -60,7 +54,6 @@ class RSSItem:
# validate urls
# author_url = self.source.url if validators.url(self.source.url) else None
# icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
# thumb_url = await self.get_thumbnail_url(session) # validation done inside func
# Combined length validation
# Can't exceed combined 6000 characters, [400 Bad Request] if failed.
@ -71,27 +64,64 @@ class RSSItem:
embed = Embed(
title=title,
description=desc,
timestamp=self.published,
# timestamp=self.published,
url=self.link if validators.url(self.link) else None,
colour=colour
colour=Colour.from_str("#" + sub.embed_colour)
)
# embed.set_thumbnail(url=icon_url)
log.debug("has image url without search: '%s'", self.image_url)
if sub.article_fetch_image:
embed.set_image(url=self.image_url or await self.get_thumbnail_url(session))
embed.set_thumbnail(url=feed.image_href if validators.url(feed.image_href) else None)
# embed.set_image(url=thumb_url)
# embed.set_author(url=author_url, name=author)
# embed.set_footer(text=self.author)
return embed
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
"""Returns the thumbnail URL for an article.
Parameters
----------
session : aiohttp.ClientSession
A client session used to get the thumbnail.
Returns
-------
str or None
The thumbnail URL, or None if not found.
"""
# log.debug("Fetching thumbnail for article: %s", self)
try:
async with session.get(self.link, timeout=15) as response:
html = await response.text()
except aiohttp.InvalidURL as error:
log.error("invalid thumbnail url: %s", error)
return None
soup = bs4(html, "html.parser")
image_element = soup.select_one("meta[property='og:image']")
if not image_element:
return None
image_content = image_element.get("content")
return image_content if validators.url(image_content) else None
import json
class RSSFeed:
def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None):
def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None, image_href=None):
self.title = title
self.link = link
self.description = description
self.language = language
self.pub_date = pub_date
self.last_build_date = last_build_date
self.image_href = image_href
self.items = []
def add_item(self, item: RSSItem):
@ -111,8 +141,9 @@ class RSSFeed:
language = parsed_feed.feed.get('language', 'en-gb')
pub_date = parsed_feed.feed.get('published', None)
last_build_date = parsed_feed.feed.get('updated', None)
image_href = parsed_feed.get("image", {}).get("href")
feed = cls(title, link, description, language, pub_date, last_build_date)
feed = cls(title, link, description, language, pub_date, last_build_date, image_href)
for entry in parsed_feed.entries:
item_title = entry.get('title', 'No title')
@ -120,10 +151,11 @@ class RSSFeed:
item_description = entry.get('description', 'No description')
item_pub_data = entry.get('published_parsed', None)
item_guid = entry.get('id', None) or entry.get("guid", None)
item_image_url = entry.get("media_content", [{}])[0].get("url")
item_published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
# item_published = datetime(*entry.published_parsed[0:-2]) if item_pub_data else None
item = RSSItem(item_title, item_link, item_description, item_published, item_guid)
item = RSSItem(item_title, item_link, item_description, item_pub_data, item_guid, item_image_url)
feed.add_item(item)
feed.items.reverse()
@ -131,244 +163,6 @@ class RSSFeed:
return feed
@dataclass
class RSSArticle:
"""Represents a news article, or entry from an RSS feed."""
guid: str
title: str | None
description: str | None
url: str | None
published: datetime | None
author: str | None
source: object
@classmethod
def from_entry(cls, source, entry:FeedParserDict):
"""Create an Article from an RSS feed entry.
Parameters
----------
entry : FeedParserDict
An entry pulled from a complete FeedParserDict object.
Returns
-------
Article
The Article created from the feed entry.
"""
# log.debug("Creating Article from entry: %s", dumps(entry))
published_parsed = entry.get("published_parsed")
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
return cls(
guid=entry["guid"],
title=entry.get("title"),
description=entry.get("description"),
url=entry.get("link"),
published=published,
author = entry.get("author"),
source=source
)
def mutate(self, attr: str, mutator: dict[str, str]):
"""
Apply a mutation to a certain text attribute of
this Article instance.
"""
# WARN:
# This could be really bad if the end user is able to effect the 'attr' value.
# Shouldn't happen though.
log.debug("applying mutator '%s'", mutator["name"])
val = mutator["value"]
try:
mutator = mutator_registry.get_mutator(val)
except ValueError as err:
log.error(err)
setattr(self, attr, mutator.mutate(getattr(self, attr)))
log.debug("mutated %s, to: %s", attr, getattr(self, attr))
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
"""Returns the thumbnail URL for an article.
Parameters
----------
session : aiohttp.ClientSession
A client session used to get the thumbnail.
Returns
-------
str or None
The thumbnail URL, or None if not found.
"""
# log.debug("Fetching thumbnail for article: %s", self)
try:
async with session.get(self.url, timeout=15) as response:
html = await response.text()
except aiohttp.InvalidURL as error:
log.error("invalid thumbnail url: %s", error)
return None
soup = bs4(html, "html.parser")
image_element = soup.select_one("meta[property='og:image']")
if not image_element:
return None
image_content = image_element.get("content")
return image_content if validators.url(image_content) else None
async def to_embed(self, session: aiohttp.ClientSession, colour: Colour) -> Embed:
"""Creates and returns a Discord Embed object from the article.
Parameters
----------
session : aiohttp.ClientSession
A client session used to get additional article data.
Returns
-------
Embed
A Discord Embed object representing the article.
"""
# log.debug(f"Creating embed from article: {self}")
# Replace HTML with Markdown, and shorten text.
title = shorten(markdownify(self.title, strip=["img", "a"]), 256)
desc = shorten(markdownify(self.description, strip=["img"]), 4096)
author = shorten(self.source.name, 256)
# validate urls
embed_url = self.url if validators.url(self.url) else None
author_url = self.source.url if validators.url(self.source.url) else None
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
thumb_url = await self.get_thumbnail_url(session) # validation done inside func
# Combined length validation
# Can't exceed combined 6000 characters, [400 Bad Request] if failed.
combined_length = len(title) + len(desc) + (len(author) * 2)
cutoff = combined_length - 6000
desc = shorten(desc, cutoff) if cutoff > 0 else desc
embed = Embed(
title=title,
description=desc,
timestamp=self.published,
url=embed_url,
colour=colour
)
embed.set_thumbnail(url=icon_url)
embed.set_image(url=thumb_url)
embed.set_author(url=author_url, name=author)
embed.set_footer(text=self.author)
return embed
@dataclass
class RSSFeedSource:
"""Represents an RSS Feed."""
name: str | None
description: str | None
url: str | None
icon_url: str | None
feed: FeedParserDict
@classmethod
def from_parsed(cls, feed:FeedParserDict):
"""Returns a Source object from a parsed feed.
Parameters
----------
feed : FeedParserDict
The feed used to create the Source.
Returns
-------
Source
The Source object
"""
# log.debug("Creating Source from feed: %s", dumps(feed))
channel = feed.get("channel", {})
return cls(
name=channel.get("title"),
description=channel.get("description"),
url=channel.get("link"),
icon_url=feed.get("feed", {}).get("image", {}).get("href"),
feed=feed
)
@classmethod
async def from_url(cls, url: str):
unparsed_content = await get_unparsed_feed(url)
source = cls.from_parsed(parse(unparsed_content))
source.url = url
return source
def get_latest_articles(self, max: int = 999) -> list[Article]:
"""Returns a list of Article objects.
Parameters
----------
max : int
The maximum number of articles to return.
Returns
-------
list of Article
A list of Article objects.
"""
# log.debug("Fetching latest articles from %s, max=%s", self, max)
return [
Article.from_entry(self, entry)
for i, entry in enumerate(self.feed.entries)
if i < max
]
@dataclass
class RSSFeedSource_:
uuid: str
name: str
url: str
image: str
discord_server_id: id
created_at: str
@classmethod
def from_list(cls, data: list) -> list:
result = []
for item in data:
key = "discord_server_id"
item[key] = int(item.get(key))
result.append(cls(**item))
return result
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
@ -401,6 +195,7 @@ class Subscription(DjangoDataModel):
extra_notes: str
filters: list[int]
mutators: dict[str, list[dict]]
article_fetch_image: bool
embed_colour: str
active: bool
channels_count: int
@ -488,143 +283,3 @@ class TrackedContent(DjangoDataModel):
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], DATETIME_FORMAT)
return item
class Functions:
def __init__(self, bot, api_token: str):
self.bot = bot
self.api_token = api_token
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
"""Validates a feed based on the given nickname and url.
Parameters
----------
nickname : str
Human readable nickname used to refer to the feed.
url : str
URL to fetch content from the feed.
Returns
-------
FeedParserDict
A Parsed Dictionary of the feed.
Raises
------
IllegalFeed
If the feed is invalid.
"""
# Ensure the URL is valid
if not validators.url(url):
raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`")
# Check the nickname is not a URL
if validators.url(nickname):
raise IllegalFeed(
"It looks like the nickname you have entered is a URL.\n" \
"For security reasons, this is not allowed.",
nickname=nickname
)
feed_data, status_code = await get_rss_data(url)
if status_code != 200:
raise IllegalFeed(
"The URL provided returned an invalid status code:",
url=url, status_code=status_code
)
# Check the contents is actually an RSS feed.
feed = parse(feed_data)
if not feed.version:
raise IllegalFeed(
"The provided URL does not seem to be a valid RSS feed.",
url=url
)
return feed
async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed:
log.info("Creating new Feed: %s", name)
parsed_feed = await self.validate_feed(name, url)
source = Source.from_parsed(parsed_feed)
async with aiohttp.ClientSession() as session:
data = await API(self.api_token, session).create_new_rssfeed(
name, url, source.icon_url, guild_id
)
return RSSFeed.from_dict(data)
async def delete_rssfeed(self, uuid: str) -> RSSFeed:
log.info("Deleting Feed '%s'", uuid)
async with aiohttp.ClientSession() as session:
api = API(self.api_token, session)
data = await api.get_rssfeed(uuid)
await api.delete_rssfeed(uuid)
return RSSFeed.from_dict(data)
async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]:
"""Get a list of RSS Feeds.
Args:
guild_id (int): The guild_id to filter by.
Returns:
list[RSSFeed]: Resulting list of RSS Feeds
"""
async with aiohttp.ClientSession() as session:
data, count = await API(self.api_token, session).get_rssfeed_list(
discord_server_id=guild_id,
page=page,
page_size=pagesize
)
return RSSFeed.from_list(data), count
async def assign_feed(
self, url: str, channel_name: str, channel_id: int, guild_id: int
) -> tuple[int, Source]:
""""""
async with DatabaseManager() as database:
select_query = select(RssSourceModel).where(and_(
RssSourceModel.rss_url == url,
RssSourceModel.discord_server_id == guild_id
))
select_result = await database.session.execute(select_query)
rss_source = select_result.scalars().one()
insert_query = insert(FeedChannelModel).values(
discord_server_id = guild_id,
discord_channel_id = channel_id,
rss_source_id=rss_source.id,
search_name=f"{rss_source.nick} #{channel_name}"
)
insert_result = await database.session.execute(insert_query)
return insert_result.inserted_primary_key.id, await Source.from_url(url)
async def unassign_feed( self, assigned_feed_id: int, guild_id: int):
""""""
async with DatabaseManager() as database:
query = delete(FeedChannelModel).where(and_(
FeedChannelModel.id == assigned_feed_id,
FeedChannelModel.discord_server_id == guild_id
))
result = await database.session.execute(query)
if not result.rowcount:
raise NoResultFound

View File

@ -1,54 +0,0 @@
import unittest
from sqlalchemy import select
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.result import ChunkedIteratorResult
from db import DatabaseManager, FeedChannelModel, AuditModel
class TestDatabaseConnections(unittest.IsolatedAsyncioTestCase):
"""The purpose of this test, is to ensure that the database connections function properly."""
async def test_select__feed_channel_model(self):
"""This test runs a select query on the `FeedChannelModel`"""
async with DatabaseManager() as database:
query = select(FeedChannelModel).limit(1000)
result = await database.session.execute(query)
self.assertIsInstance(
result,
ChunkedIteratorResult,
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
)
async def test_select__rss_source_model(self):
"""This test runs a select query on the `RssSourceModel`"""
async with DatabaseManager() as database:
query = select(RssSourceModel).limit(1000)
result = await database.session.execute(query)
self.assertIsInstance(
result,
ChunkedIteratorResult,
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
)
async def test_select__audit_model(self):
"""This test runs a select query on the `AuditModel`"""
async with DatabaseManager() as database:
query = select(AuditModel).limit(1000)
result = await database.session.execute(query)
self.assertIsInstance(
result,
ChunkedIteratorResult,
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
)
if __name__ == "__main__":
unittest.main()

View File

@ -42,15 +42,9 @@ async def followup(inter: Interaction, *args, **kwargs):
await inter.followup.send(*args, **kwargs)
async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
class FollowupIcons:
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"