From 8a1f623c6f1d16eb5dd8240cec4544a28a18294b Mon Sep 17 00:00:00 2001 From: corbz Date: Wed, 20 Dec 2023 08:22:37 +0000 Subject: [PATCH] help commands, and tasks function --- src/db/db.py | 5 +- src/db/models.py | 18 ++++--- src/extensions/rss.py | 115 ++++++++++++++++++++++++++++++++-------- src/extensions/tasks.py | 88 ++++++++++++++++++++++++++++-- src/feed.py | 10 ++++ src/main.py | 2 +- 6 files changed, 201 insertions(+), 37 deletions(-) diff --git a/src/db/db.py b/src/db/db.py index 81c5684..9821310 100644 --- a/src/db/db.py +++ b/src/db/db.py @@ -24,7 +24,7 @@ class DatabaseManager: """ def __init__(self, no_commit: bool = False): - database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it? + database_url = self.get_database_url() # This is called every time a connection is established, maybe make it once and reference it? self.engine = create_async_engine(database_url, future=True) self.session_maker = sessionmaker(self.engine, class_=AsyncSession) self.session = None @@ -36,13 +36,10 @@ class DatabaseManager: Returns a connection string for the database. """ - # TODO finish support for mysql, mariadb, etc - url = f"{DB_TYPE}://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_DATABASE}" url_addon = "" # This looks fucking ugly - # match use_async, DB_TYPE: case True, "sqlite": url_addon = "aiosqlite" diff --git a/src/db/models.py b/src/db/models.py index 96ffd23..2fd5b2d 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -3,12 +3,18 @@ Models and Enums for the database. All table classes should be suffixed with `Model`. """ -from enum import Enum, auto - -from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey from sqlalchemy.sql import func from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + BigInteger, + UniqueConstraint, + ForeignKey +) Base = declarative_base() @@ -23,7 +29,7 @@ class AuditModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_user_id = Column(BigInteger, nullable=False) message = Column(String, nullable=False) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 class SentArticleModel(Base): @@ -38,7 +44,7 @@ class SentArticleModel(Base): discord_channel_id = Column(BigInteger, nullable=False) discord_server_id = Column(BigInteger, nullable=False) article_url = Column(String, nullable=False) - when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 class RssSourceModel(Base): @@ -52,7 +58,7 @@ class RssSourceModel(Base): nick = Column(String, nullable=False) discord_server_id = Column(BigInteger, nullable=False) rss_url = Column(String, nullable=False) - created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 feed_channels = relationship("FeedChannelModel", cascade="all, delete") diff --git a/src/extensions/rss.py b/src/extensions/rss.py index 56bb18e..ceb1d98 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -1,6 +1,6 @@ """ -Extension for the `RssCog`. -Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot. +Extension for the `FeedCog`. +Loading this file via `commands.Bot.load_extension` will add `FeedCog` to the bot. """ import logging @@ -73,7 +73,7 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed return None, feed -class RssCog(commands.Cog): +class FeedCog(commands.Cog): """ Command cog. """ @@ -226,7 +226,9 @@ class RssCog(commands.Cog): @feed_group.command(name="list") @choices(sort=rss_list_sort_choices) - async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False): + async def list_rss_sources( + self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False + ): """Provides a with a list of RSS sources available for the current server. Parameters @@ -238,7 +240,7 @@ class RssCog(commands.Cog): await inter.response.defer() # Default to the first choice if not specified. - if type(sort) is Choice: + if isinstance(sort, Choice): description = "Sort by " description += "Nickname " if sort.value == 0 else "Date Added " description += '\U000025BC' if sort_reverse else '\U000025B2' @@ -246,19 +248,17 @@ class RssCog(commands.Cog): sort = rss_list_sort_choices[0] description = "" - sort = sort if type(sort) == Choice else rss_list_sort_choices[0] - match sort.value, sort_reverse: case 0, False: order_by = RssSourceModel.nick.asc() case 0, True: order_by = RssSourceModel.nick.desc() - case 1, False: # NOTE: - order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest - case 1, True: # date first, not the oldest as it would sort otherwise. + case 1, False: + order_by = RssSourceModel.created.desc() + case 1, True: order_by = RssSourceModel.created.asc() case _, _: - raise ValueError("Unknown sort: %s" % sort) + raise ValueError(f"Unknown sort: {sort}") async with DatabaseManager() as database: whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) @@ -272,7 +272,10 @@ class RssCog(commands.Cog): await followup(inter, "It looks like you have no rss sources.") return - output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)]) + output = "\n".join([ + f"{i}. **[{rss.nick}]({rss.rss_url})** " + for i, rss in enumerate(rss_sources) + ]) embed = Embed( title="Saved RSS Feeds", @@ -330,16 +333,79 @@ class RssCog(commands.Cog): for article in articles ]) await database.session.execute(query) - await audit(self, f"User is requesting {max_} articles from {source.name}", inter.user.id, database=database) + await audit(self, + f"User is requesting {max_} articles from {source.name}", + inter.user.id, database=database + ) await followup(inter, embeds=embeds) + # Help ---- ---- ---- + + @feed_group.command(name="help") + async def get_help(self, inter: Interaction): + """Get help on how to use my commands. + + Parameters + ---------- + inter : Interaction + Represents an app command interaction. + """ + + await inter.response.defer() + + description = ( + "`/feed add ` \n\n" + "Save a new RSS feed to the bot. This can be referred to later, when assigning " + "channels to receive content from these RSS feeds." + + "\n\n\n`/feed remove