Merge branch 'main' of https://gitea.corbz.dev/corbz/NewsBot
This commit is contained in:
commit
76d27c4782
17
.vscode/launch.json
vendored
Normal file
17
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: NewsBot",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/src/main.py",
|
||||
"python": "${workspaceFolder}/venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
15
src/bot.py
15
src/bot.py
@ -46,9 +46,16 @@ class DiscordBot(commands.Bot):
|
||||
if path.suffix == ".py":
|
||||
await self.load_extension(f"extensions.{path.stem}")
|
||||
|
||||
async def audit(self, message: str, user_id: int):
|
||||
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
|
||||
|
||||
message = f"Requesting latest article"
|
||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||
|
||||
if database:
|
||||
await database.session.execute(query)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
message = f"Requesting latest article"
|
||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||
await database.session.execute(query)
|
||||
await database.session.execute(query)
|
||||
|
||||
log.debug("Audit logged")
|
||||
|
@ -24,7 +24,7 @@ class DatabaseManager:
|
||||
"""
|
||||
|
||||
def __init__(self, no_commit: bool = False):
|
||||
database_url = self.get_database_url()
|
||||
database_url = self.get_database_url() # TODO: This is called every time a connection is established, maybe make it once and reference it?
|
||||
self.engine = create_async_engine(database_url, future=True)
|
||||
self.session_maker = sessionmaker(self.engine, class_=AsyncSession)
|
||||
self.session = None
|
||||
|
@ -5,7 +5,7 @@ All table classes should be suffixed with `Model`.
|
||||
|
||||
from enum import Enum, auto
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger
|
||||
from sqlalchemy import Column, Integer, String, DateTime, BigInteger, UniqueConstraint, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@ -24,7 +24,6 @@ class AuditModel(Base):
|
||||
discord_user_id = Column(BigInteger, nullable=False)
|
||||
message = Column(String, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
|
||||
|
||||
class SentArticleModel(Base):
|
||||
@ -39,7 +38,7 @@ class SentArticleModel(Base):
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
article_url = Column(String, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
|
||||
class RssSourceModel(Base):
|
||||
@ -50,9 +49,17 @@ class RssSourceModel(Base):
|
||||
__tablename__ = "rss_source"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
nick = Column(String, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
rss_url = Column(String, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
feed_channels = relationship("FeedChannelModel", cascade="all, delete")
|
||||
|
||||
# the nickname must be unique, but only within the same discord server
|
||||
__table_args__ = (
|
||||
UniqueConstraint('nick', 'discord_server_id', name='uq_nick_discord_server'),
|
||||
)
|
||||
|
||||
|
||||
class FeedChannelModel(Base):
|
||||
@ -64,4 +71,6 @@ class FeedChannelModel(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
discord_channel_id = Column(BigInteger, nullable=False)
|
||||
active = Column(Integer, default=True, nullable=False)
|
||||
discord_server_id = Column(BigInteger, nullable=False)
|
||||
search_name = Column(String, nullable=False)
|
||||
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)
|
||||
|
106
src/extensions/channels.py
Normal file
106
src/extensions/channels.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""
|
||||
Extension for the `ChannelCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `ChannelCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
from discord import Interaction, TextChannel
|
||||
from discord.ext import commands
|
||||
from discord.app_commands import Group, Choice, autocomplete
|
||||
|
||||
from db import DatabaseManager, FeedChannelModel
|
||||
from utils import followup
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChannelCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
async def autocomplete_existing_feeds(self, inter: Interaction, current: str):
|
||||
"""Returns a list of existing RSS + Channel feeds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
current : str
|
||||
The current text entered for the autocomplete.
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
FeedChannelModel.discord_server_id == inter.guild_id,
|
||||
FeedChannelModel.search_name.ilike(f"%{current}%") # is this secure from SQL Injection atk ?
|
||||
)
|
||||
query = select(FeedChannelModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
feeds = [
|
||||
Choice(name=feed.search_name, value=feed.id)
|
||||
for feed in result.scalars().all()
|
||||
]
|
||||
|
||||
return feeds
|
||||
|
||||
# All RSS commands belong to this group.
|
||||
channel_group = Group(
|
||||
name="channel",
|
||||
description="Commands for channel assignment.",
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
channel_group.command(name="include-feed")
|
||||
async def include_feed(self, inter: Interaction, rss: str, channel: TextChannel):
|
||||
"""Include a feed within the specified channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
rss : str
|
||||
The RSS feed to include.
|
||||
channel : TextChannel
|
||||
The channel to include the feed in.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
await followup(inter, "Ping")
|
||||
|
||||
channel_group.command(name="exclude-feed")
|
||||
@autocomplete(option=autocomplete_existing_feeds)
|
||||
async def exclude_feed(self, inter: Interaction, option: int):
|
||||
"""Undo command for the `/channel include-feed` command.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
option : str
|
||||
The RSS feed and channel to exclude.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
await followup(inter, "Pong")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `ChannelCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = ChannelCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
@ -1,184 +0,0 @@
|
||||
"""
|
||||
Extension for the `CommandCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `CommandCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import validators
|
||||
|
||||
import aiohttp
|
||||
import textwrap
|
||||
import feedparser
|
||||
from markdownify import markdownify
|
||||
from discord import app_commands, Interaction, Embed
|
||||
from discord.ext import commands, tasks
|
||||
from sqlalchemy import insert, select, update, and_, or_
|
||||
|
||||
from db import DatabaseManager, AuditModel, SentArticleModel, RssSourceModel, FeedChannelModel
|
||||
from feed import Feeds, get_source
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
async def get_rss_data(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
items = await response.text(), response.status
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class CommandCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
rss_group = app_commands.Group(
|
||||
name="rss",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True
|
||||
)
|
||||
|
||||
@rss_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, url: str):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# validate the input
|
||||
if not validators.url(url):
|
||||
await inter.followup.send(
|
||||
"The URL you have entered is malformed or invalid:\n"
|
||||
f"`{url=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
if status_code != 200:
|
||||
await inter.followup.send(
|
||||
f"The URL provided returned an invalid status code:\n"
|
||||
f"{url=}, {status_code=}",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
feed = feedparser.parse(feed_data)
|
||||
if not feed.version:
|
||||
await inter.followup.send(
|
||||
f"The provided URL '{url}' does not seem to be a valid RSS feed.",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
rss_url = url
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
await inter.followup.send("RSS source added")
|
||||
|
||||
@rss_group.command(name="remove")
|
||||
async def remove_rss_source(self, inter: Interaction, number: int | None=None, url: str | None = None):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
def exists(item) -> bool:
|
||||
"""
|
||||
Shorthand for `is not None`. Cant just use `if not number` because 0 int will pass.
|
||||
Ironically with this func & comment the code is longer, but at least I can read it ...
|
||||
"""
|
||||
|
||||
return item is not None
|
||||
|
||||
url_exists = exists(url)
|
||||
num_exists = exists(number)
|
||||
|
||||
if (url_exists and num_exists) or (not url_exists and not num_exists):
|
||||
await inter.followup.send(
|
||||
"Please only specify either the existing rss number or url, "
|
||||
"enter at least one of these, but don't enter both."
|
||||
)
|
||||
return
|
||||
|
||||
if url_exists and not validators.url(url):
|
||||
await inter.followup.send(
|
||||
"The URL you have entered is malformed or invalid:\n"
|
||||
f"`{url=}`",
|
||||
suppress_embeds=True
|
||||
)
|
||||
return
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
query = update(RssSourceModel).where(whereclause).values(active=False)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
await inter.followup.send(f"I've updated {result.rowcount} rows")
|
||||
|
||||
@rss_group.command(name="list")
|
||||
@app_commands.choices(filter=[
|
||||
app_commands.Choice(name="Active Only [default]", value=1),
|
||||
app_commands.Choice(name="Inactive Only", value=0),
|
||||
app_commands.Choice(name="All", value=2),
|
||||
])
|
||||
async def list_rss_sources(self, inter: Interaction, filter: app_commands.Choice[int]):
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
if filter.value == 2:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
else:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.active == filter.value # should result to 0 or 1
|
||||
)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
rss_sources = result.scalars().all()
|
||||
embed_fields = [{
|
||||
"name": f"[{i}]",
|
||||
"value": f"{rss.rss_url} | {'inactive' if not rss.active else 'active'}"
|
||||
} for i, rss in enumerate(rss_sources)]
|
||||
|
||||
if not embed_fields:
|
||||
await inter.followup.send("It looks like you have no rss sources.")
|
||||
return
|
||||
|
||||
embed = Embed(
|
||||
title="RSS Sources",
|
||||
description="Here are your rss sources:"
|
||||
)
|
||||
|
||||
for field in embed_fields:
|
||||
embed.add_field(**field, inline=False)
|
||||
|
||||
# output = "Your rss sources:\n\n"
|
||||
# output += "\n".join([f"[{i+1}] {rss.rss_url=} {bool(rss.active)=}" for i, rss in enumerate(rss_sources)])
|
||||
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `CommandCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = CommandCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
349
src/extensions/rss.py
Normal file
349
src/extensions/rss.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""
|
||||
Extension for the `RssCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `RssCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import validators
|
||||
from typing import Tuple
|
||||
|
||||
import textwrap
|
||||
import feedparser
|
||||
from markdownify import markdownify
|
||||
from discord import Interaction, Embed, Colour
|
||||
from discord.ext import commands
|
||||
from discord.app_commands import Choice, Group, autocomplete, choices
|
||||
from sqlalchemy import insert, select, and_, delete
|
||||
|
||||
from utils import get_rss_data, followup, audit
|
||||
from feed import get_source, Source
|
||||
from db import DatabaseManager, SentArticleModel, RssSourceModel
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
rss_list_sort_choices = [
|
||||
Choice(name="Nickname", value=0),
|
||||
Choice(name="Date Added", value=1)
|
||||
]
|
||||
|
||||
# TODO SECURITY: a potential attack is that the user submits an rss feed then changes the target resource.
|
||||
# Run a period task to check this.
|
||||
async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, feedparser.FeedParserDict | None]:
|
||||
"""Validate a provided RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nickname : str
|
||||
Nickname of the source. Must not contain URL.
|
||||
url : str
|
||||
URL of the source. Must be URL with valid status code and be an RSS feed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
String invalid message if invalid, NoneType if valid.
|
||||
FeedParserDict or None
|
||||
The feed parsed from the given URL or None if invalid.
|
||||
"""
|
||||
|
||||
# Ensure the URL is valid
|
||||
if not validators.url(url):
|
||||
return f"The URL you have entered is malformed or invalid:\n`{url=}`", None
|
||||
|
||||
# Check the nickname is not a URL
|
||||
if validators.url(nickname):
|
||||
return "It looks like the nickname you have entered is a URL.\n" \
|
||||
f"For security reasons, this is not allowed.\n`{nickname=}`", None
|
||||
|
||||
|
||||
feed_data, status_code = await get_rss_data(url)
|
||||
|
||||
# Check the URL status code is valid
|
||||
if status_code != 200:
|
||||
return f"The URL provided returned an invalid status code:\n{url=}, {status_code=}", None
|
||||
|
||||
# Check the contents is actually an RSS feed.
|
||||
feed = feedparser.parse(feed_data)
|
||||
if not feed.version:
|
||||
return f"The provided URL '{url}' does not seem to be a valid RSS feed.", None
|
||||
|
||||
return None, feed
|
||||
|
||||
|
||||
class RssCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
async def source_autocomplete(self, inter: Interaction, nickname: str):
|
||||
"""Provides RSS source autocomplete functionality for commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
nickname : str
|
||||
_description_
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of app_commands.Choice
|
||||
_description_
|
||||
"""
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.nick.ilike(f"%{nickname}%")
|
||||
)
|
||||
query = select(RssSourceModel).where(whereclause)
|
||||
result = await database.session.execute(query)
|
||||
sources = [
|
||||
Choice(name=rss.nick, value=rss.rss_url)
|
||||
for rss in result.scalars().all()
|
||||
]
|
||||
|
||||
return sources
|
||||
|
||||
# All RSS commands belong to this group.
|
||||
rss_group = Group(
|
||||
name="rss",
|
||||
description="Commands for rss sources.",
|
||||
guild_only=True # We store guild IDs in the database, so guild only = True
|
||||
)
|
||||
|
||||
@rss_group.command(name="add")
|
||||
async def add_rss_source(self, inter: Interaction, nickname: str, url: str):
|
||||
"""Add a new RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
nickname : str
|
||||
A name used to identify the RSS source.
|
||||
url : str
|
||||
The RSS feed URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
illegal_message, feed = await validate_rss_source(nickname, url)
|
||||
if illegal_message:
|
||||
await followup(inter, illegal_message, suppress_embeds=True)
|
||||
return
|
||||
|
||||
log.debug("RSS feed added")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(RssSourceModel).values(
|
||||
discord_server_id = inter.guild_id,
|
||||
rss_url = url,
|
||||
nick=nickname
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
embed = Embed(title="RSS Feed Added", colour=Colour.dark_green())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=feed.get("feed", {}).get("image", {}).get("href"))
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="remove")
|
||||
@autocomplete(url=source_autocomplete)
|
||||
async def remove_rss_source(self, inter: Interaction, url: str):
|
||||
"""Delete an existing RSS source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
url : str
|
||||
The RSS source to be removed. Autocomplete or enter the URL.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
log.debug(f"Attempting to remove RSS source ({url=})")
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
select_result = await database.session.execute(
|
||||
select(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
)
|
||||
rss_source = select_result.scalars().one()
|
||||
nickname = rss_source.nick
|
||||
|
||||
delete_result = await database.session.execute(
|
||||
delete(RssSourceModel).filter(
|
||||
and_(
|
||||
RssSourceModel.discord_server_id == inter.guild_id,
|
||||
RssSourceModel.rss_url == url
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await audit(self,
|
||||
f"Added RSS source ({nickname=}, {url=})",
|
||||
inter.user.id, database=database
|
||||
)
|
||||
|
||||
if not delete_result.rowcount:
|
||||
await followup(inter, "Couldn't find any RSS sources with this name.")
|
||||
return
|
||||
|
||||
source = get_source(url)
|
||||
|
||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||
embed.add_field(name="Nickname", value=nickname)
|
||||
embed.add_field(name="URL", value=url)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="list")
|
||||
@choices(sort=rss_list_sort_choices)
|
||||
async def list_rss_sources(self, inter: Interaction, sort: Choice[int]=None, sort_reverse: bool=False):
|
||||
"""Provides a with a list of RSS sources available for the current server.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
# Default to the first choice if not specified.
|
||||
if type(sort) is Choice:
|
||||
description = "Sort by "
|
||||
description += "Nickname " if sort.value == 0 else "Date Added "
|
||||
description += '\U000025BC' if sort_reverse else '\U000025B2'
|
||||
else:
|
||||
sort = rss_list_sort_choices[0]
|
||||
description = ""
|
||||
|
||||
sort = sort if type(sort) == Choice else rss_list_sort_choices[0]
|
||||
|
||||
match sort.value, sort_reverse:
|
||||
case 0, False:
|
||||
order_by = RssSourceModel.nick.asc()
|
||||
case 0, True:
|
||||
order_by = RssSourceModel.nick.desc()
|
||||
case 1, False: # NOTE:
|
||||
order_by = RssSourceModel.created.desc() # Datetime order is inversed because we want the latest
|
||||
case 1, True: # date first, not the oldest as it would sort otherwise.
|
||||
order_by = RssSourceModel.created.asc()
|
||||
case _, _:
|
||||
raise ValueError("Unknown sort: %s" % sort)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
whereclause = and_(RssSourceModel.discord_server_id == inter.guild_id)
|
||||
query = select(RssSourceModel).where(whereclause).order_by(order_by)
|
||||
result = await database.session.execute(query)
|
||||
|
||||
rss_sources = result.scalars().all()
|
||||
|
||||
if not rss_sources:
|
||||
await followup(inter, "It looks like you have no rss sources.")
|
||||
return
|
||||
|
||||
output = "\n".join([f"{i}. **[{rss.nick}]({rss.rss_url})** " for i, rss in enumerate(rss_sources)])
|
||||
|
||||
embed = Embed(
|
||||
title="Saved RSS Feeds",
|
||||
description=f"{description}\n\n{output}",
|
||||
colour=Colour.lighter_grey()
|
||||
)
|
||||
|
||||
await followup(inter, embed=embed)
|
||||
|
||||
@rss_group.command(name="fetch")
|
||||
@autocomplete(rss=source_autocomplete)
|
||||
async def fetch_rss(self, inter: Interaction, rss: str, max: int=1):
|
||||
# """"""
|
||||
|
||||
await inter.response.defer()
|
||||
|
||||
if max > 5:
|
||||
followup(inter, "It looks like you have requested too many articles.\nThe limit is 5")
|
||||
return
|
||||
|
||||
invalid_message, feed = await validate_rss_source("", rss)
|
||||
if invalid_message:
|
||||
await followup(inter, invalid_message)
|
||||
return
|
||||
|
||||
source = Source.from_parsed(feed)
|
||||
articles = source.get_latest_articles(max)
|
||||
|
||||
if not articles:
|
||||
await followup(inter, "Sorry, I couldn't find any articles from this feed.")
|
||||
return
|
||||
|
||||
embeds = []
|
||||
for article in articles:
|
||||
md_description = markdownify(article.description, strip=("img",))
|
||||
article_description = textwrap.shorten(md_description, 4096)
|
||||
|
||||
embed = Embed(
|
||||
title=article.title,
|
||||
description=article_description,
|
||||
url=article.url,
|
||||
timestamp=article.published,
|
||||
colour=Colour.brand_red()
|
||||
)
|
||||
thumbail_url = await article.get_thumbnail_url()
|
||||
thumbail_url = thumbail_url if validators.url(thumbail_url) else None
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
embed.set_image(url=thumbail_url)
|
||||
embed.set_footer(text=article.author)
|
||||
embed.set_author(
|
||||
name=source.name,
|
||||
url=source.url,
|
||||
)
|
||||
embeds.append(embed)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(SentArticleModel).values([
|
||||
{
|
||||
"discord_server_id": inter.guild_id,
|
||||
"discord_channel_id": inter.channel_id,
|
||||
"discord_message_id": inter.id,
|
||||
"article_url": article.url,
|
||||
}
|
||||
for article in articles
|
||||
])
|
||||
await database.session.execute(query)
|
||||
await audit(self, f"User is requesting {max} articles from {source.name}", inter.user.id, database=database)
|
||||
|
||||
await followup(inter, embeds=embeds)
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `RssCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = RssCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
35
src/extensions/tasks.py
Normal file
35
src/extensions/tasks.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Extension for the `TaskCog`.
|
||||
Loading this file via `commands.Bot.load_extension` will add `TaskCog` to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from discord.ext import commands
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskCog(commands.Cog):
|
||||
"""
|
||||
Command cog.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds `TaskCog` to the bot.
|
||||
"""
|
||||
|
||||
cog = TaskCog(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
@ -1,88 +0,0 @@
|
||||
"""
|
||||
Extension for the `test` cog.
|
||||
Loading this file via `commands.Bot.load_extension` will add the `test` cog to the bot.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import textwrap
|
||||
from markdownify import markdownify
|
||||
from discord import app_commands, Interaction, Embed
|
||||
from discord.ext import commands, tasks
|
||||
from sqlalchemy import insert, select
|
||||
|
||||
from db import DatabaseManager, AuditModel, SentArticleModel
|
||||
from feed import Feeds, get_source
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Test(commands.Cog):
|
||||
"""
|
||||
News cog.
|
||||
Delivers embeds of news articles to discord channels.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
super().__init__()
|
||||
self.bot = bot
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_ready(self):
|
||||
log.info(f"{self.__class__.__name__} cog is ready")
|
||||
|
||||
@app_commands.command(name="test-latest-article")
|
||||
# @app_commands.choices(source=[
|
||||
# app_commands.Choice(name="The Babylon Bee", value=Feeds.THE_BABYLON_BEE),
|
||||
# app_commands.Choice(name="The Upper Lip", value=Feeds.THE_UPPER_LIP),
|
||||
# app_commands.Choice(name="BBC News", value=Feeds.BBC_NEWS),
|
||||
# ])
|
||||
async def test_bee(self, inter: Interaction, source: Feeds):
|
||||
|
||||
await inter.response.defer()
|
||||
await self.bot.audit("Requesting latest article.", inter.user.id)
|
||||
|
||||
source = get_source(source)
|
||||
article = source.get_latest_article()
|
||||
|
||||
md_description = markdownify(article.description, strip=("img",))
|
||||
article_description = textwrap.shorten(md_description, 4096)
|
||||
|
||||
embed = Embed(
|
||||
title=article.title,
|
||||
description=article_description,
|
||||
url=article.url,
|
||||
timestamp=article.published,
|
||||
)
|
||||
embed.set_thumbnail(url=source.icon_url)
|
||||
embed.set_image(url=await article.get_thumbnail_url())
|
||||
embed.set_footer(text=article.author)
|
||||
embed.set_author(
|
||||
name=source.name,
|
||||
url=source.url,
|
||||
)
|
||||
|
||||
async with DatabaseManager() as database:
|
||||
query = insert(SentArticleModel).values(
|
||||
discord_server_id=inter.guild_id,
|
||||
discord_channel_id=inter.channel_id,
|
||||
discord_message_id=inter.id,
|
||||
article_url=article.url
|
||||
)
|
||||
await database.session.execute(query)
|
||||
|
||||
await inter.followup.send(embed=embed)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
"""
|
||||
Setup function for this extension.
|
||||
Adds the `ErrorCog` cog to the bot.
|
||||
"""
|
||||
|
||||
cog = Test(bot)
|
||||
await bot.add_cog(cog)
|
||||
log.info(f"Added {cog.__class__.__name__} cog")
|
167
src/feed.py
167
src/feed.py
@ -1,7 +1,9 @@
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
@ -10,85 +12,136 @@ from bs4 import BeautifulSoup as bs4
|
||||
from feedparser import FeedParserDict, parse
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Feeds(Enum):
|
||||
THE_UPPER_LIP = "https://theupperlip.co.uk/rss"
|
||||
THE_BABYLON_BEE= "https://babylonbee.com/feed"
|
||||
BBC_NEWS = "https://feeds.bbci.co.uk/news/rss.xml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Source:
|
||||
|
||||
name: str
|
||||
url: str
|
||||
icon_url: str
|
||||
feed: FeedParserDict
|
||||
|
||||
@classmethod
|
||||
def from_parsed(cls, feed:FeedParserDict):
|
||||
|
||||
# print(json.dumps(feed, indent=8))
|
||||
return cls(
|
||||
name=feed.channel.title,
|
||||
url=feed.channel.link,
|
||||
icon_url=feed.feed.image.href,
|
||||
feed=feed
|
||||
)
|
||||
|
||||
def get_latest_article(self):
|
||||
return Article.from_parsed(self.feed)
|
||||
dumps = lambda _dict: json.dumps(_dict, indent=8)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Article:
|
||||
|
||||
title: str
|
||||
description: str
|
||||
url: str
|
||||
published: datetime
|
||||
"""Represents a news article, or entry from an RSS feed."""
|
||||
|
||||
title: str | None
|
||||
description: str | None
|
||||
url: str | None
|
||||
published: datetime | None
|
||||
author: str | None
|
||||
|
||||
@classmethod
|
||||
def from_parsed(cls, feed:FeedParserDict):
|
||||
entry = feed.entries[0]
|
||||
# log.debug(json.dumps(entry, indent=8))
|
||||
def from_entry(cls, entry:FeedParserDict):
|
||||
"""Create an Article from an RSS feed entry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
entry : FeedParserDict
|
||||
An entry pulled from a complete FeedParserDict object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Article
|
||||
The Article created from the feed entry.
|
||||
"""
|
||||
|
||||
log.debug("Creating Article from entry: %s", dumps(entry))
|
||||
|
||||
published_parsed = entry.get("published_parsed")
|
||||
published = datetime(*entry.published_parsed[0:-2]) if published_parsed else None
|
||||
|
||||
return cls(
|
||||
title=entry.title,
|
||||
description=entry.description,
|
||||
url=entry.link,
|
||||
published=datetime(*entry.published_parsed[0:-2]),
|
||||
author = entry.get("author", None)
|
||||
title=entry.get("title"),
|
||||
description=entry.get("description"),
|
||||
url=entry.get("link"),
|
||||
published=published,
|
||||
author = entry.get("author")
|
||||
)
|
||||
|
||||
async def get_thumbnail_url(self):
|
||||
async def get_thumbnail_url(self) -> str | None:
|
||||
"""Returns the thumbnail URL for an article.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The thumbnail URL, or None if not found.
|
||||
"""
|
||||
|
||||
"""
|
||||
log.debug("Fetching thumbnail for article: %s", self)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.url) as response:
|
||||
html = await response.text()
|
||||
|
||||
# Parse the thumbnail for the news story
|
||||
soup = bs4(html, "html.parser")
|
||||
image_element = soup.select_one("meta[property='og:image']")
|
||||
return image_element.get("content") if image_element else None
|
||||
|
||||
|
||||
def get_source(feed: Feeds) -> Source:
|
||||
@dataclass
|
||||
class Source:
|
||||
"""Represents an RSS source."""
|
||||
|
||||
name: str | None
|
||||
url: str | None
|
||||
icon_url: str | None
|
||||
feed: FeedParserDict
|
||||
|
||||
@classmethod
|
||||
def from_parsed(cls, feed:FeedParserDict):
|
||||
"""Returns a Source object from a parsed feed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feed : FeedParserDict
|
||||
The feed used to create the Source.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
The Source object
|
||||
"""
|
||||
|
||||
log.debug("Creating Source from feed: %s", dumps(feed))
|
||||
|
||||
return cls(
|
||||
name=feed.get("channel", {}).get("title"),
|
||||
url=feed.get("channel", {}).get("link"),
|
||||
icon_url=feed.get("feed", {}).get("image", {}).get("href"),
|
||||
feed=feed
|
||||
)
|
||||
|
||||
def get_latest_articles(self, max: int) -> list[Article]:
|
||||
"""Returns a list of Article objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max : int
|
||||
The maximum number of articles to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of Article
|
||||
A list of Article objects.
|
||||
"""
|
||||
|
||||
log.debug("Fetching latest articles from %s, max=%s", self, max)
|
||||
|
||||
return [
|
||||
Article.from_entry(entry)
|
||||
for i, entry in enumerate(self.feed.entries)
|
||||
if i < max
|
||||
]
|
||||
|
||||
|
||||
def get_source(rss_url: str) -> Source:
|
||||
"""_summary_
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rss_url : str
|
||||
_description_
|
||||
|
||||
Returns
|
||||
-------
|
||||
Source
|
||||
_description_
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
parsed_feed = parse("https://gitea.corbz.dev/corbz/BBC-News-Bot/rss/branch/main/src/extensions/news.py")
|
||||
parsed_feed = parse(rss_url) # TODO: make asyncronous
|
||||
return Source.from_parsed(parsed_feed)
|
||||
|
||||
|
||||
def get_test():
|
||||
|
||||
parsed = parse(Feeds.THE_UPPER_LIP.value)
|
||||
print(json.dumps(parsed, indent=4))
|
||||
return parsed
|
||||
|
31
src/utils.py
Normal file
31
src/utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""A collection of utility functions that can be used in various places."""
|
||||
|
||||
import aiohttp
|
||||
import logging
|
||||
|
||||
from discord import Interaction
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
async def get_rss_data(url: str):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
items = await response.text(), response.status
|
||||
|
||||
return items
|
||||
|
||||
async def followup(inter: Interaction, *args, **kwargs):
|
||||
"""Shorthand for following up on an interaction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inter : Interaction
|
||||
Represents an app command interaction.
|
||||
"""
|
||||
|
||||
await inter.followup.send(*args, **kwargs)
|
||||
|
||||
async def audit(cog, *args, **kwargs):
|
||||
"""Shorthand for auditing an interaction."""
|
||||
|
||||
await cog.bot.audit(*args, **kwargs)
|
Loading…
x
Reference in New Issue
Block a user