working on tasks

This commit is contained in:
Corban-Lee Jones 2023-12-20 21:51:57 +00:00
parent 8a1f623c6f
commit 0e5af1d752
8 changed files with 152 additions and 67 deletions

View File

@ -7,3 +7,11 @@ Plans
- Multiple news providers
- Choose how much of each provider should be delivered
- Check for duplicate articles between providers, and only deliver preferred provider article
## Dev Notes:
For the sake of development, the following defintions apply:
- Feed - An RSS feed stored within the database, submitted by a user.
- Assigned Feed - A discord channel set to receive content from a Feed.

View File

@ -18,6 +18,7 @@ from sqlalchemy import (
Base = declarative_base()
# back in wed, thu, fri, off new year day then back in after
class AuditModel(Base):
"""
@ -81,7 +82,7 @@ class FeedChannelModel(Base):
search_name = Column(String, nullable=False)
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined")
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
# the rss source must be unique, but only within the same discord channel
__table_args__ = (

View File

@ -12,10 +12,11 @@ from discord.ext import commands
from discord import Interaction, Embed, Colour, TextChannel
from discord.app_commands import Choice, Group, autocomplete, choices, rename
from sqlalchemy import insert, select, and_, delete
from sqlalchemy.exc import NoResultFound
from utils import get_rss_data, followup, audit # pylint: disable=E0401
from feed import get_source, Source # pylint: disable=E0401
from db import ( # pylint: disable=E0401
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
from feed import get_source, Source # pylint: disable=E0401
from db import ( # pylint: disable=E0401
DatabaseManager,
SentArticleModel,
RssSourceModel,
@ -28,6 +29,11 @@ rss_list_sort_choices = [
Choice(name="Nickname", value=0),
Choice(name="Date Added", value=1)
]
channels_list_sort_choices=[
Choice(name="Feed Nickname", value=0),
Choice(name="Channel ID", value=1),
Choice(name="Date Added", value=2)
]
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
# target resource. Run a period task to check this.
@ -169,6 +175,7 @@ class FeedCog(commands.Cog):
await followup(inter, embed=embed)
@feed_group.command(name="remove")
@rename(url="option")
@autocomplete(url=source_autocomplete)
async def remove_rss_source(self, inter: Interaction, url: str):
"""Delete an existing RSS source.
@ -186,35 +193,34 @@ class FeedCog(commands.Cog):
log.debug("Attempting to remove RSS source (url=%s)", url)
async with DatabaseManager() as database:
select_result = await database.session.execute(
select(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
)
whereclause = and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
rss_source = select_result.scalars().one()
# We will select the item first, so we can reference it's nickname later.
select_query = select(RssSourceModel).filter(whereclause)
select_result = await database.session.execute(select_query)
try:
rss_source = select_result.scalars().one()
except NoResultFound:
await followup_error(inter,
title="Error Deleting Feed",
message=f"I couldn't find anything for `{url}`"
)
return
nickname = rss_source.nick
delete_result = await database.session.execute(
delete(RssSourceModel).filter(
and_(
RssSourceModel.discord_server_id == inter.guild_id,
RssSourceModel.rss_url == url
)
)
)
delete_query = delete(RssSourceModel).filter(whereclause)
delete_result = await database.session.execute(delete_query)
await audit(self,
f"Added RSS source ({nickname=}, {url=})",
f"Deleted RSS source ({nickname=}, {url=})",
inter.user.id, database=database
)
if not delete_result.rowcount:
await followup(inter, "Couldn't find any RSS sources with this name.")
return
source = get_source(url)
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
@ -269,7 +275,10 @@ class FeedCog(commands.Cog):
rowcount = len(rss_sources)
if not rss_sources:
await followup(inter, "It looks like you have no rss sources.")
await followup_error(inter,
title="No Feeds Found",
message="I couldn't find any Feeds for this server."
)
return
output = "\n".join([
@ -286,9 +295,9 @@ class FeedCog(commands.Cog):
await followup(inter, embed=embed)
@feed_group.command(name="fetch")
@rename(max_="max")
@autocomplete(rss=source_autocomplete)
# @feed_group.command(name="fetch")
# @rename(max_="max")
# @autocomplete(rss=source_autocomplete)
async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1):
"""Fetch an item from the specified RSS feed.
@ -467,6 +476,7 @@ class FeedCog(commands.Cog):
# )
@feed_group.command(name="assign")
@rename(rss="feed")
@autocomplete(rss=autocomplete_rss_sources)
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
"""Include a feed within the specified channel.
@ -531,18 +541,17 @@ class FeedCog(commands.Cog):
result = await database.session.execute(query)
if not result.rowcount:
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
await followup_error(inter,
title="Assigned Feed Not Found",
message=f"I couldn't find any assigned feeds for the option: {option}"
)
return
await followup(inter, "I've removed this item (placeholder response)")
@feed_group.command(name="channels")
# @choices(sort=[
# Choice(name="RSS Nickname", value=0),
# Choice(name="Channel ID", value=1),
# Choice(name="Date Added", value=2)
# ])
async def list_feeds(self, inter: Interaction): # sort: int
@choices(sort=channels_list_sort_choices)
async def list_feeds(self, inter: Interaction, sort: Choice[int] = 0, sort_reverse: bool = False):
"""List all of the channels and their respective included feeds.
Parameters
@ -553,6 +562,34 @@ class FeedCog(commands.Cog):
await inter.response.defer()
description = "Sort By "
if isinstance(sort, Choice):
match sort.value, sort_reverse:
case 0, False:
order_by = RssSourceModel.nick.asc()
description += "Nickname "
case 0, True:
order_by = RssSourceModel.nick.desc()
description += "Nickname "
case 1, False:
order_by = FeedChannelModel.discord_channel_id.asc()
description += "Channel ID "
case 1, True:
order_by = FeedChannelModel.discord_channel_id.desc()
description += "Channel ID "
case 2, False:
order_by = RssSourceModel.created.desc()
description += "Date Added "
case 2, True:
order_by = RssSourceModel.created.asc()
description += "Date Added "
case _, _:
raise ValueError(f"Unknown sort: {sort}")
else:
order_by = FeedChannelModel.discord_channel_id.asc()
description = ""
async with DatabaseManager() as database:
whereclause = and_(
FeedChannelModel.discord_server_id == inter.guild_id,
@ -562,7 +599,7 @@ class FeedCog(commands.Cog):
select(FeedChannelModel, RssSourceModel)
.where(whereclause)
.join(RssSourceModel)
.order_by(FeedChannelModel.discord_channel_id)
.order_by(order_by)
)
result = await database.session.execute(query)
@ -570,8 +607,9 @@ class FeedCog(commands.Cog):
rowcount = len(feed_channels)
if not feed_channels:
await followup(inter,
"It looks like there are no feed channels available."
await followup_error(inter,
title="No Assigned Feeds Found",
message="Assign a channel to receive feed content with `/feed assign`."
)
return
@ -583,7 +621,7 @@ class FeedCog(commands.Cog):
embed = Embed(
title="Saved Feed Channels",
description=f"{output}",
description=f"{description}\n{output}",
colour=Colour.blue()
)
embed.set_footer(text=f"Showing {rowcount} results")

View File

@ -4,17 +4,16 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
"""
import logging
import async_timeout
from time import process_time
import aiohttp
from feedparser import parse
from sqlalchemy import insert, select, and_
from discord import Interaction, app_commands, TextChannel
from discord import Interaction, TextChannel
from discord.ext import commands, tasks
from discord.errors import Forbidden
from feed import Source, Article, get_unparsed_feed
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel
from feed import Source, Article, get_unparsed_feed # pylint disable=E0401
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
log = logging.getLogger(__name__)
@ -27,6 +26,7 @@ class TaskCog(commands.Cog):
def __init__(self, bot):
super().__init__()
self.bot = bot
self.time = None
@commands.Cog.listener()
async def on_ready(self):
@ -36,15 +36,12 @@ class TaskCog(commands.Cog):
log.info("%s cog is ready", self.__class__.__name__)
# @app_commands.command(name="test-trigger-task")
async def test_trigger_task(self, inter: Interaction):
await inter.response.defer()
await self.rss_task()
await inter.followup.send("done")
@tasks.loop(minutes=10)
async def rss_task(self):
"""Automated task responsible for processing rss feeds."""
log.info("Running rss task")
time = process_time()
async with DatabaseManager() as database:
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
@ -54,10 +51,21 @@ class TaskCog(commands.Cog):
for feed in feeds:
await self.process_feed(feed, database)
log.info("Finished rss task")
log.info("Finished rss task, time elapsed: %s", process_time() - time)
async def process_feed(self, feed: FeedChannelModel, database: DatabaseManager):
"""Process the passed feed. Will also call process for each article found in the feed.
Parameters
----------
feed : FeedChannelModel
Database model for the feed.
database : DatabaseManager
Database connection handler, must be open.
"""
log.debug("Processing feed: %s", feed.id)
async def process_feed(self, feed: FeedChannelModel, database):
log.info("Processing feed: %s", feed.id)
channel = self.bot.get_channel(feed.discord_channel_id)
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
@ -72,8 +80,22 @@ class TaskCog(commands.Cog):
for article in articles:
await self.process_article(article, channel, database)
async def process_article(self, article: Article, channel: TextChannel, database):
log.info("Processing article: %s", article.url)
async def process_article(
self, article: Article, channel: TextChannel, database: DatabaseManager
):
"""Process the passed article. Will send the embed to a channel if all is valid.
Parameters
----------
article : Article
Database model for the article.
channel : TextChannel
Where the article will be sent to.
database : DatabaseManager
Database connection handler, must be open.
"""
log.debug("Processing article: %s", article.url)
query = select(SentArticleModel).where(and_(
SentArticleModel.article_url == article.url,
@ -82,14 +104,14 @@ class TaskCog(commands.Cog):
result = await database.session.execute(query)
if result.scalars().all():
log.info("Article already processed: %s", article.url)
log.debug("Article already processed: %s", article.url)
return
embed = await article.to_embed()
try:
await channel.send(embed=embed)
except Forbidden:
log.error("Forbidden: %s · %s", channel.name, channel.id)
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
return
query = insert(SentArticleModel).values(
@ -100,8 +122,7 @@ class TaskCog(commands.Cog):
)
await database.session.execute(query)
log.info("new Article processed: %s", article.url)
log.debug("new Article processed: %s", article.url)
async def setup(bot):

View File

@ -1,6 +1,3 @@
"""
"""
import json
import logging

View File

@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class LogSetup:
def __init__(self, BASE_DIR: Path):
self.BASE_DIR = BASE_DIR
self.LOGS_DIR = BASE_DIR / "logs/"
@ -100,4 +100,4 @@ class LogSetup:
# Clear up old log files
self._delete_old_logs()
return file.name
return file.name

View File

@ -33,7 +33,7 @@ async def main():
# Setup logging settings and mute spammy loggers
logsetup = LogSetup(BASE_DIR)
logsetup.setup_logs(logging.INFO)
logsetup.setup_logs(logging.DEBUG)
logsetup.update_log_levels(
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
level=logging.WARNING

View File

@ -3,7 +3,7 @@
import aiohttp
import logging
from discord import Interaction
from discord import Interaction, Embed, Colour
log = logging.getLogger(__name__)
@ -29,3 +29,23 @@ async def audit(cog, *args, **kwargs):
"""Shorthand for auditing an interaction."""
await cog.bot.audit(*args, **kwargs)
async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs):
"""Shorthand for following up on an interaction, except returns an embed styled in
error colours.
Parameters
----------
inter : Interaction
Represents an app command interaction.
"""
await inter.followup.send(
*args,
embed=Embed(
title=title,
description=message,
colour=Colour.red()
),
**kwargs
)