diff --git a/src/db/models.py b/src/db/models.py index 97554af..9831a7c 100644 --- a/src/db/models.py +++ b/src/db/models.py @@ -28,7 +28,7 @@ class AuditModel(Base): id = Column(Integer, primary_key=True, autoincrement=True) discord_user_id = Column(BigInteger, nullable=False) - discord_server_id = Column(BigInteger, nullable=False) + # discord_server_id = Column(BigInteger, nullable=False) # TODO: this doesnt exist, integrate it. message = Column(String, nullable=False) created = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) # pylint: disable=E1102 diff --git a/src/extensions/rss.py b/src/extensions/rss.py index b9d4b03..2b45e85 100644 --- a/src/extensions/rss.py +++ b/src/extensions/rss.py @@ -248,7 +248,7 @@ class FeedCog(commands.Cog): await inter.response.defer() try: - sources = await self.bot.functions.get_feeds(inter.guild_id) + feeds = await self.bot.functions.get_feeds(inter.guild_id) except NoResultFound: await ( Followup( @@ -261,8 +261,8 @@ class FeedCog(commands.Cog): ) else: description = "\n".join([ - f"{i}. **[{source.name}]({source.url})**" - for i, source in enumerate(sources) + f"{i}. **[{info[0]}]({info[1]})**" # info = (nick, url) + for i, info in enumerate(feeds) ]) await ( Followup( @@ -273,7 +273,6 @@ class FeedCog(commands.Cog): .send(inter) ) - # @feed_group.command(name="fetch") # @rename(max_="max") # @autocomplete(rss=source_autocomplete) @@ -308,7 +307,8 @@ class FeedCog(commands.Cog): await followup(inter, "Sorry, I couldn't find any articles from this feed.") return - embeds = [await article.to_embed() for article in articles] + async with aiohttp.ClientSession() as session: + embeds = [await article.to_embed(session) for article in articles] async with DatabaseManager() as database: query = insert(SentArticleModel).values([ @@ -630,7 +630,7 @@ class FeedCog(commands.Cog): return output = "\n".join([ - f"{i}. <#{feed.discord_channel_id}> · " + f"{i}. <#{feed.discord_channel_id}> · " # TODO: add icon indicating inaccessible channel, if is the case. f"[{feed.rss_source.nick}]({feed.rss_source.rss_url})" for i, feed in enumerate(feed_channels) ]) diff --git a/src/extensions/tasks.py b/src/extensions/tasks.py index 9f5160a..c7fcfa0 100644 --- a/src/extensions/tasks.py +++ b/src/extensions/tasks.py @@ -8,6 +8,7 @@ import datetime from os import getenv from time import process_time +import aiohttp from discord import TextChannel from discord.ext import commands, tasks from discord.errors import Forbidden @@ -61,7 +62,7 @@ class TaskCog(commands.Cog): self.rss_task.cancel() - @tasks.loop(minutes=10) + @tasks.loop(time=times) async def rss_task(self): """Automated task responsible for processing rss feeds.""" @@ -95,20 +96,23 @@ class TaskCog(commands.Cog): # TODO: integrate the `validate_feed` code into here, also do on list command and show errors. - unparsed_content = await self.bot.functions.get_unparsed_feed(feed.rss_source.rss_url) - parsed_feed = parse(unparsed_content) - source = Source.from_parsed(parsed_feed) - articles = source.get_latest_articles(5) + async with aiohttp.ClientSession() as session: - if not articles: - log.info("No articles to process") - return + unparsed_content = await get_unparsed_feed(feed.rss_source.rss_url) + parsed_feed = parse(unparsed_content) + source = Source.from_parsed(parsed_feed) + articles = source.get_latest_articles(5) - for article in articles: - await self.process_article(feed.id, article, channel, database) + if not articles: + log.info("No articles to process for %s in ", feed.rss_source.nick, feed.discord_server_id) + return + + for article in articles: + await self.process_article(feed.id, article, channel, database, session) async def process_article( - self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager + self, feed_id: int, article: Article, channel: TextChannel, database: DatabaseManager, + session: aiohttp.ClientSession ): """Process the passed article. Will send the embed to a channel if all is valid. @@ -136,11 +140,11 @@ class TaskCog(commands.Cog): log.debug("Article already processed: %s", article.url) return - embed = await article.to_embed() + embed = await article.to_embed(session) try: await channel.send(embed=embed) - except Forbidden: - log.error("Can't send article to channel: %s · %s", channel.name, channel.id) + except Forbidden as error: # TODO: find some way of informing the user about this error. + log.error("Can't send article to channel: %s · %s · %s", channel.name, channel.id, error) return query = insert(SentArticleModel).values( diff --git a/src/feed.py b/src/feed.py index 7472c6d..7b45bac 100644 --- a/src/feed.py +++ b/src/feed.py @@ -63,9 +63,14 @@ class Article: source=source ) - async def get_thumbnail_url(self) -> str | None: + async def get_thumbnail_url(self, session: aiohttp.ClientSession) -> str | None: """Returns the thumbnail URL for an article. + Parameters + ---------- + session : aiohttp.ClientSession + A client session used to get the thumbnail. + Returns ------- str or None @@ -75,11 +80,10 @@ class Article: log.debug("Fetching thumbnail for article: %s", self) try: - async with aiohttp.ClientSession() as session: - async with session.get(self.url) as response: - html = await response.text() + async with session.get(self.url, timeout=15) as response: + html = await response.text() except aiohttp.InvalidURL as error: - log.error(error) + log.error("invalid thumbnail url: %s", error) return None soup = bs4(html, "html.parser") @@ -90,9 +94,14 @@ class Article: image_content = image_element.get("content") return image_content if validators.url(image_content) else None - async def to_embed(self) -> Embed: + async def to_embed(self, session: aiohttp.ClientSession) -> Embed: """Creates and returns a Discord Embed object from the article. + Parameters + ---------- + session : aiohttp.ClientSession + A client session used to get additional article data. + Returns ------- Embed @@ -110,7 +119,7 @@ class Article: embed_url = self.url if validators.url(self.url) else None author_url = self.source.url if validators.url(self.source.url) else None icon_url = self.source.icon_url if validators.url(self.source.icon_url) else None - thumb_url = await self.get_thumbnail_url() # validation done inside func + thumb_url = await self.get_thumbnail_url(session) # validation done inside func embed = Embed( title=title, @@ -315,9 +324,9 @@ class Functions: return await Source.from_url(url) - async def get_feeds(self, guild_id: int) -> list[Source]: + async def get_feeds(self, guild_id: int) -> list[tuple[str, str]]: """Returns a list of fetched Feed objects from the database. - Note: a request will be made too all found Feed URLs. + Note: a request will be made too all found Feed UR Ls. Parameters ---------- @@ -326,7 +335,7 @@ class Functions: Returns ------- - list[Source] + list[tuple[str, str]] List of Source objects, resulting from the query. Raises @@ -344,7 +353,7 @@ class Functions: if not rss_sources: raise NoResultFound - return [await Source.from_url(feed.rss_url) for feed in rss_sources] + return [(feed.nick, feed.rss_url) for feed in rss_sources] async def assign_feed( self, url: str, channel_name: str, channel_id: int, guild_id: int diff --git a/src/utils.py b/src/utils.py index d4a03f5..98cb37d 100644 --- a/src/utils.py +++ b/src/utils.py @@ -13,7 +13,10 @@ async def fetch(session, url: str) -> str: async with session.get(url) as response: return await response.text() -async def get_unparsed_feed(url: str): +async def get_unparsed_feed(url: str, session: aiohttp.ClientSession=None): + if session is not None: + return await fetch(session, url) + async with aiohttp.ClientSession() as session: return await fetch(session, url)