heavy rewrites. (rm old code, improve tasks, etc)
This commit is contained in:
parent
574d54a2eb
commit
1f199c36f9
29
src/bot.py
29
src/bot.py
@ -7,10 +7,6 @@ from pathlib import Path
|
||||
|
||||
from discord import Intents, Game
|
||||
from discord.ext import commands
|
||||
from sqlalchemy import insert
|
||||
|
||||
from feed import Functions
|
||||
from db import DatabaseManager, AuditModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -20,7 +16,6 @@ class DiscordBot(commands.Bot):
|
||||
def __init__(self, BASE_DIR: Path, developing: bool, api_token: str):
|
||||
activity = Game("Indev") if developing else None
|
||||
super().__init__(command_prefix="-", intents=Intents.all(), activity=activity)
|
||||
self.functions = Functions(self, api_token)
|
||||
self.BASE_DIR = BASE_DIR
|
||||
self.developing = developing
|
||||
self.api_token = api_token
|
||||
@ -52,27 +47,3 @@ class DiscordBot(commands.Bot):
|
||||
for path in (self.BASE_DIR / "src/extensions").iterdir():
|
||||
if path.suffix == ".py":
|
||||
await self.load_extension(f"extensions.{path.stem}")
|
||||
|
||||
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
|
||||
"""Shorthand for auditing an action.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
message : str
|
||||
The message to be audited.
|
||||
user_id : int
|
||||
Discord ID of the user being audited.
|
||||
database : DatabaseManager, optional
|
||||
An existing database connection to be used if specified, by default None
|
||||
"""
|
||||
|
||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||
|
||||
log.debug("Logging audit")
|
||||
|
||||
if database:
|
||||
await database.session.execute(query)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
await database.session.execute(query)
|
||||
|
@ -1,20 +0,0 @@
|
||||
"""
|
||||
Initialize the database modules, create the database tables and default data.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .models import Base, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||
from .db import DatabaseManager
|
||||
|
||||
# # Initialise a database session
|
||||
# engine = create_engine(DatabaseManager.get_database_url(use_async=False))
|
||||
# session = sessionmaker(bind=engine)()
|
||||
|
||||
# # Create tables if not exists
|
||||
# Base.metadata.create_all(engine)
|
||||
|
||||
# session.commit()
|
||||
# session.close()
|
||||
|
77
src/db/db.py
77
src/db/db.py
@ -1,77 +0,0 @@
|
||||
"""
|
||||
Database Manager
|
||||
"""
|
||||
|
||||
import logging
|
||||
from os import getenv
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
DB_TYPE = getenv("DB_TYPE", default="sqlite")
|
||||
DB_HOST = getenv("DB_HOST", default="db.sqlite")
|
||||
DB_PORT = getenv("DB_PORT")
|
||||
DB_USERNAME = getenv("DB_USERNAME")
|
||||
DB_PASSWORD = getenv("DB_PASSWORD")
|
||||
DB_DATABASE = getenv("DB_DATABASE")
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Asynchronous database context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, no_commit: bool = False):
|
||||
database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
self.session = None
|
||||
self.no_commit = no_commit
|
||||
|
||||
@staticmethod
|
||||
def get_database_url(use_async=True):
|
||||
"""
|
||||
Returns a connection string for the database.
|
||||
"""
|
||||
|
||||
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
|
||||
url_addon = ""
|
||||
|
||||
# This looks fucking ugly
|
||||
match use_async, DB_TYPE:
|
||||
case True, "sqlite":
|
||||
url_addon = "aiosqlite"
|
||||
|
||||
case True, "postgresql":
|
||||
url_addon = "asyncpg"
|
||||
|
||||
case False, "sqlite":
|
||||
pass
|
||||
|
||||
case False, "postgresql":
|
||||
pass
|
||||
|
||||
case _, _:
|
||||
raise ValueError(f"Unknown Database Type: {DB_TYPE}")
|
||||
|
||||
url = url.replace(":/", f"+{url_addon}:/") if url_addon else url
|
||||
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = self.session_maker()
|
||||
log.debug("Database connection open")
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_):
|
||||
if not self.no_commit:
|
||||
await self.session.commit()
|
||||
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
await self.engine.dispose()
|
||||
log.debug("Database connection closed")
|
@ -1,96 +0,0 @@
|
||||
"""
|
||||
Models and Enums for the database.
|
||||
All table classes should be suffixed with `Model`.
|
||||
"""
|
||||
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
BigInteger,
|
||||
UniqueConstraint,
|
||||
ForeignKey
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# back in wed, thu, fri, off new year day then back in after
|
||||
|
||||
class AuditModel(Base):
|
||||
"""
|
||||
Table for taking audits.
|
||||
"""
|
||||
|
||||
__tablename__ = "audit"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
# discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it.
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
|
||||
class SentArticleModel(Base):
|
||||
"""
|
||||
Table for tracking articles that have been served by the bot.
|
||||
"""
|
||||
|
||||
__tablename__ = "sent_articles"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_message_id = Column(BigInteger, nullable=False)
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
article_url = Column(String, nullable=False)
|
||||
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
feed_channel_id = Column(Integer, ForeignKey("feed_channel.id", ondelete="CASCADE"), nullable=False)
|
||||
feed_channel = relationship("FeedChannelModel", overlaps="sent_articles", lazy="joined", cascade="all, delete")
|
||||
|
||||
|
||||
class RssSourceModel(Base):
|
||||
"""
|
||||
Table for user submitted news feeds.
|
||||
"""
|
||||
|
||||
__tablename__ = "rss_source"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
nick = Column(String, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
rss_url = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
feed_channels = relationship("FeedChannelModel", cascade="all, delete")
|
||||
|
||||
# the nickname must be unique, but only within the same discord server
|
||||
__table_args__ = (
|
||||
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
|
||||
)
|
||||
|
||||
|
||||
class FeedChannelModel(Base):
|
||||
"""
|
||||
Table representing discord channels to be used for news feeds.
|
||||
"""
|
||||
|
||||
__tablename__ = "feed_channel"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
search_name = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
sent_articles = relationship("SentArticleModel", cascade="all, delete")
|
||||
|
||||
rss_source_id = Column(Integer, ForeignKey('rss_source.id', ondelete="CASCADE"), nullable=False)
|
||||
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
|
||||
|
||||
# the rss source must be unique, but only within the same discord channel
|
||||
__table_args__ = (
|
||||
UniqueConstraint('rss_source_id', 'discord_channel_id', name='uq_rss_discord_channel'),
|
||||
)
|
@ -1,10 +1,4 @@
|
||||
|
||||
class IllegalFeed(Exception):
|
||||
def __init__(self, message: str, **items):
|
||||
super().__init__(message)
|
||||
self.items = items
|
||||
|
||||
|
||||
class TokenMissingError(Exception):
|
||||
"""
|
||||
Exception to indicate a token couldnt be found.
|
||||
|
@ -1,152 +0,0 @@
|
||||
"""
|
||||
|
||||
task flow
|
||||
|
||||
=-=-=-=-=
|
||||
|
||||
1. every 10 minutes
|
||||
2. get all bot guild_ids
|
||||
3. get all subscriptions against known guild ids (account for pagination of 25)
|
||||
4. iterate through and process each subscription
|
||||
5. get all articles from subscription
|
||||
6. iterate through and process each article
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class Cog:
|
||||
|
||||
async def subscription_task():
|
||||
|
||||
subscriptions = self.get_subscriptions()
|
||||
|
||||
for subscription in subscriptions:
|
||||
await self.process_subscription(subscription)
|
||||
|
||||
async def get_subscriptions():
|
||||
# return subs from api, handle pagination
|
||||
|
||||
async def process_subscription(subscription):
|
||||
|
||||
articles = subscription.get_articles()
|
||||
|
||||
for article in articles:
|
||||
await self.process_article(article)
|
||||
|
||||
async def process_article(article):
|
||||
# validate article then:
|
||||
|
||||
if await self.track_article(article):
|
||||
await self.send_article(article)
|
||||
|
||||
|
||||
async def track_article():
|
||||
pass
|
||||
|
||||
async def send_article():
|
||||
pass
|
||||
|
||||
|
||||
class TaskCog(commands.Cog):
|
||||
"""
|
||||
Tasks cog for PYRSS.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@tasks.loop(time=times)
|
||||
async def subscription_task(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.bot.api_token, session)
|
||||
subscriptions = await self.get_subscriptions(api)
|
||||
await self.process_subscriptions(api, subscriptions)
|
||||
# articles = [*(await self.get_articles(api, sub)) for sub in subscriptions]
|
||||
# await self.send_articles(api, articles)
|
||||
|
||||
async def get_subscriptions(self, api) -> list:
|
||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
subscriptions = []
|
||||
|
||||
for page in iter(int, 1):
|
||||
try:
|
||||
page_data = (await api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
|
||||
except apihttp.ClientResponseError as error:
|
||||
if error.status == 404:
|
||||
break
|
||||
|
||||
except Exception as error:
|
||||
log.error("Exception while gathering page data %s", error)
|
||||
break
|
||||
|
||||
subscriptions.extend(page_data)
|
||||
|
||||
async def process_subscriptions(self, api, subscriptions):
|
||||
for sub in subscriptions:
|
||||
if not sub.active or not sub.channel_count:
|
||||
continue
|
||||
|
||||
unparsed_feed = await get_unparsed_feed(sub.url, api.session)
|
||||
parsed_feed = await parse(unparsed_feed)
|
||||
|
||||
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
||||
await self.process_items(api, sub, rss_feed)
|
||||
|
||||
async def process_items(self, api, sub, feed):
|
||||
|
||||
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels()]
|
||||
filters = [await api.get_filter(filter_id) for filter_id in sub.filters]
|
||||
|
||||
for item in sub.items:
|
||||
blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
||||
mutated_item = item.create_mutated_copy(sub.mutators)
|
||||
|
||||
for channel in channels:
|
||||
await self.mark_tracked_item(api, sub, item, channel.id, blocked)
|
||||
|
||||
if not blocked:
|
||||
channel.send(embed=item.to_embed(sub, feed))
|
||||
|
||||
async def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||
"""
|
||||
Returns True if item should be ignored due to filters.
|
||||
"""
|
||||
|
||||
match_found = False # This is the flag to determine if the content should be filtered
|
||||
|
||||
keywords = _filter["keywords"].split(",")
|
||||
regex_pattern = _filter["regex"]
|
||||
is_whitelist = _filter["whitelist"]
|
||||
|
||||
log.debug(
|
||||
"trying filter '%s', keyword '%s', regex '%s', is whitelist: '%s'",
|
||||
_filter["name"], keywords, regex_pattern, is_whitelist
|
||||
)
|
||||
|
||||
assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used."
|
||||
|
||||
if any(word in item.title or word in item.description for word in keywords):
|
||||
match_found = True
|
||||
|
||||
if regex_pattern:
|
||||
regex = re.compile(regex_pattern)
|
||||
match_found = regex.search(item.title) or regex.search(item.description)
|
||||
|
||||
return not match_found if is_whitelist else match_found
|
||||
|
||||
async def mark_tracked_item(self, sub, item, channel_id, blocked):
|
||||
try:
|
||||
api.create_tracked_content(
|
||||
guid=item.guid,
|
||||
title=item.title,
|
||||
url=item.url,
|
||||
subscription=sub.id,
|
||||
channel_id=channel_id,
|
||||
blocked=blocked
|
||||
)
|
||||
except aiohttp.ClientResponseError as error:
|
||||
if error.status == 409:
|
||||
log.debug(error)
|
||||
else:
|
||||
log.error(error)
|
@ -4,7 +4,6 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
from os import getenv
|
||||
@ -15,249 +14,129 @@ from discord import TextChannel, Embed, Colour
|
||||
from discord import app_commands
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
from sqlalchemy import insert, select, and_
|
||||
from feedparser import parse
|
||||
|
||||
from feed import Source, Article, RSSFeed, Subscription, SubscriptionChannel, SubChannel
|
||||
from db import (
|
||||
DatabaseManager,
|
||||
FeedChannelModel,
|
||||
RssSourceModel,
|
||||
SentArticleModel
|
||||
)
|
||||
from feed import RSSFeed, Subscription, RSSItem
|
||||
from utils import get_unparsed_feed
|
||||
from api import API
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TASK_INTERVAL_MINUTES = getenv("TASK_INTERVAL_MINUTES")
|
||||
|
||||
# task trigger times : must be of type list
|
||||
times = [
|
||||
subscription_task_times = [
|
||||
datetime.time(hour, minute, tzinfo=datetime.timezone.utc)
|
||||
for hour in range(24)
|
||||
for minute in range(0, 60, int(TASK_INTERVAL_MINUTES))
|
||||
]
|
||||
|
||||
log.debug("Task will trigger every %s minutes", TASK_INTERVAL_MINUTES)
|
||||
|
||||
|
||||
class TaskCog(commands.Cog):
|
||||
"""
|
||||
Tasks cog.
|
||||
Tasks cog for PYRSS.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.time = None
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
"""Instructions to execute when the cog is ready."""
|
||||
|
||||
# if not self.bot.developing:
|
||||
self.rss_task.start()
|
||||
|
||||
"""
|
||||
Instructions to execute when the cog is ready.
|
||||
"""
|
||||
self.subscription_task.start()
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
@commands.Cog.listener(name="cog_unload")
|
||||
async def on_unload(self):
|
||||
"""Instructions to execute before the cog is unloaded."""
|
||||
|
||||
self.rss_task.cancel()
|
||||
"""
|
||||
Instructions to execute before the cog is unloaded.
|
||||
"""
|
||||
self.subscription_task.cancel()
|
||||
|
||||
@app_commands.command(name="debug-trigger-task")
|
||||
async def debug_trigger_task(self, inter):
|
||||
await inter.response.defer()
|
||||
await self.rss_task()
|
||||
await self.subscription_task()
|
||||
await inter.followup.send("done")
|
||||
|
||||
@tasks.loop(time=times)
|
||||
async def rss_task(self):
|
||||
"""Automated task responsible for processing rss feeds."""
|
||||
|
||||
@tasks.loop(time=subscription_task_times)
|
||||
async def subscription_task(self):
|
||||
"""
|
||||
Task for fetching and processing subscriptions.
|
||||
"""
|
||||
log.info("Running subscription task")
|
||||
time = process_time()
|
||||
|
||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
data = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.bot.api_token, session)
|
||||
page = 0
|
||||
subscriptions = await self.get_subscriptions(api)
|
||||
await self.process_subscriptions(api, subscriptions)
|
||||
|
||||
while True:
|
||||
page += 1
|
||||
page_data = await self.get_subscriptions(api, guild_ids, page)
|
||||
async def get_subscriptions(self, api: API) -> list[Subscription]:
|
||||
guild_ids = [guild.id for guild in self.bot.guilds]
|
||||
sub_data = []
|
||||
|
||||
if not page_data:
|
||||
break
|
||||
for page, _ in enumerate(iter(int, 1)):
|
||||
try:
|
||||
log.debug("fetching page '%s'", page + 1)
|
||||
sub_data.extend(
|
||||
(await api.get_subscriptions(server__in=guild_ids, page=page+1))[0]
|
||||
)
|
||||
except aiohttp.ClientResponseError as error:
|
||||
match error.status:
|
||||
case 404:
|
||||
log.debug("final page reached '%s'", page)
|
||||
break
|
||||
case 403:
|
||||
log.critical(error)
|
||||
self.subscription_task.cancel()
|
||||
return [] # returning an empty list should gracefully end the task
|
||||
case _:
|
||||
log.error(error)
|
||||
break
|
||||
|
||||
data.extend(page_data)
|
||||
log.debug("extending data by '%s' items", len(page_data))
|
||||
except Exception as error:
|
||||
log.error("Exception while gathering page data %s", error)
|
||||
break
|
||||
|
||||
log.debug("finished api data collection, browsed %s pages for %s subscriptions", page, len(data))
|
||||
|
||||
subscriptions = Subscription.from_list(data)
|
||||
for sub in subscriptions:
|
||||
await self.process_subscription(api, session, sub)
|
||||
return Subscription.from_list(sub_data)
|
||||
|
||||
log.info("Finished subscription task, time elapsed: %s", process_time() - time)
|
||||
async def process_subscriptions(self, api: API, subscriptions: list[Subscription]):
|
||||
for sub in subscriptions:
|
||||
log.debug("processing subscription '%s'", sub.id)
|
||||
|
||||
async def get_subscriptions(self, api, guild_ids: list[int], page: int):
|
||||
if not sub.active or not sub.channels_count:
|
||||
continue
|
||||
|
||||
log.debug("attempting to get subscriptions for page: %s", page)
|
||||
unparsed_feed = await get_unparsed_feed(sub.url, api.session)
|
||||
parsed_feed = parse(unparsed_feed)
|
||||
|
||||
try:
|
||||
return (await api.get_subscriptions(server__in=guild_ids, page=page))[0]
|
||||
except aiohttp.ClientResponseError as error:
|
||||
if error.status == 404:
|
||||
log.debug(error)
|
||||
return []
|
||||
rss_feed = RSSFeed.from_parsed_feed(parsed_feed)
|
||||
await self.process_items(api, sub, rss_feed)
|
||||
|
||||
log.error(error)
|
||||
|
||||
async def process_subscription(self, api: API, session: aiohttp.ClientSession, sub: Subscription):
|
||||
"""
|
||||
Process a given Subscription.
|
||||
"""
|
||||
|
||||
log.debug("processing subscription '%s' '%s' for '%s'", sub.id, sub.name, sub.guild_id)
|
||||
|
||||
if not sub.active:
|
||||
log.debug("skipping sub because it's active flag is 'False'")
|
||||
return
|
||||
|
||||
channels: list[TextChannel] = [self.bot.get_channel(subchannel.channel_id) for subchannel in await sub.get_channels(api)]
|
||||
if not channels:
|
||||
log.warning("No channels to send this to")
|
||||
return
|
||||
async def process_items(self, api: API, sub: Subscription, feed: RSSFeed):
|
||||
log.debug("processing items")
|
||||
|
||||
channels = [self.bot.get_channel(channel.channel_id) for channel in await sub.get_channels(api)]
|
||||
filters = [await api.get_filter(filter_id) for filter_id in sub.filters]
|
||||
log.debug("found %s filter(s)", len(filters))
|
||||
|
||||
unparsed_content = await get_unparsed_feed(sub.url, session)
|
||||
parsed_content = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_content)
|
||||
articles = source.get_latest_articles(10)
|
||||
articles.reverse()
|
||||
for item in feed.items:
|
||||
log.debug("processing item '%s'", item.guid)
|
||||
|
||||
if not articles:
|
||||
log.debug("No articles found")
|
||||
|
||||
embeds = await self.get_articles_as_embeds(api, session, sub.id, sub.mutators, filters, articles, Colour.from_str("#" + sub.embed_colour))
|
||||
await self.send_embeds_in_chunks(embeds, channels)
|
||||
|
||||
async def get_articles_as_embeds(
|
||||
self,
|
||||
api: API,
|
||||
session: aiohttp.ClientSession,
|
||||
sub_id: int,
|
||||
mutators: dict[str, list[dict]],
|
||||
filters: list[dict],
|
||||
articles: list[Article],
|
||||
embed_colour: str
|
||||
) -> list[Embed]:
|
||||
"""
|
||||
Process articles and return their respective embeds.
|
||||
"""
|
||||
|
||||
embeds = []
|
||||
for article in articles:
|
||||
embed = await self.process_article(api, session, sub_id, mutators, filters, article, embed_colour)
|
||||
if embed:
|
||||
embeds.append(embed)
|
||||
|
||||
return embeds
|
||||
|
||||
async def send_embeds_in_chunks(self, embeds: list[Embed], channels: list[TextChannel], embed_limit=10):
|
||||
"""
|
||||
Send embeds to a list of `TextChannel` in chunks of `embed_limit` size.
|
||||
"""
|
||||
|
||||
log.debug("about to send %s embeds")
|
||||
|
||||
for i in range(0, len(embeds), embed_limit):
|
||||
embeds_chunk = embeds[i:i + embed_limit]
|
||||
|
||||
log.debug("sending chunk of %s embeds", len(embeds_chunk))
|
||||
blocked = any(self.filter_item(_filter, item) for _filter in filters)
|
||||
mutated_item = item.create_mutated_copy(sub.mutators)
|
||||
|
||||
for channel in channels:
|
||||
await self.try_send_embeds(embeds, channel)
|
||||
successful_track = await self.mark_tracked_item(api, sub, item, channel.id, blocked)
|
||||
|
||||
async def try_send_embeds(self, embeds: list[Embed], channel: TextChannel):
|
||||
if successful_track and not blocked:
|
||||
await channel.send(embed=await item.to_embed(sub, feed, api.session))
|
||||
|
||||
def filter_item(self, _filter: dict, item: RSSItem) -> bool:
|
||||
"""
|
||||
Attempt to send embeds to a given `TextChannel`. Gracefully handles errors.
|
||||
"""
|
||||
|
||||
try:
|
||||
await channel.send(embeds=embeds)
|
||||
|
||||
except Forbidden:
|
||||
log.debug(
|
||||
"Forbidden from sending embed to channel '%s', guild '%s'",
|
||||
channel.id, channel.guild.id
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
log.error(exc)
|
||||
|
||||
async def process_article(
|
||||
self,
|
||||
api: API,
|
||||
session: aiohttp.ClientSession,
|
||||
sub_id: int,
|
||||
mutators: dict[str, list[dict]],
|
||||
filters: list[dict],
|
||||
article: Article,
|
||||
embed_colour: str
|
||||
) -> Embed | None:
|
||||
"""
|
||||
Process a given Article.
|
||||
Returns an Embed representing the given Article.
|
||||
"""
|
||||
|
||||
log.debug("processing article '%s' '%s'", article.guid, article.title)
|
||||
|
||||
blocked = any(self.filter_article(_filter, article) for _filter in filters)
|
||||
log.debug("filter result: %s", "blocked" if blocked else "ok")
|
||||
|
||||
self.mutate_article(article, mutators)
|
||||
|
||||
try:
|
||||
await api.create_tracked_content(
|
||||
guid=article.guid,
|
||||
title=article.title,
|
||||
url=article.url,
|
||||
subscription=sub_id,
|
||||
blocked=blocked,
|
||||
channel_id="-_-"
|
||||
)
|
||||
log.debug("successfully tracked %s", article.guid)
|
||||
|
||||
except aiohttp.ClientResponseError as error:
|
||||
if error.status == 409:
|
||||
log.debug("It looks like this article already exists, skipping")
|
||||
else:
|
||||
log.error(error)
|
||||
|
||||
return
|
||||
|
||||
if not blocked:
|
||||
return await article.to_embed(session, embed_colour)
|
||||
|
||||
def mutate_article(self, article: Article, mutators: list[dict]):
|
||||
|
||||
for mutator in mutators["title"]:
|
||||
article.mutate("title", mutator)
|
||||
|
||||
for mutator in mutators["desc"]:
|
||||
article.mutate("description", mutator)
|
||||
|
||||
def filter_article(self, _filter: dict, article: Article) -> bool:
|
||||
"""
|
||||
Returns True if article should be ignored due to filters.
|
||||
Returns True if item should be ignored due to filters.
|
||||
"""
|
||||
|
||||
match_found = False # This is the flag to determine if the content should be filtered
|
||||
@ -273,15 +152,35 @@ class TaskCog(commands.Cog):
|
||||
|
||||
assert not (keywords and regex_pattern), "Keywords and Regex used, only 1 can be used."
|
||||
|
||||
if any(word in article.title or word in article.description for word in keywords):
|
||||
match_found = True
|
||||
|
||||
if regex_pattern:
|
||||
regex = re.compile(regex_pattern)
|
||||
match_found = regex.search(article.title) or regex.search(article.description)
|
||||
match_found = regex.search(item.title) or regex.search(item.description)
|
||||
else:
|
||||
match_found = any(word in item.title or word in item.description for word in keywords)
|
||||
|
||||
return not match_found if is_whitelist else match_found
|
||||
|
||||
async def mark_tracked_item(self, api: API, sub: Subscription, item: RSSItem, channel_id: int, blocked: bool):
|
||||
try:
|
||||
log.debug("marking as tracked 'blocked: %s'", blocked)
|
||||
await api.create_tracked_content(
|
||||
guid=item.guid,
|
||||
title=item.title,
|
||||
url=item.link,
|
||||
subscription=sub.id,
|
||||
channel_id=channel_id,
|
||||
blocked=blocked
|
||||
)
|
||||
return True
|
||||
except aiohttp.ClientResponseError as error:
|
||||
if error.status == 409:
|
||||
log.debug(error)
|
||||
else:
|
||||
log.error(error)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
|
449
src/feed.py
449
src/feed.py
@ -4,7 +4,6 @@ import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
from random import shuffle, sample
|
||||
|
||||
import aiohttp
|
||||
import validators
|
||||
@ -12,30 +11,25 @@ from discord import Embed, Colour
|
||||
from bs4 import BeautifulSoup as bs4
|
||||
from feedparser import FeedParserDict, parse
|
||||
from markdownify import markdownify
|
||||
from sqlalchemy import select, insert, delete, and_
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from textwrap import shorten
|
||||
from feedparser import parse
|
||||
|
||||
from mutators import registry as mutator_registry
|
||||
from errors import IllegalFeed
|
||||
from db import DatabaseManager, RssSourceModel, FeedChannelModel
|
||||
from utils import get_rss_data, get_unparsed_feed
|
||||
from utils import get_unparsed_feed
|
||||
from api import API
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||
|
||||
|
||||
from xml.etree.ElementTree import Element, SubElement, tostring
|
||||
from feedparser import parse
|
||||
|
||||
class RSSItem:
|
||||
def __init__(self, title, link, description, pub_date, guid):
|
||||
def __init__(self, title, link, description, pub_date, guid, image_url):
|
||||
self.title = title
|
||||
self.link = link
|
||||
self.description = description
|
||||
self.pub_date = pub_date
|
||||
self.guid = guid
|
||||
self.image_url = image_url
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
@ -43,7 +37,7 @@ class RSSItem:
|
||||
def create_mutated_copy(self, mutators):
|
||||
pass
|
||||
|
||||
def create_embed(self, sub, feed):
|
||||
async def to_embed(self, sub, feed, session):
|
||||
|
||||
title = shorten(markdownify(self.title, strip=["img", "a"]), 256)
|
||||
desc = shorten(markdownify(self.description, strip=["img"]), 4096)
|
||||
@ -60,7 +54,6 @@ class RSSItem:
|
||||
# validate urls
|
||||
# author_url = self.source.url if validators.url(self.source.url) else None
|
||||
# icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
|
||||
# thumb_url = await self.get_thumbnail_url(session) # validation done inside func
|
||||
|
||||
# Combined length validation
|
||||
# Can't exceed combined 6000 characters, [400 Bad Request] if failed.
|
||||
@ -71,27 +64,64 @@ class RSSItem:
|
||||
embed = Embed(
|
||||
title=title,
|
||||
description=desc,
|
||||
timestamp=self.published,
|
||||
# timestamp=self.published,
|
||||
url=self.link if validators.url(self.link) else None,
|
||||
colour=colour
|
||||
colour=Colour.from_str("#" + sub.embed_colour)
|
||||
)
|
||||
|
||||
# embed.set_thumbnail(url=icon_url)
|
||||
log.debug("has image url without search: '%s'", self.image_url)
|
||||
|
||||
if sub.article_fetch_image:
|
||||
embed.set_image(url=self.image_url or await self.get_thumbnail_url(session))
|
||||
embed.set_thumbnail(url=feed.image_href if validators.url(feed.image_href) else None)
|
||||
|
||||
# embed.set_image(url=thumb_url)
|
||||
# embed.set_author(url=author_url, name=author)
|
||||
# embed.set_footer(text=self.author)
|
||||
|
||||
return embed
|
||||
|
||||
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
|
||||
"""Returns the thumbnail URL for an article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : aiohttp.ClientSession
|
||||
A client session used to get the thumbnail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The thumbnail URL, or None if not found.
|
||||
"""
|
||||
|
||||
# log.debug("Fetching thumbnail for article: %s", self)
|
||||
|
||||
try:
|
||||
async with session.get(self.link, timeout=15) as response:
|
||||
html = await response.text()
|
||||
except aiohttp.InvalidURL as error:
|
||||
log.error("invalid thumbnail url: %s", error)
|
||||
return None
|
||||
|
||||
soup = bs4(html, "html.parser")
|
||||
image_element = soup.select_one("meta[property='og:image']")
|
||||
if not image_element:
|
||||
return None
|
||||
|
||||
image_content = image_element.get("content")
|
||||
return image_content if validators.url(image_content) else None
|
||||
|
||||
import json
|
||||
class RSSFeed:
|
||||
def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None):
|
||||
def __init__(self, title, link, description, language='en-gb', pub_date=None, last_build_date=None, image_href=None):
|
||||
self.title = title
|
||||
self.link = link
|
||||
self.description = description
|
||||
self.language = language
|
||||
self.pub_date = pub_date
|
||||
self.last_build_date = last_build_date
|
||||
self.image_href = image_href
|
||||
self.items = []
|
||||
|
||||
def add_item(self, item: RSSItem):
|
||||
@ -111,8 +141,9 @@ class RSSFeed:
|
||||
language = parsed_feed.feed.get('language', 'en-gb')
|
||||
pub_date = parsed_feed.feed.get('published', None)
|
||||
last_build_date = parsed_feed.feed.get('updated', None)
|
||||
image_href = parsed_feed.get("image", {}).get("href")
|
||||
|
||||
feed = cls(title, link, description, language, pub_date, last_build_date)
|
||||
feed = cls(title, link, description, language, pub_date, last_build_date, image_href)
|
||||
|
||||
for entry in parsed_feed.entries:
|
||||
item_title = entry.get('title', 'No title')
|
||||
@ -120,10 +151,11 @@ class RSSFeed:
|
||||
item_description = entry.get('description', 'No description')
|
||||
item_pub_data = entry.get('published_parsed', None)
|
||||
item_guid = entry.get('id', None) or entry.get("guid", None)
|
||||
item_image_url = entry.get("media_content", [{}])[0].get("url")
|
||||
|
||||
item_published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
|
||||
# item_published = datetime(*entry.published_parsed[0:-2]) if item_pub_data else None
|
||||
|
||||
item = RSSItem(item_title, item_link, item_description, item_published, item_guid)
|
||||
item = RSSItem(item_title, item_link, item_description, item_pub_data, item_guid, item_image_url)
|
||||
feed.add_item(item)
|
||||
|
||||
feed.items.reverse()
|
||||
@ -131,244 +163,6 @@ class RSSFeed:
|
||||
return feed
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RSSArticle:
|
||||
"""Represents a news article, or entry from an RSS feed."""
|
||||
|
||||
guid: str
|
||||
title: str | None
|
||||
description: str | None
|
||||
url: str | None
|
||||
published: datetime | None
|
||||
author: str | None
|
||||
source: object
|
||||
|
||||
@classmethod
|
||||
def from_entry(cls, source, entry:FeedParserDict):
|
||||
"""Create an Article from an RSS feed entry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
entry : FeedParserDict
|
||||
An entry pulled from a complete FeedParserDict object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Article
|
||||
The Article created from the feed entry.
|
||||
"""
|
||||
|
||||
# log.debug("Creating Article from entry: %s", dumps(entry))
|
||||
|
||||
published_parsed = entry.get("published_parsed")
|
||||
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
|
||||
|
||||
return cls(
|
||||
guid=entry["guid"],
|
||||
title=entry.get("title"),
|
||||
description=entry.get("description"),
|
||||
url=entry.get("link"),
|
||||
published=published,
|
||||
author = entry.get("author"),
|
||||
source=source
|
||||
)
|
||||
|
||||
def mutate(self, attr: str, mutator: dict[str, str]):
|
||||
"""
|
||||
Apply a mutation to a certain text attribute of
|
||||
this Article instance.
|
||||
"""
|
||||
|
||||
# WARN:
|
||||
# This could be really bad if the end user is able to effect the 'attr' value.
|
||||
# Shouldn't happen though.
|
||||
|
||||
log.debug("applying mutator '%s'", mutator["name"])
|
||||
val = mutator["value"]
|
||||
|
||||
try:
|
||||
mutator = mutator_registry.get_mutator(val)
|
||||
except ValueError as err:
|
||||
log.error(err)
|
||||
|
||||
setattr(self, attr, mutator.mutate(getattr(self, attr)))
|
||||
log.debug("mutated %s, to: %s", attr, getattr(self, attr))
|
||||
|
||||
|
||||
async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None:
|
||||
"""Returns the thumbnail URL for an article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : aiohttp.ClientSession
|
||||
A client session used to get the thumbnail.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The thumbnail URL, or None if not found.
|
||||
"""
|
||||
|
||||
# log.debug("Fetching thumbnail for article: %s", self)
|
||||
|
||||
try:
|
||||
async with session.get(self.url, timeout=15) as response:
|
||||
html = await response.text()
|
||||
except aiohttp.InvalidURL as error:
|
||||
log.error("invalid thumbnail url: %s", error)
|
||||
return None
|
||||
|
||||
soup = bs4(html, "html.parser")
|
||||
image_element = soup.select_one("meta[property='og:image']")
|
||||
if not image_element:
|
||||
return None
|
||||
|
||||
image_content = image_element.get("content")
|
||||
return image_content if validators.url(image_content) else None
|
||||
|
||||
async def to_embed(self, session: aiohttp.ClientSession, colour: Colour) -> Embed:
|
||||
"""Creates and returns a Discord Embed object from the article.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : aiohttp.ClientSession
|
||||
A client session used to get additional article data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Embed
|
||||
A Discord Embed object representing the article.
|
||||
"""
|
||||
|
||||
# log.debug(f"Creating embed from article: {self}")
|
||||
|
||||
# Replace HTML with Markdown, and shorten text.
|
||||
title = shorten(markdownify(self.title, strip=["img", "a"]), 256)
|
||||
desc = shorten(markdownify(self.description, strip=["img"]), 4096)
|
||||
author = shorten(self.source.name, 256)
|
||||
|
||||
# validate urls
|
||||
embed_url = self.url if validators.url(self.url) else None
|
||||
author_url = self.source.url if validators.url(self.source.url) else None
|
||||
icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None
|
||||
thumb_url = await self.get_thumbnail_url(session) # validation done inside func
|
||||
|
||||
# Combined length validation
|
||||
# Can't exceed combined 6000 characters, [400 Bad Request] if failed.
|
||||
combined_length = len(title) + len(desc) + (len(author) * 2)
|
||||
cutoff = combined_length - 6000
|
||||
desc = shorten(desc, cutoff) if cutoff > 0 else desc
|
||||
|
||||
embed = Embed(
|
||||
title=title,
|
||||
description=desc,
|
||||
timestamp=self.published,
|
||||
url=embed_url,
|
||||
colour=colour
|
||||
)
|
||||
|
||||
embed.set_thumbnail(url=icon_url)
|
||||
embed.set_image(url=thumb_url)
|
||||
embed.set_author(url=author_url, name=author)
|
||||
embed.set_footer(text=self.author)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
@dataclass
|
||||
class RSSFeedSource:
|
||||
"""Represents an RSS Feed."""
|
||||
|
||||
name: str | None
|
||||
description: str | None
|
||||
url: str | None
|
||||
icon_url: str | None
|
||||
feed: FeedParserDict
|
||||
|
||||
@classmethod
|
||||
def from_parsed(cls, feed:FeedParserDict):
|
||||
"""Returns a Source object from a parsed feed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feed : FeedParserDict
|
||||
The feed used to create the Source.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
The Source object
|
||||
"""
|
||||
|
||||
# log.debug("Creating Source from feed: %s", dumps(feed))
|
||||
|
||||
channel = feed.get("channel", {})
|
||||
|
||||
return cls(
|
||||
name=channel.get("title"),
|
||||
description=channel.get("description"),
|
||||
url=channel.get("link"),
|
||||
icon_url=feed.get("feed", {}).get("image", {}).get("href"),
|
||||
feed=feed
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str):
|
||||
unparsed_content = await get_unparsed_feed(url)
|
||||
source = cls.from_parsed(parse(unparsed_content))
|
||||
source.url = url
|
||||
return source
|
||||
|
||||
def get_latest_articles(self, max: int = 999) -> list[Article]:
|
||||
"""Returns a list of Article objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max : int
|
||||
The maximum number of articles to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of Article
|
||||
A list of Article objects.
|
||||
"""
|
||||
|
||||
# log.debug("Fetching latest articles from %s, max=%s", self, max)
|
||||
|
||||
return [
|
||||
Article.from_entry(self, entry)
|
||||
for i, entry in enumerate(self.feed.entries)
|
||||
if i < max
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RSSFeedSource_:
|
||||
|
||||
uuid: str
|
||||
name: str
|
||||
url: str
|
||||
image: str
|
||||
discord_server_id: id
|
||||
created_at: str
|
||||
|
||||
@classmethod
|
||||
def from_list(cls, data: list) -> list:
|
||||
result = []
|
||||
|
||||
for item in data:
|
||||
key = "discord_server_id"
|
||||
item[key] = int(item.get(key))
|
||||
result.append(cls(**item))
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
return cls(**data)
|
||||
|
||||
|
||||
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
|
||||
|
||||
@ -401,6 +195,7 @@ class Subscription(DjangoDataModel):
|
||||
extra_notes: str
|
||||
filters: list[int]
|
||||
mutators: dict[str, list[dict]]
|
||||
article_fetch_image: bool
|
||||
embed_colour: str
|
||||
active: bool
|
||||
channels_count: int
|
||||
@ -488,143 +283,3 @@ class TrackedContent(DjangoDataModel):
|
||||
|
||||
item["creation_datetime"] = datetime.strptime(item["creation_datetime"], DATETIME_FORMAT)
|
||||
return item
|
||||
|
||||
|
||||
class Functions:
|
||||
|
||||
def __init__(self, bot, api_token: str):
|
||||
self.bot = bot
|
||||
self.api_token = api_token
|
||||
|
||||
async def validate_feed(self, nickname: str, url: str) -> FeedParserDict:
|
||||
"""Validates a feed based on the given nickname and url.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Human readable nickname used to refer to the feed.
|
||||
url : str
|
||||
URL to fetch content from the feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FeedParserDict
|
||||
A Parsed Dictionary of the feed.
|
||||
|
||||
Raises
|
||||
------
|
||||
IllegalFeed
|
||||
If the feed is invalid.
|
||||
"""
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
raise IllegalFeed(f"The URL you have entered is malformed or invalid:\n`{url=}`")
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
raise IllegalFeed(
|
||||
"It looks like the nickname you have entered is a URL.\n" \
|
||||
"For security reasons, this is not allowed.",
|
||||
nickname=nickname
|
||||
)
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
if status_code != 200:
|
||||
raise IllegalFeed(
|
||||
"The URL provided returned an invalid status code:",
|
||||
url=url, status_code=status_code
|
||||
)
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = parse(feed_data)
|
||||
if not feed.version:
|
||||
raise IllegalFeed(
|
||||
"The provided URL does not seem to be a valid RSS feed.",
|
||||
url=url
|
||||
)
|
||||
|
||||
return feed
|
||||
|
||||
async def create_new_rssfeed(self, name: str, url: str, guild_id: int) -> RSSFeed:
|
||||
|
||||
|
||||
log.info("Creating new Feed: %s", name)
|
||||
|
||||
parsed_feed = await self.validate_feed(name, url)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data = await API(self.api_token, session).create_new_rssfeed(
|
||||
name, url, source.icon_url, guild_id
|
||||
)
|
||||
|
||||
return RSSFeed.from_dict(data)
|
||||
|
||||
async def delete_rssfeed(self, uuid: str) -> RSSFeed:
|
||||
|
||||
log.info("Deleting Feed '%s'", uuid)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api = API(self.api_token, session)
|
||||
data = await api.get_rssfeed(uuid)
|
||||
await api.delete_rssfeed(uuid)
|
||||
|
||||
return RSSFeed.from_dict(data)
|
||||
|
||||
async def get_rssfeeds(self, guild_id: int, page: int, pagesize: int) -> list[RSSFeed]:
|
||||
"""Get a list of RSS Feeds.
|
||||
|
||||
Args:
|
||||
guild_id (int): The guild_id to filter by.
|
||||
|
||||
Returns:
|
||||
list[RSSFeed]: Resulting list of RSS Feeds
|
||||
"""
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
data, count = await API(self.api_token, session).get_rssfeed_list(
|
||||
discord_server_id=guild_id,
|
||||
page=page,
|
||||
page_size=pagesize
|
||||
)
|
||||
|
||||
return RSSFeed.from_list(data), count
|
||||
|
||||
async def assign_feed(
|
||||
self, url: str, channel_name: str, channel_id: int, guild_id: int
|
||||
) -> tuple[int, Source]:
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_query = select(RssSourceModel).where(and_(
|
||||
RssSourceModel.rss_url == url,
|
||||
RssSourceModel.discord_server_id == guild_id
|
||||
))
|
||||
select_result = await database.session.execute(select_query)
|
||||
|
||||
rss_source = select_result.scalars().one()
|
||||
|
||||
insert_query = insert(FeedChannelModel).values(
|
||||
discord_server_id = guild_id,
|
||||
discord_channel_id = channel_id,
|
||||
rss_source_id=rss_source.id,
|
||||
search_name=f"{rss_source.nick} #{channel_name}"
|
||||
)
|
||||
|
||||
insert_result = await database.session.execute(insert_query)
|
||||
return insert_result.inserted_primary_key.id, await Source.from_url(url)
|
||||
|
||||
async def unassign_feed( self, assigned_feed_id: int, guild_id: int):
|
||||
""""""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = delete(FeedChannelModel).where(and_(
|
||||
FeedChannelModel.id == assigned_feed_id,
|
||||
FeedChannelModel.discord_server_id == guild_id
|
||||
))
|
||||
|
||||
result = await database.session.execute(query)
|
||||
if not result.rowcount:
|
||||
raise NoResultFound
|
||||
|
54
src/tests.py
54
src/tests.py
@ -1,54 +0,0 @@
|
||||
|
||||
import unittest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine.cursor import CursorResult
|
||||
from sqlalchemy.engine.result import ChunkedIteratorResult
|
||||
|
||||
from db import DatabaseManager, FeedChannelModel, AuditModel
|
||||
|
||||
|
||||
class TestDatabaseConnections(unittest.IsolatedAsyncioTestCase):
|
||||
"""The purpose of this test, is to ensure that the database connections function properly."""
|
||||
|
||||
async def test_select__feed_channel_model(self):
|
||||
"""This test runs a select query on the `FeedChannelModel`"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(FeedChannelModel).limit(1000)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
self.assertIsInstance(
|
||||
result,
|
||||
ChunkedIteratorResult,
|
||||
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
|
||||
)
|
||||
|
||||
async def test_select__rss_source_model(self):
|
||||
"""This test runs a select query on the `RssSourceModel`"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(RssSourceModel).limit(1000)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
self.assertIsInstance(
|
||||
result,
|
||||
ChunkedIteratorResult,
|
||||
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
|
||||
)
|
||||
|
||||
async def test_select__audit_model(self):
|
||||
"""This test runs a select query on the `AuditModel`"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(AuditModel).limit(1000)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
self.assertIsInstance(
|
||||
result,
|
||||
ChunkedIteratorResult,
|
||||
f"Result should be `ChunkedIteratorResult`, not {type(result)!r}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -42,15 +42,9 @@ async def followup(inter: Interaction, *args, **kwargs):
|
||||
|
||||
await inter.followup.send(*args, **kwargs)
|
||||
|
||||
async def audit(cog, *args, **kwargs):
|
||||
"""Shorthand for auditing an interaction."""
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
||||
|
||||
|
||||
# https://img.icons8.com/fluency-systems-filled/48/FA5252/trash.png
|
||||
|
||||
|
||||
class FollowupIcons:
|
||||
error = "https://img.icons8.com/fluency-systems-filled/48/DC573C/box-important.png"
|
||||
success = "https://img.icons8.com/fluency-systems-filled/48/5BC873/ok--v1.png"
|
||||
|
Loading…
x
Reference in New Issue
Block a user