Working on user commands
This commit is contained in:
parent
595659e2ec
commit
abeebdcd61
@ -1,12 +1,10 @@
|
|||||||
aiohttp==3.9.1
|
aiohttp==3.9.1
|
||||||
aiopg==1.4.0
|
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
aiosqlite==0.19.0
|
aiosqlite==0.19.0
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
asyncpg==0.29.0
|
asyncpg==0.29.0
|
||||||
attrs==23.1.0
|
attrs==23.1.0
|
||||||
beautifulsoup4==4.12.2
|
beautifulsoup4==4.12.2
|
||||||
bs4==0.0.1
|
|
||||||
discord.py==2.3.2
|
discord.py==2.3.2
|
||||||
feedparser==6.0.11
|
feedparser==6.0.11
|
||||||
frozenlist==1.4.0
|
frozenlist==1.4.0
|
||||||
@ -14,7 +12,6 @@ greenlet==3.0.2
|
|||||||
idna==3.6
|
idna==3.6
|
||||||
markdownify==0.11.6
|
markdownify==0.11.6
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
psycopg2==2.9.9
|
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
python-dotenv==1.0.0
|
python-dotenv==1.0.0
|
||||||
sgmllib3k==1.0.0
|
sgmllib3k==1.0.0
|
||||||
@ -22,4 +19,5 @@ six==1.16.0
|
|||||||
soupsieve==2.5
|
soupsieve==2.5
|
||||||
SQLAlchemy==2.0.23
|
SQLAlchemy==2.0.23
|
||||||
typing_extensions==4.9.0
|
typing_extensions==4.9.0
|
||||||
|
validators==0.22.0
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
@ -5,7 +5,7 @@ Initialize the database modules, create the database tables and default data.
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from .models import Base, AuditModel, SentArticleModel
|
from .models import Base, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||||
from .db import DatabaseManager
|
from .db import DatabaseManager
|
||||||
|
|
||||||
# Initialise a database session
|
# Initialise a database session
|
||||||
|
184
src/extensions/cmd.py
Normal file
184
src/extensions/cmd.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
"""
|
||||||
|
Extension for the `CommandCog`.
|
||||||
|
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import validators
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import textwrap
|
||||||
|
import feedparser
|
||||||
|
from markdownify import markdownify
|
||||||
|
from discord import app_commands, Interaction, Embed
|
||||||
|
from discord.ext import commands, tasks
|
||||||
|
from sqlalchemy import insert, select, update, and_, or_
|
||||||
|
|
||||||
|
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||||
|
from feed import Feeds, get_source
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def get_rss_data(url: str):
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
items = await response.text(), response.status
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class CommandCog(commands.Cog):
|
||||||
|
"""
|
||||||
|
Command cog.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bot):
|
||||||
|
super().__init__()
|
||||||
|
self.bot = bot
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_ready(self):
|
||||||
|
log.info(f"{self.__class__.__name__} cog is ready")
|
||||||
|
|
||||||
|
rss_group = app_commands.Group(
|
||||||
|
name="rss",
|
||||||
|
description="Commands for rss sources.",
|
||||||
|
guild_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@rss_group.command(name="add")
|
||||||
|
async def add_rss_source(self, inter: Interaction, url: str):
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
# validate the input
|
||||||
|
if not validators.url(url):
|
||||||
|
await inter.followup.send(
|
||||||
|
"The URL you have entered is malformed or invalid:\n"
|
||||||
|
f"`{url=}`",
|
||||||
|
suppress_embeds=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
feed_data, status_code = await get_rss_data(url)
|
||||||
|
if status_code != 200:
|
||||||
|
await inter.followup.send(
|
||||||
|
f"The URL provided returned an invalid status code:\n"
|
||||||
|
f"{url=}, {status_code=}",
|
||||||
|
suppress_embeds=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
feed = feedparser.parse(feed_data)
|
||||||
|
if not feed.version:
|
||||||
|
await inter.followup.send(
|
||||||
|
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
||||||
|
suppress_embeds=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
query = insert(RssSourceModel).values(
|
||||||
|
discord_server_id = inter.guild_id,
|
||||||
|
rss_url = url
|
||||||
|
)
|
||||||
|
await database.session.execute(query)
|
||||||
|
|
||||||
|
await inter.followup.send("RSS source added")
|
||||||
|
|
||||||
|
@rss_group.command(name="remove")
|
||||||
|
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
def exists(item) -> bool:
|
||||||
|
"""
|
||||||
|
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
|
||||||
|
Ironically with this func & comment the code is longer, but at least I can read it ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
return item is not None
|
||||||
|
|
||||||
|
url_exists = exists(url)
|
||||||
|
num_exists = exists(number)
|
||||||
|
|
||||||
|
if (url_exists and num_exists) or (not url_exists and not num_exists):
|
||||||
|
await inter.followup.send(
|
||||||
|
"Please only specify either the existing rss number or url, "
|
||||||
|
"enter at least one of these, but don't enter both."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if url_exists and not validators.url(url):
|
||||||
|
await inter.followup.send(
|
||||||
|
"The URL you have entered is malformed or invalid:\n"
|
||||||
|
f"`{url=}`",
|
||||||
|
suppress_embeds=True
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
whereclause = and_(
|
||||||
|
RssSourceModel.discord_server_id == inter.guild_id,
|
||||||
|
RssSourceModel.rss_url == url
|
||||||
|
)
|
||||||
|
query = update(RssSourceModel).where(whereclause).values(active=False)
|
||||||
|
result = await database.session.execute(query)
|
||||||
|
|
||||||
|
await inter.followup.send(f"I've updated {result.rowcount} rows")
|
||||||
|
|
||||||
|
@rss_group.command(name="list")
|
||||||
|
@app_commands.choices(filter=[
|
||||||
|
app_commands.Choice(name="Active Only [default]", value=1),
|
||||||
|
app_commands.Choice(name="Inactive Only", value=0),
|
||||||
|
app_commands.Choice(name="All", value=2),
|
||||||
|
])
|
||||||
|
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
if filter.value == 2:
|
||||||
|
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||||
|
else:
|
||||||
|
whereclause = and_(
|
||||||
|
RssSourceModel.discord_server_id == inter.guild_id,
|
||||||
|
RssSourceModel.active == filter.value # should result to 0 or 1
|
||||||
|
)
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
query = select(RssSourceModel).where(whereclause)
|
||||||
|
result = await database.session.execute(query)
|
||||||
|
|
||||||
|
rss_sources = result.scalars().all()
|
||||||
|
embed_fields = [{
|
||||||
|
"name": f"[{i}]",
|
||||||
|
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
|
||||||
|
} for i, rss in enumerate(rss_sources)]
|
||||||
|
|
||||||
|
if not embed_fields:
|
||||||
|
await inter.followup.send("It looks like you have no rss sources.")
|
||||||
|
return
|
||||||
|
|
||||||
|
embed = Embed(
|
||||||
|
title="RSS Sources",
|
||||||
|
description="Here are your rss sources:"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field in embed_fields:
|
||||||
|
embed.add_field(**field, inline=False)
|
||||||
|
|
||||||
|
# output = "Your rss sources:\n\n"
|
||||||
|
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
|
||||||
|
|
||||||
|
await inter.followup.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(bot):
|
||||||
|
"""
|
||||||
|
Setup function for this extension.
|
||||||
|
Adds `CommandCog` to the bot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cog = CommandCog(bot)
|
||||||
|
await bot.add_cog(cog)
|
||||||
|
log.info(f"Added {cog.__class__.__name__} cog")
|
@ -83,7 +83,7 @@ def get_source(feed: Feeds) -> Source:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parsed_feed = parse(feed.value)
|
parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py")
|
||||||
return Source.from_parsed(parsed_feed)
|
return Source.from_parsed(parsed_feed)
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ import asyncio
|
|||||||
from os import getenv
|
from os import getenv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# it's important to load environment variables before
|
||||||
|
# importing the packages that depend on them.
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@ -29,7 +31,7 @@ async def main():
|
|||||||
if not token:
|
if not token:
|
||||||
raise ValueError("Token is empty")
|
raise ValueError("Token is empty")
|
||||||
|
|
||||||
# Setup logging settings
|
# Setup logging settings and mute spammy loggers
|
||||||
logsetup = LogSetup(BASE_DIR)
|
logsetup = LogSetup(BASE_DIR)
|
||||||
logsetup.setup_logs()
|
logsetup.setup_logs()
|
||||||
logsetup.update_log_levels(
|
logsetup.update_log_levels(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user