working on tasks
This commit is contained in:
parent
8a1f623c6f
commit
0e5af1d752
@ -7,3 +7,11 @@ Plans
|
||||
- Multiple news providers
|
||||
- Choose how much of each provider should be delivered
|
||||
- Check for duplicate articles between providers, and only deliver preferred provider article
|
||||
|
||||
|
||||
## Dev Notes:
|
||||
|
||||
For the sake of development, the following defintions apply:
|
||||
|
||||
- Feed - An RSS feed stored within the database, submitted by a user.
|
||||
- Assigned Feed - A discord channel set to receive content from a Feed.
|
@ -18,6 +18,7 @@ from sqlalchemy import (
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# back in wed, thu, fri, off new year day then back in after
|
||||
|
||||
class AuditModel(Base):
|
||||
"""
|
||||
@ -81,7 +82,7 @@ class FeedChannelModel(Base):
|
||||
search_name = Column(String, nullable=False)
|
||||
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)
|
||||
|
||||
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined")
|
||||
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
|
||||
|
||||
# the rss source must be unique, but only within the same discord channel
|
||||
__table_args__ = (
|
||||
|
@ -12,10 +12,11 @@ from discord.ext import commands
|
||||
from discord import Interaction, Embed, Colour, TextChannel
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices, rename
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
|
||||
from utils import get_rss_data, followup, audit # pylint: disable=E0401
|
||||
from feed import get_source, Source # pylint: disable=E0401
|
||||
from db import ( # pylint: disable=E0401
|
||||
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
|
||||
from feed import get_source, Source # pylint: disable=E0401
|
||||
from db import ( # pylint: disable=E0401
|
||||
DatabaseManager,
|
||||
SentArticleModel,
|
||||
RssSourceModel,
|
||||
@ -28,6 +29,11 @@ rss_list_sort_choices = [
|
||||
Choice(name="Nickname", value=0),
|
||||
Choice(name="Date Added", value=1)
|
||||
]
|
||||
channels_list_sort_choices=[
|
||||
Choice(name="Feed Nickname", value=0),
|
||||
Choice(name="Channel ID", value=1),
|
||||
Choice(name="Date Added", value=2)
|
||||
]
|
||||
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the
|
||||
# target resource. Run a period task to check this.
|
||||
@ -169,6 +175,7 @@ class FeedCog(commands.Cog):
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@feed_group.command(name="remove")
|
||||
@rename(url="option")
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing RSS source.
|
||||
@ -186,35 +193,34 @@ class FeedCog(commands.Cog):
|
||||
log.debug("Attempting to remove RSS source (url=%s)", url)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_result = await database.session.execute(
|
||||
select(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
rss_source = select_result.scalars().one()
|
||||
|
||||
# We will select the item first, so we can reference it's nickname later.
|
||||
select_query = select(RssSourceModel).filter(whereclause)
|
||||
select_result = await database.session.execute(select_query)
|
||||
|
||||
try:
|
||||
rss_source = select_result.scalars().one()
|
||||
except NoResultFound:
|
||||
await followup_error(inter,
|
||||
title="Error Deleting Feed",
|
||||
message=f"I couldn't find anything for `{url}`"
|
||||
)
|
||||
return
|
||||
|
||||
nickname = rss_source.nick
|
||||
|
||||
delete_result = await database.session.execute(
|
||||
delete(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
)
|
||||
delete_query = delete(RssSourceModel).filter(whereclause)
|
||||
delete_result = await database.session.execute(delete_query)
|
||||
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
f"Deleted RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
if not delete_result.rowcount:
|
||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||
return
|
||||
|
||||
source = get_source(url)
|
||||
|
||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||
@ -269,7 +275,10 @@ class FeedCog(commands.Cog):
|
||||
rowcount = len(rss_sources)
|
||||
|
||||
if not rss_sources:
|
||||
await followup(inter, "It looks like you have no rss sources.")
|
||||
await followup_error(inter,
|
||||
title="No Feeds Found",
|
||||
message="I couldn't find any Feeds for this server."
|
||||
)
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
@ -286,9 +295,9 @@ class FeedCog(commands.Cog):
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@feed_group.command(name="fetch")
|
||||
@rename(max_="max")
|
||||
@autocomplete(rss=source_autocomplete)
|
||||
# @feed_group.command(name="fetch")
|
||||
# @rename(max_="max")
|
||||
# @autocomplete(rss=source_autocomplete)
|
||||
async def fetch_rss(self, inter: Interaction, rss: str, max_: int=1):
|
||||
"""Fetch an item from the specified RSS feed.
|
||||
|
||||
@ -467,6 +476,7 @@ class FeedCog(commands.Cog):
|
||||
# )
|
||||
|
||||
@feed_group.command(name="assign")
|
||||
@rename(rss="feed")
|
||||
@autocomplete(rss=autocomplete_rss_sources)
|
||||
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
|
||||
"""Include a feed within the specified channel.
|
||||
@ -531,18 +541,17 @@ class FeedCog(commands.Cog):
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if not result.rowcount:
|
||||
await followup(inter, "I couldn't find any items under that ID (placeholder response)")
|
||||
await followup_error(inter,
|
||||
title="Assigned Feed Not Found",
|
||||
message=f"I couldn't find any assigned feeds for the option: {option}"
|
||||
)
|
||||
return
|
||||
|
||||
await followup(inter, "I've removed this item (placeholder response)")
|
||||
|
||||
@feed_group.command(name="channels")
|
||||
# @choices(sort=[
|
||||
# Choice(name="RSS Nickname", value=0),
|
||||
# Choice(name="Channel ID", value=1),
|
||||
# Choice(name="Date Added", value=2)
|
||||
# ])
|
||||
async def list_feeds(self, inter: Interaction): # sort: int
|
||||
@choices(sort=channels_list_sort_choices)
|
||||
async def list_feeds(self, inter: Interaction, sort: Choice[int] = 0, sort_reverse: bool = False):
|
||||
"""List all of the channels and their respective included feeds.
|
||||
|
||||
Parameters
|
||||
@ -553,6 +562,34 @@ class FeedCog(commands.Cog):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
description = "Sort By "
|
||||
|
||||
if isinstance(sort, Choice):
|
||||
match sort.value, sort_reverse:
|
||||
case 0, False:
|
||||
order_by = RssSourceModel.nick.asc()
|
||||
description += "Nickname "
|
||||
case 0, True:
|
||||
order_by = RssSourceModel.nick.desc()
|
||||
description += "Nickname "
|
||||
case 1, False:
|
||||
order_by = FeedChannelModel.discord_channel_id.asc()
|
||||
description += "Channel ID "
|
||||
case 1, True:
|
||||
order_by = FeedChannelModel.discord_channel_id.desc()
|
||||
description += "Channel ID "
|
||||
case 2, False:
|
||||
order_by = RssSourceModel.created.desc()
|
||||
description += "Date Added "
|
||||
case 2, True:
|
||||
order_by = RssSourceModel.created.asc()
|
||||
description += "Date Added "
|
||||
case _, _:
|
||||
raise ValueError(f"Unknown sort: {sort}")
|
||||
else:
|
||||
order_by = FeedChannelModel.discord_channel_id.asc()
|
||||
description = ""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
@ -562,7 +599,7 @@ class FeedCog(commands.Cog):
|
||||
select(FeedChannelModel, RssSourceModel)
|
||||
.where(whereclause)
|
||||
.join(RssSourceModel)
|
||||
.order_by(FeedChannelModel.discord_channel_id)
|
||||
.order_by(order_by)
|
||||
)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
@ -570,8 +607,9 @@ class FeedCog(commands.Cog):
|
||||
rowcount = len(feed_channels)
|
||||
|
||||
if not feed_channels:
|
||||
await followup(inter,
|
||||
"It looks like there are no feed channels available."
|
||||
await followup_error(inter,
|
||||
title="No Assigned Feeds Found",
|
||||
message="Assign a channel to receive feed content with `/feed assign`."
|
||||
)
|
||||
return
|
||||
|
||||
@ -583,7 +621,7 @@ class FeedCog(commands.Cog):
|
||||
|
||||
embed = Embed(
|
||||
title="Saved Feed Channels",
|
||||
description=f"{output}",
|
||||
description=f"{description}\n{output}",
|
||||
colour=Colour.blue()
|
||||
)
|
||||
embed.set_footer(text=f"Showing {rowcount} results")
|
||||
|
@ -4,17 +4,16 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
|
||||
"""
|
||||
|
||||
import logging
|
||||
import async_timeout
|
||||
from time import process_time
|
||||
|
||||
import aiohttp
|
||||
from feedparser import parse
|
||||
from sqlalchemy import insert, select, and_
|
||||
from discord import Interaction, app_commands, TextChannel
|
||||
from discord import Interaction, TextChannel
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
|
||||
from feed import Source, Article, get_unparsed_feed
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel
|
||||
from feed import Source, Article, get_unparsed_feed # pylint disable=E0401
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel # pylint disable=E0401
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -27,6 +26,7 @@ class TaskCog(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
self.time = None
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
@ -36,15 +36,12 @@ class TaskCog(commands.Cog):
|
||||
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
# @app_commands.command(name="test-trigger-task")
|
||||
async def test_trigger_task(self, inter: Interaction):
|
||||
await inter.response.defer()
|
||||
await self.rss_task()
|
||||
await inter.followup.send("done")
|
||||
|
||||
@tasks.loop(minutes=10)
|
||||
async def rss_task(self):
|
||||
"""Automated task responsible for processing rss feeds."""
|
||||
|
||||
log.info("Running rss task")
|
||||
time = process_time()
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
|
||||
@ -54,10 +51,21 @@ class TaskCog(commands.Cog):
|
||||
for feed in feeds:
|
||||
await self.process_feed(feed, database)
|
||||
|
||||
log.info("Finished rss task")
|
||||
log.info("Finished rss task, time elapsed: %s", process_time() - time)
|
||||
|
||||
async def process_feed(self, feed: FeedChannelModel, database: DatabaseManager):
|
||||
"""Process the passed feed. Will also call process for each article found in the feed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feed : FeedChannelModel
|
||||
Database model for the feed.
|
||||
database : DatabaseManager
|
||||
Database connection handler, must be open.
|
||||
"""
|
||||
|
||||
log.debug("Processing feed: %s", feed.id)
|
||||
|
||||
async def process_feed(self, feed: FeedChannelModel, database):
|
||||
log.info("Processing feed: %s", feed.id)
|
||||
channel = self.bot.get_channel(feed.discord_channel_id)
|
||||
|
||||
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||
@ -72,8 +80,22 @@ class TaskCog(commands.Cog):
|
||||
for article in articles:
|
||||
await self.process_article(article, channel, database)
|
||||
|
||||
async def process_article(self, article: Article, channel: TextChannel, database):
|
||||
log.info("Processing article: %s", article.url)
|
||||
async def process_article(
|
||||
self, article: Article, channel: TextChannel, database: DatabaseManager
|
||||
):
|
||||
"""Process the passed article. Will send the embed to a channel if all is valid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
article : Article
|
||||
Database model for the article.
|
||||
channel : TextChannel
|
||||
Where the article will be sent to.
|
||||
database : DatabaseManager
|
||||
Database connection handler, must be open.
|
||||
"""
|
||||
|
||||
log.debug("Processing article: %s", article.url)
|
||||
|
||||
query = select(SentArticleModel).where(and_(
|
||||
SentArticleModel.article_url == article.url,
|
||||
@ -82,14 +104,14 @@ class TaskCog(commands.Cog):
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if result.scalars().all():
|
||||
log.info("Article already processed: %s", article.url)
|
||||
log.debug("Article already processed: %s", article.url)
|
||||
return
|
||||
|
||||
embed = await article.to_embed()
|
||||
try:
|
||||
await channel.send(embed=embed)
|
||||
except Forbidden:
|
||||
log.error("Forbidden: %s · %s", channel.name, channel.id)
|
||||
log.error("Can't send article to channel: %s · %s", channel.name, channel.id)
|
||||
return
|
||||
|
||||
query = insert(SentArticleModel).values(
|
||||
@ -100,8 +122,7 @@ class TaskCog(commands.Cog):
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.info("new Article processed: %s", article.url)
|
||||
|
||||
log.debug("new Article processed: %s", article.url)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
|
@ -1,6 +1,3 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogSetup:
|
||||
|
||||
|
||||
def __init__(self, BASE_DIR: Path):
|
||||
self.BASE_DIR = BASE_DIR
|
||||
self.LOGS_DIR = BASE_DIR / "logs/"
|
||||
@ -100,4 +100,4 @@ class LogSetup:
|
||||
# Clear up old log files
|
||||
self._delete_old_logs()
|
||||
|
||||
return file.name
|
||||
return file.name
|
||||
|
@ -33,7 +33,7 @@ async def main():
|
||||
|
||||
# Setup logging settings and mute spammy loggers
|
||||
logsetup = LogSetup(BASE_DIR)
|
||||
logsetup.setup_logs(logging.INFO)
|
||||
logsetup.setup_logs(logging.DEBUG)
|
||||
logsetup.update_log_levels(
|
||||
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
|
||||
level=logging.WARNING
|
||||
|
22
src/utils.py
22
src/utils.py
@ -3,7 +3,7 @@
|
||||
import aiohttp
|
||||
import logging
|
||||
|
||||
from discord import Interaction
|
||||
from discord import Interaction, Embed, Colour
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -29,3 +29,23 @@ async def audit(cog, *args, **kwargs):
|
||||
"""Shorthand for auditing an interaction."""
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
||||
|
||||
async def followup_error(inter: Interaction, title: str, message: str, *args, **kwargs):
|
||||
"""Shorthand for following up on an interaction, except returns an embed styled in
|
||||
error colours.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.followup.send(
|
||||
*args,
|
||||
embed=Embed(
|
||||
title=title,
|
||||
description=message,
|
||||
colour=Colour.red()
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user