help commands, and tasks function
This commit is contained in:
parent
df60652da6
commit
8a1f623c6f
@ -24,7 +24,7 @@ class DatabaseManager:
|
||||
"""
|
||||
|
||||
def __init__(self, no_commit: bool = False):
|
||||
database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it?
|
||||
database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it?
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
self.session = None
|
||||
@ -36,13 +36,10 @@ class DatabaseManager:
|
||||
Returns a connection string for the database.
|
||||
"""
|
||||
|
||||
# TODO finish support for mysql, mariadb, etc
|
||||
|
||||
url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}"
|
||||
url_addon = ""
|
||||
|
||||
# This looks fucking ugly
|
||||
#
|
||||
match use_async, DB_TYPE:
|
||||
case True, "sqlite":
|
||||
url_addon = "aiosqlite"
|
||||
|
@ -3,12 +3,18 @@ Models and Enums for the database.
|
||||
All table classes should be suffixed with `Model`.
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
BigInteger,
|
||||
UniqueConstraint,
|
||||
ForeignKey
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@ -23,7 +29,7 @@ class AuditModel(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
|
||||
class SentArticleModel(Base):
|
||||
@ -38,7 +44,7 @@ class SentArticleModel(Base):
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
article_url = Column(String, nullable=False)
|
||||
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
|
||||
class RssSourceModel(Base):
|
||||
@ -52,7 +58,7 @@ class RssSourceModel(Base):
|
||||
nick = Column(String, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
rss_url = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||
|
||||
feed_channels = relationship("FeedChannelModel", cascade="all, delete")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""
|
||||
Extension for the `RssCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot.
|
||||
Extension for the `FeedCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@ -73,7 +73,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
|
||||
return None, feed
|
||||
|
||||
|
||||
class RssCog(commands.Cog):
|
||||
class FeedCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
@ -226,7 +226,9 @@ class RssCog(commands.Cog):
|
||||
|
||||
@feed_group.command(name="list")
|
||||
@choices(sort=rss_list_sort_choices)
|
||||
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
|
||||
async def list_rss_sources(
|
||||
self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False
|
||||
):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
|
||||
Parameters
|
||||
@ -238,7 +240,7 @@ class RssCog(commands.Cog):
|
||||
await inter.response.defer()
|
||||
|
||||
# Default to the first choice if not specified.
|
||||
if type(sort) is Choice:
|
||||
if isinstance(sort, Choice):
|
||||
description = "Sort by "
|
||||
description += "Nickname " if sort.value == 0 else "Date Added "
|
||||
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
||||
@ -246,19 +248,17 @@ class RssCog(commands.Cog):
|
||||
sort = rss_list_sort_choices[0]
|
||||
description = ""
|
||||
|
||||
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
|
||||
|
||||
match sort.value, sort_reverse:
|
||||
case 0, False:
|
||||
order_by = RssSourceModel.nick.asc()
|
||||
case 0, True:
|
||||
order_by = RssSourceModel.nick.desc()
|
||||
case 1, False: # NOTE:
|
||||
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
|
||||
case 1, True: # date first, not the oldest as it would sort otherwise.
|
||||
case 1, False:
|
||||
order_by = RssSourceModel.created.desc()
|
||||
case 1, True:
|
||||
order_by = RssSourceModel.created.asc()
|
||||
case _, _:
|
||||
raise ValueError("Unknown sort: %s" % sort)
|
||||
raise ValueError(f"Unknown sort: {sort}")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
@ -272,7 +272,10 @@ class RssCog(commands.Cog):
|
||||
await followup(inter, "It looks like you have no rss sources.")
|
||||
return
|
||||
|
||||
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
|
||||
output = "\n".join([
|
||||
f"{i}. **[{rss.nick}]({rss.rss_url})** "
|
||||
for i, rss in enumerate(rss_sources)
|
||||
])
|
||||
|
||||
embed = Embed(
|
||||
title="Saved RSS Feeds",
|
||||
@ -330,16 +333,79 @@ class RssCog(commands.Cog):
|
||||
for article in articles
|
||||
])
|
||||
await database.session.execute(query)
|
||||
await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database)
|
||||
await audit(self,
|
||||
f"User is requesting {max_} articles from {source.name}",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
await followup(inter, embeds=embeds)
|
||||
|
||||
# Help ---- ---- ----
|
||||
|
||||
@feed_group.command(name="help")
|
||||
async def get_help(self, inter: Interaction):
|
||||
"""Get help on how to use my commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
description = (
|
||||
"`/feed add <nickname> <url>` \n\n"
|
||||
"Save a new RSS feed to the bot. This can be referred to later, when assigning "
|
||||
"channels to receive content from these RSS feeds."
|
||||
|
||||
"\n\n\n`/feed remove <option>` \n\n"
|
||||
"Remove a previously saved RSS feed. Select the nickname from the shown options "
|
||||
"if any. You can re-add the RSS feed to the bot using the `/feeds add` command."
|
||||
|
||||
"\n\n\n`/feed list <sort> <sort_reverse>` \n\n"
|
||||
"List all saved RSS feeds numerically. Use the `<sort>` option to order "
|
||||
"the results by either nickname or date & time added. Use the `<sort_reverse>` "
|
||||
"option to order by ascending or descending in conjunction with the `<sort>` option."
|
||||
|
||||
"\n\n\n`/feed assign <rss> <textchannel>` \n\n"
|
||||
"Assign a channel to an RSS feed. Previously saved RSS feeds will be selectable "
|
||||
"under the `<rss>` option. The channel will be assumed as the current channel, "
|
||||
"unless specified otherwise using the `<textchannel` option."
|
||||
|
||||
"\n\n\n`/feed unassign <option>` \n\n"
|
||||
"Unassigned channel from an RSS feed. Previously assigned channels will be shown "
|
||||
"as an `<option>`, select one to remove it."
|
||||
|
||||
"\n\n\n`/feed channels` \n\n"
|
||||
"List all channels assigned to an RSS feed numerically."
|
||||
)
|
||||
|
||||
embed = Embed(
|
||||
title="Help",
|
||||
description=description,
|
||||
colour=Colour.blue(),
|
||||
)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
# Channels ---- ---- ----
|
||||
|
||||
|
||||
async def autocomplete_rss_sources(self, inter: Interaction, nickname: str):
|
||||
""""""
|
||||
"""_summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
_description_
|
||||
nickname : str
|
||||
_description_
|
||||
|
||||
Returns
|
||||
-------
|
||||
_type_
|
||||
_description_
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
@ -371,7 +437,7 @@ class RssCog(commands.Cog):
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
FeedChannelModel.search_name.ilike(f"%{current}%"), # is this secure from SQL Injection atk ?
|
||||
FeedChannelModel.search_name.ilike(f"%{current}%"),
|
||||
RssSourceModel.id == FeedChannelModel.rss_source_id
|
||||
)
|
||||
query = (
|
||||
@ -384,7 +450,10 @@ class RssCog(commands.Cog):
|
||||
feeds = []
|
||||
for feed in result.scalars().all():
|
||||
channel = inter.guild.get_channel(feed.discord_channel_id)
|
||||
feeds.append(Choice(name=f"# {channel.name} | {feed.rss_source.nick}", value=feed.id))
|
||||
feeds.append(Choice(
|
||||
name=f"# {channel.name} | {feed.rss_source.nick}",
|
||||
value=feed.id
|
||||
))
|
||||
|
||||
log.debug("Autocomplete existing_feeds returned %s results", len(feeds))
|
||||
|
||||
@ -501,11 +570,14 @@ class RssCog(commands.Cog):
|
||||
rowcount = len(feed_channels)
|
||||
|
||||
if not feed_channels:
|
||||
await followup(inter, "It looks like there are no feed channels available.")
|
||||
await followup(inter,
|
||||
"It looks like there are no feed channels available."
|
||||
)
|
||||
return
|
||||
|
||||
output = "\n".join([
|
||||
f"{i}. <#{feed.discord_channel_id}> · [{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||
f"{i}. <#{feed.discord_channel_id}> · "
|
||||
f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})"
|
||||
for i, feed in enumerate(feed_channels)
|
||||
])
|
||||
|
||||
@ -519,13 +591,12 @@ class RssCog(commands.Cog):
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `RssCog` to the bot.
|
||||
Adds `FeedCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = RssCog(bot)
|
||||
cog = FeedCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info("Added %s cog", cog.__class__.__name__)
|
||||
|
@ -4,15 +4,24 @@ Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bo
|
||||
"""
|
||||
|
||||
import logging
|
||||
import async_timeout
|
||||
|
||||
from discord.ext import commands
|
||||
import aiohttp
|
||||
from feedparser import parse
|
||||
from sqlalchemy import insert, select, and_
|
||||
from discord import Interaction, app_commands, TextChannel
|
||||
from discord.ext import commands, tasks
|
||||
from discord.errors import Forbidden
|
||||
|
||||
from feed import Source, Article, get_unparsed_feed
|
||||
from db import DatabaseManager, FeedChannelModel, RssSourceModel, SentArticleModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
Tasks cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
@ -21,7 +30,78 @@ class TaskCog(commands.Cog):
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
"""Instructions to call when the cog is ready."""
|
||||
|
||||
self.rss_task.start() # pylint disable=E1101
|
||||
|
||||
log.info("%s cog is ready", self.__class__.__name__)
|
||||
|
||||
# @app_commands.command(name="test-trigger-task")
|
||||
async def test_trigger_task(self, inter: Interaction):
|
||||
await inter.response.defer()
|
||||
await self.rss_task()
|
||||
await inter.followup.send("done")
|
||||
|
||||
@tasks.loop(minutes=10)
|
||||
async def rss_task(self):
|
||||
log.info("Running rss task")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(FeedChannelModel, RssSourceModel).join(RssSourceModel)
|
||||
result = await database.session.execute(query)
|
||||
feeds = result.scalars().all()
|
||||
|
||||
for feed in feeds:
|
||||
await self.process_feed(feed, database)
|
||||
|
||||
log.info("Finished rss task")
|
||||
|
||||
async def process_feed(self, feed: FeedChannelModel, database):
|
||||
log.info("Processing feed: %s", feed.id)
|
||||
channel = self.bot.get_channel(feed.discord_channel_id)
|
||||
|
||||
unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url)
|
||||
parsed_feed = parse(unparsed_content)
|
||||
source = Source.from_parsed(parsed_feed)
|
||||
articles = source.get_latest_articles(5)
|
||||
|
||||
if not articles:
|
||||
log.info("No articles to process")
|
||||
return
|
||||
|
||||
for article in articles:
|
||||
await self.process_article(article, channel, database)
|
||||
|
||||
async def process_article(self, article: Article, channel: TextChannel, database):
|
||||
log.info("Processing article: %s", article.url)
|
||||
|
||||
query = select(SentArticleModel).where(and_(
|
||||
SentArticleModel.article_url == article.url,
|
||||
SentArticleModel.discord_channel_id == channel.id
|
||||
))
|
||||
result = await database.session.execute(query)
|
||||
|
||||
if result.scalars().all():
|
||||
log.info("Article already processed: %s", article.url)
|
||||
return
|
||||
|
||||
embed = await article.to_embed()
|
||||
try:
|
||||
await channel.send(embed=embed)
|
||||
except Forbidden:
|
||||
log.error("Forbidden: %s · %s", channel.name, channel.id)
|
||||
return
|
||||
|
||||
query = insert(SentArticleModel).values(
|
||||
article_url = article.url,
|
||||
discord_channel_id = channel.id,
|
||||
discord_server_id = channel.guild.id,
|
||||
discord_message_id = -1
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.info("new Article processed: %s", article.url)
|
||||
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
@ -32,4 +112,4 @@ async def setup(bot):
|
||||
|
||||
cog = TaskCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
||||
log.info("Added %s cog", cog.__class__.__name__)
|
||||
|
10
src/feed.py
10
src/feed.py
@ -4,6 +4,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import async_timeout
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
@ -179,6 +180,15 @@ class Source:
|
||||
for i, entry in enumerate(self.feed.entries)
|
||||
if i < max
|
||||
]
|
||||
|
||||
async def fetch(session, url: str) -> str:
|
||||
async with async_timeout.timeout(20):
|
||||
async with session.get(url) as response:
|
||||
return await response.text()
|
||||
|
||||
async def get_unparsed_feed(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
return await fetch(session, url)
|
||||
|
||||
|
||||
def get_source(rss_url: str) -> Source:
|
||||
|
@ -33,7 +33,7 @@ async def main():
|
||||
|
||||
# Setup logging settings and mute spammy loggers
|
||||
logsetup = LogSetup(BASE_DIR)
|
||||
logsetup.setup_logs()
|
||||
logsetup.setup_logs(logging.INFO)
|
||||
logsetup.update_log_levels(
|
||||
('discord', 'PIL', 'urllib3', 'aiosqlite', 'charset_normalizer'),
|
||||
level=logging.WARNING
|
||||
|
Loading…
x
Reference in New Issue
Block a user