185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
"""
|
|
Extension for the `CommandCog`.
|
|
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
|
"""
|
|
|
|
import logging
|
|
import validators
|
|
|
|
import aiohttp
|
|
import textwrap
|
|
import feedparser
|
|
from markdownify import markdownify
|
|
from discord import app_commands, Interaction, Embed
|
|
from discord.ext import commands, tasks
|
|
from sqlalchemy import insert, select, update, and_, or_
|
|
|
|
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
|
from feed import Feeds, get_source
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
async def get_rss_data(url: str):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
items = await response.text(), response.status
|
|
|
|
return items
|
|
|
|
|
|
class CommandCog(commands.Cog):
|
|
"""
|
|
Command cog.
|
|
"""
|
|
|
|
def __init__(self, bot):
|
|
super().__init__()
|
|
self.bot = bot
|
|
|
|
@commands.Cog.listener()
|
|
async def on_ready(self):
|
|
log.info(f"{self.__class__.__name__} cog is ready")
|
|
|
|
rss_group = app_commands.Group(
|
|
name="rss",
|
|
description="Commands for rss sources.",
|
|
guild_only=True
|
|
)
|
|
|
|
@rss_group.command(name="add")
|
|
async def add_rss_source(self, inter: Interaction, url: str):
|
|
|
|
await inter.response.defer()
|
|
|
|
# validate the input
|
|
if not validators.url(url):
|
|
await inter.followup.send(
|
|
"The URL you have entered is malformed or invalid:\n"
|
|
f"`{url=}`",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
feed_data, status_code = await get_rss_data(url)
|
|
if status_code != 200:
|
|
await inter.followup.send(
|
|
f"The URL provided returned an invalid status code:\n"
|
|
f"{url=}, {status_code=}",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
feed = feedparser.parse(feed_data)
|
|
if not feed.version:
|
|
await inter.followup.send(
|
|
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
async with DatabaseManager() as database:
|
|
query = insert(RssSourceModel).values(
|
|
discord_server_id = inter.guild_id,
|
|
rss_url = url
|
|
)
|
|
await database.session.execute(query)
|
|
|
|
await inter.followup.send("RSS source added")
|
|
|
|
@rss_group.command(name="remove")
|
|
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
|
|
|
|
await inter.response.defer()
|
|
|
|
def exists(item) -> bool:
|
|
"""
|
|
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
|
|
Ironically with this func & comment the code is longer, but at least I can read it ...
|
|
"""
|
|
|
|
return item is not None
|
|
|
|
url_exists = exists(url)
|
|
num_exists = exists(number)
|
|
|
|
if (url_exists and num_exists) or (not url_exists and not num_exists):
|
|
await inter.followup.send(
|
|
"Please only specify either the existing rss number or url, "
|
|
"enter at least one of these, but don't enter both."
|
|
)
|
|
return
|
|
|
|
if url_exists and not validators.url(url):
|
|
await inter.followup.send(
|
|
"The URL you have entered is malformed or invalid:\n"
|
|
f"`{url=}`",
|
|
suppress_embeds=True
|
|
)
|
|
return
|
|
|
|
async with DatabaseManager() as database:
|
|
whereclause = and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.rss_url == url
|
|
)
|
|
query = update(RssSourceModel).where(whereclause).values(active=False)
|
|
result = await database.session.execute(query)
|
|
|
|
await inter.followup.send(f"I've updated {result.rowcount} rows")
|
|
|
|
@rss_group.command(name="list")
|
|
@app_commands.choices(filter=[
|
|
app_commands.Choice(name="Active Only [default]", value=1),
|
|
app_commands.Choice(name="Inactive Only", value=0),
|
|
app_commands.Choice(name="All", value=2),
|
|
])
|
|
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
|
|
|
|
await inter.response.defer()
|
|
|
|
if filter.value == 2:
|
|
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
|
else:
|
|
whereclause = and_(
|
|
RssSourceModel.discord_server_id == inter.guild_id,
|
|
RssSourceModel.active == filter.value # should result to 0 or 1
|
|
)
|
|
|
|
async with DatabaseManager() as database:
|
|
query = select(RssSourceModel).where(whereclause)
|
|
result = await database.session.execute(query)
|
|
|
|
rss_sources = result.scalars().all()
|
|
embed_fields = [{
|
|
"name": f"[{i}]",
|
|
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
|
|
} for i, rss in enumerate(rss_sources)]
|
|
|
|
if not embed_fields:
|
|
await inter.followup.send("It looks like you have no rss sources.")
|
|
return
|
|
|
|
embed = Embed(
|
|
title="RSS Sources",
|
|
description="Here are your rss sources:"
|
|
)
|
|
|
|
for field in embed_fields:
|
|
embed.add_field(**field, inline=False)
|
|
|
|
# output = "Your rss sources:\n\n"
|
|
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
|
|
|
|
await inter.followup.send(embed=embed)
|
|
|
|
|
|
async def setup(bot):
|
|
"""
|
|
Setup function for this extension.
|
|
Adds `CommandCog` to the bot.
|
|
"""
|
|
|
|
cog = CommandCog(bot)
|
|
await bot.add_cog(cog)
|
|
log.info(f"Added {cog.__class__.__name__} cog")
|