From abeebdcd61c9264acb46fc69da1419955aae3098 Mon Sep 17 00:00:00 2001 From: Corban-Lee Date: Thu, 14 Dec 2023 16:54:33 +0000 Subject: [PATCH] Working on user commands --- requirements.txt | 4 +- src/db/__init__.py | 2 +- src/extensions/cmd.py | 184 ++++++++++++++++++++++++++++++++++++++++++ src/feed.py | 2 +- src/main.py | 4 +- 5 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 src/extensions/cmd.py diff --git a/requirements.txt b/requirements.txt index 87da7c3..8e7707f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ aiohttp==3.9.1 -aiopg==1.4.0 aiosignal==1.3.1 aiosqlite==0.19.0 async-timeout==4.0.3 asyncpg==0.29.0 attrs==23.1.0 beautifulsoup4==4.12.2 -bs4==0.0.1 discord.py==2.3.2 feedparser==6.0.11 frozenlist==1.4.0 @@ -14,7 +12,6 @@ greenlet==3.0.2 idna==3.6 markdownify==0.11.6 multidict==6.0.4 -psycopg2==2.9.9 psycopg2-binary==2.9.9 python-dotenv==1.0.0 sgmllib3k==1.0.0 @@ -22,4 +19,5 @@ six==1.16.0 soupsieve==2.5 SQLAlchemy==2.0.23 typing_extensions==4.9.0 +validators==0.22.0 yarl==1.9.4 diff --git a/src/db/__init__.py b/src/db/__init__.py index 432d366..65d8c9f 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -5,7 +5,7 @@ Initialize the database modules, create the database tables and default data. from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from .models import Base, AuditModel, SentArticleModel +from .models import Base, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel from .db import DatabaseManager # Initialise a database session diff --git a/src/extensions/cmd.py b/src/extensions/cmd.py new file mode 100644 index 0000000..7f54a3c --- /dev/null +++ b/src/extensions/cmd.py @@ -0,0 +1,184 @@ +""" +Extension for the `CommandCog`. +Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot. +""" + +import logging +import validators + +import aiohttp +import textwrap +import feedparser +from markdownify import markdownify +from discord import app_commands, Interaction, Embed +from discord.ext import commands, tasks +from sqlalchemy import insert, select, update, and_, or_ + +from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel +from feed import Feeds, get_source + +log = logging.getLogger(__name__) + +async def get_rss_data(url: str): + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + items = await response.text(), response.status + + return items + + +class CommandCog(commands.Cog): + """ + Command cog. + """ + + def __init__(self, bot): + super().__init__() + self.bot = bot + + @commands.Cog.listener() + async def on_ready(self): + log.info(f"{self.__class__.__name__} cog is ready") + + rss_group = app_commands.Group( + name="rss", + description="Commands for rss sources.", + guild_only=True + ) + + @rss_group.command(name="add") + async def add_rss_source(self, inter: Interaction, url: str): + + await inter.response.defer() + + # validate the input + if not validators.url(url): + await inter.followup.send( + "The URL you have entered is malformed or invalid:\n" + f"`{url=}`", + suppress_embeds=True + ) + return + + feed_data, status_code = await get_rss_data(url) + if status_code != 200: + await inter.followup.send( + f"The URL provided returned an invalid status code:\n" + f"{url=}, {status_code=}", + suppress_embeds=True + ) + return + + feed = feedparser.parse(feed_data) + if not feed.version: + await inter.followup.send( + f"The provided URL '{url}' does not seem to be a valid RSS feed.", + suppress_embeds=True + ) + return + + async with DatabaseManager() as database: + query = insert(RssSourceModel).values( + discord_server_id = inter.guild_id, + rss_url = url + ) + await database.session.execute(query) + + await inter.followup.send("RSS source added") + + @rss_group.command(name="remove") + async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None): + + await inter.response.defer() + + def exists(item) -> bool: + """ + Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass. + Ironically with this func & comment the code is longer, but at least I can read it ... + """ + + return item is not None + + url_exists = exists(url) + num_exists = exists(number) + + if (url_exists and num_exists) or (not url_exists and not num_exists): + await inter.followup.send( + "Please only specify either the existing rss number or url, " + "enter at least one of these, but don't enter both." + ) + return + + if url_exists and not validators.url(url): + await inter.followup.send( + "The URL you have entered is malformed or invalid:\n" + f"`{url=}`", + suppress_embeds=True + ) + return + + async with DatabaseManager() as database: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.rss_url == url + ) + query = update(RssSourceModel).where(whereclause).values(active=False) + result = await database.session.execute(query) + + await inter.followup.send(f"I've updated {result.rowcount} rows") + + @rss_group.command(name="list") + @app_commands.choices(filter=[ + app_commands.Choice(name="Active Only [default]", value=1), + app_commands.Choice(name="Inactive Only", value=0), + app_commands.Choice(name="All", value=2), + ]) + async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]): + + await inter.response.defer() + + if filter.value == 2: + whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id) + else: + whereclause = and_( + RssSourceModel.discord_server_id == inter.guild_id, + RssSourceModel.active == filter.value # should result to 0 or 1 + ) + + async with DatabaseManager() as database: + query = select(RssSourceModel).where(whereclause) + result = await database.session.execute(query) + + rss_sources = result.scalars().all() + embed_fields = [{ + "name": f"[{i}]", + "value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}" + } for i, rss in enumerate(rss_sources)] + + if not embed_fields: + await inter.followup.send("It looks like you have no rss sources.") + return + + embed = Embed( + title="RSS Sources", + description="Here are your rss sources:" + ) + + for field in embed_fields: + embed.add_field(**field, inline=False) + + # output = "Your rss sources:\n\n" + # output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)]) + + await inter.followup.send(embed=embed) + + +async def setup(bot): + """ + Setup function for this extension. + Adds `CommandCog` to the bot. + """ + + cog = CommandCog(bot) + await bot.add_cog(cog) + log.info(f"Added {cog.__class__.__name__} cog") diff --git a/src/feed.py b/src/feed.py index ccebb72..4183187 100644 --- a/src/feed.py +++ b/src/feed.py @@ -83,7 +83,7 @@ def get_source(feed: Feeds) -> Source: """ - parsed_feed = parse(feed.value) + parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py") return Source.from_parsed(parsed_feed) diff --git a/src/main.py b/src/main.py index dc6ee0d..dc3add1 100644 --- a/src/main.py +++ b/src/main.py @@ -8,6 +8,8 @@ import asyncio from os import getenv from pathlib import Path +# it's important to load environment variables before +# importing the packages that depend on them. from dotenv import load_dotenv load_dotenv() @@ -29,7 +31,7 @@ async def main(): if not token: raise ValueError("Token is empty") - # Setup logging settings + # Setup logging settings and mute spammy loggers logsetup = LogSetup(BASE_DIR) logsetup.setup_logs() logsetup.update_log_levels(