admin commands
This commit is contained in:
parent
0e5af1d752
commit
45e5c6a04f
12
src/bot.py
12
src/bot.py
@ -47,8 +47,18 @@ class DiscordBot(commands.Bot):
|
|||||||
await self.load_extension(f"extensions.{path.stem}")
|
await self.load_extension(f"extensions.{path.stem}")
|
||||||
|
|
||||||
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
|
async def audit(self, message: str, user_id: int, database: DatabaseManager=None):
|
||||||
|
"""Shorthand for auditing an action.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
message : str
|
||||||
|
The message to be audited.
|
||||||
|
user_id : int
|
||||||
|
Discord ID of the user being audited.
|
||||||
|
database : DatabaseManager, optional
|
||||||
|
An existing database connection to be used if specified, by default None
|
||||||
|
"""
|
||||||
|
|
||||||
message = f"Requesting latest article"
|
|
||||||
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
query = insert(AuditModel).values(discord_user_id=user_id, message=message)
|
||||||
|
|
||||||
if database:
|
if database:
|
||||||
|
@ -47,6 +47,9 @@ class SentArticleModel(Base):
|
|||||||
article_url = Column(String, nullable=False)
|
article_url = Column(String, nullable=False)
|
||||||
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
when = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||||
|
|
||||||
|
# feed_channel_id = Column(Integer, ForeignKey("feed_channel.id"), nullable=False)
|
||||||
|
# feed_channel = relationship("FeedChannelModel", lazy="joined", cascade="all, delete")
|
||||||
|
|
||||||
|
|
||||||
class RssSourceModel(Base):
|
class RssSourceModel(Base):
|
||||||
"""
|
"""
|
||||||
@ -80,8 +83,9 @@ class FeedChannelModel(Base):
|
|||||||
discord_channel_id = Column(BigInteger, nullable=False)
|
discord_channel_id = Column(BigInteger, nullable=False)
|
||||||
discord_server_id = Column(BigInteger, nullable=False)
|
discord_server_id = Column(BigInteger, nullable=False)
|
||||||
search_name = Column(String, nullable=False)
|
search_name = Column(String, nullable=False)
|
||||||
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)
|
# created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102
|
||||||
|
|
||||||
|
rss_source_id = Column(Integer, ForeignKey('rss_source.id'), nullable=False)
|
||||||
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
|
rss_source = relationship("RssSourceModel", overlaps="feed_channels", lazy="joined", cascade="all, delete")
|
||||||
|
|
||||||
# the rss source must be unique, but only within the same discord channel
|
# the rss source must be unique, but only within the same discord channel
|
||||||
|
@ -15,7 +15,7 @@ from sqlalchemy import insert, select, and_, delete
|
|||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
|
|
||||||
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
|
from utils import get_rss_data, followup, audit, followup_error # pylint: disable=E0401
|
||||||
from feed import get_source, Source # pylint: disable=E0401
|
from feed import get_source, get_unparsed_feed, Source # pylint: disable=E0401
|
||||||
from db import ( # pylint: disable=E0401
|
from db import ( # pylint: disable=E0401
|
||||||
DatabaseManager,
|
DatabaseManager,
|
||||||
SentArticleModel,
|
SentArticleModel,
|
||||||
@ -78,6 +78,23 @@ async def validate_rss_source(nickname: str, url: str) -> Tuple[str | None, Feed
|
|||||||
|
|
||||||
return None, feed
|
return None, feed
|
||||||
|
|
||||||
|
async def set_all_articles_as_sent(inter, channel: TextChannel, rss_url: str):
|
||||||
|
unparsed_feed = await get_unparsed_feed(rss_url)
|
||||||
|
source = Source.from_parsed(parse(unparsed_feed))
|
||||||
|
articles = source.get_latest_articles()
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
query = insert(SentArticleModel).values([
|
||||||
|
{
|
||||||
|
"discord_server_id": inter.guild_id,
|
||||||
|
"discord_channel_id": channel.id,
|
||||||
|
"discord_message_id": -1,
|
||||||
|
"article_url": article.url,
|
||||||
|
}
|
||||||
|
for article in articles
|
||||||
|
])
|
||||||
|
await database.session.execute(query)
|
||||||
|
|
||||||
|
|
||||||
class FeedCog(commands.Cog):
|
class FeedCog(commands.Cog):
|
||||||
"""
|
"""
|
||||||
@ -221,7 +238,7 @@ class FeedCog(commands.Cog):
|
|||||||
inter.user.id, database=database
|
inter.user.id, database=database
|
||||||
)
|
)
|
||||||
|
|
||||||
source = get_source(url)
|
source = get_source(url) # TODO: replace with async function
|
||||||
|
|
||||||
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
embed = Embed(title="RSS Feed Deleted", colour=Colour.dark_red())
|
||||||
embed.add_field(name="Nickname", value=nickname)
|
embed.add_field(name="Nickname", value=nickname)
|
||||||
@ -478,14 +495,16 @@ class FeedCog(commands.Cog):
|
|||||||
@feed_group.command(name="assign")
|
@feed_group.command(name="assign")
|
||||||
@rename(rss="feed")
|
@rename(rss="feed")
|
||||||
@autocomplete(rss=autocomplete_rss_sources)
|
@autocomplete(rss=autocomplete_rss_sources)
|
||||||
async def include_feed(self, inter: Interaction, rss: int, channel: TextChannel = None):
|
async def include_feed(
|
||||||
|
self, inter: Interaction, rss: int, channel: TextChannel = None, prevent_spam: bool = True
|
||||||
|
):
|
||||||
"""Include a feed within the specified channel.
|
"""Include a feed within the specified channel.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
inter : Interaction
|
inter : Interaction
|
||||||
Represents an app command interaction.
|
Represents an app command interaction.
|
||||||
rss : str
|
rss : int
|
||||||
The RSS feed to include.
|
The RSS feed to include.
|
||||||
channel : TextChannel
|
channel : TextChannel
|
||||||
The channel to include the feed in.
|
The channel to include the feed in.
|
||||||
@ -514,6 +533,8 @@ class FeedCog(commands.Cog):
|
|||||||
|
|
||||||
await database.session.execute(insert_query)
|
await database.session.execute(insert_query)
|
||||||
|
|
||||||
|
if prevent_spam:
|
||||||
|
await set_all_articles_as_sent(inter, channel, rss_url)
|
||||||
|
|
||||||
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
await followup(inter, f"I've included [{nick}]({rss_url}) to {channel.mention}")
|
||||||
|
|
||||||
@ -628,6 +649,35 @@ class FeedCog(commands.Cog):
|
|||||||
|
|
||||||
await followup(inter, embed=embed)
|
await followup(inter, embed=embed)
|
||||||
|
|
||||||
|
admin_group = Group(
|
||||||
|
parent=feed_group,
|
||||||
|
name="admin",
|
||||||
|
description="Administration tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
@admin_group.command(name="clear-sent-articles")
|
||||||
|
async def clear_sent_articles(self, inter: Interaction):
|
||||||
|
"""Clear the database of all sent articles.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
inter : Interaction
|
||||||
|
Represents an app command interaction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
await inter.response.defer()
|
||||||
|
|
||||||
|
async with DatabaseManager() as database:
|
||||||
|
query = delete(SentArticleModel).where(and_(
|
||||||
|
SentArticleModel.discord_server_id == inter.guild_id
|
||||||
|
))
|
||||||
|
result = await database.session.execute(query)
|
||||||
|
|
||||||
|
await followup(inter,
|
||||||
|
f"{result.rowcount} sent articles have been cleared from the database. "
|
||||||
|
"I will no longer recognise these articles as sent, and will send them "
|
||||||
|
"again if they appear during the next RSS feed scan."
|
||||||
|
)
|
||||||
|
|
||||||
async def setup(bot):
|
async def setup(bot):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user