Some checks are pending
Build and Push Docker Image / build (push) Waiting to run
342 lines
10 KiB
Python
342 lines
10 KiB
Python
# -*- encoding: utf-8 -*-
|
|
|
|
import logging
|
|
|
|
from django.conf import settings
|
|
from rest_framework import serializers
|
|
|
|
from apps.home.models import (
|
|
Server,
|
|
ContentFilter,
|
|
MessageMutator,
|
|
MessageStyle,
|
|
DiscordChannel,
|
|
Subscription,
|
|
Content,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# region Dynamic Model
|
|
|
|
# This DynamicModelSerializer is from a StackOverflow user in an obscure thread.
|
|
# I wish that I could remember which thread, because god bless that man.
|
|
|
|
class DynamicModelSerializer(serializers.ModelSerializer):
|
|
"""
|
|
For use with GET requests, to specify which fields to include or exclude
|
|
Mimics some graphql functionality.
|
|
|
|
Usage: Inherit your ModelSerializer with this class. Add "only_fields" or
|
|
"exclude_fields" to the query parameters of your GET request.
|
|
|
|
This also works with nested foreign keys, for example:
|
|
?only_fields=name,age&company__only_fields=id,name
|
|
|
|
Some more examples:
|
|
|
|
?only_fields=company,name&company__exclude_fields=name
|
|
?exclude_fields=name&company__only_fields=id
|
|
?company__exclude_fields=name
|
|
|
|
Note: the Foreign Key serializer must also inherit from this class
|
|
"""
|
|
|
|
def only_keep_fields(self, fields_to_keep):
|
|
fields_to_keep = set(fields_to_keep.split(","))
|
|
all_fields = set(self.fields.keys())
|
|
for field in all_fields - fields_to_keep:
|
|
self.fields.pop(field, None)
|
|
|
|
def exclude_fields(self, fields_to_exclude):
|
|
fields_to_exclude = fields_to_exclude.split(",")
|
|
for field in fields_to_exclude:
|
|
self.fields.pop(field, None)
|
|
|
|
def remove_unwanted_fields(self, dynamic_params):
|
|
if fields_to_keep := dynamic_params.pop("only_fields", None):
|
|
self.only_keep_fields(fields_to_keep)
|
|
|
|
if fields_to_exclude := dynamic_params.pop("exclude_fields", None):
|
|
self.exclude_fields(fields_to_exclude)
|
|
|
|
def get_or_create_dynamic_params(self, child):
|
|
if "dynamic_params" not in self.fields[child]._context:
|
|
self.fields[child]._context.update({"dynamic_params": {}})
|
|
return self.fields[child]._context["dynamic_params"]
|
|
|
|
@staticmethod
|
|
def split_param(dynamic_param):
|
|
crumbs = dynamic_param.split("__")
|
|
return crumbs[0], "__".join(crumbs[1:]) if len(crumbs) > 1 else None
|
|
|
|
def set_dynamic_params_for_children(self, dynamic_params):
|
|
for param, fields in dynamic_params.items():
|
|
child, child_dynamic_param = self.split_param(param)
|
|
if child in set(self.fields.keys()):
|
|
dynamic_params = self.get_or_create_dynamic_params(child)
|
|
dynamic_params.update({child_dynamic_param: fields})
|
|
|
|
@staticmethod
|
|
def is_param_dynamic(p):
|
|
return p.endswith("only_fields") or p.endswith("exclude_fields")
|
|
|
|
def get_dynamic_params_for_root(self, request):
|
|
query_params = request.query_params.items()
|
|
return {k: v for k, v in query_params if self.is_param_dynamic(k)}
|
|
|
|
def get_dynamic_params(self):
|
|
"""
|
|
When dynamic params get passed down in set_context_for_children
|
|
If the child is a subclass of ListSerializer (has many=True)
|
|
The context must be fetched from ListSerializer Class
|
|
"""
|
|
if isinstance(self.parent, serializers.ListSerializer):
|
|
return self.parent._context.get("dynamic_params", {})
|
|
return self._context.get("dynamic_params", {})
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
request = kwargs.get("context", {}).get("request")
|
|
super().__init__(*args, **kwargs)
|
|
is_root = bool(request)
|
|
|
|
if is_root:
|
|
if request.method != "GET":
|
|
return
|
|
|
|
dynamic_params = self.get_dynamic_params_for_root(request)
|
|
self._context.update({"dynamic_params": dynamic_params})
|
|
|
|
def to_representation(self, *args, **kwargs):
|
|
if dynamic_params := self.get_dynamic_params().copy():
|
|
self.remove_unwanted_fields(dynamic_params)
|
|
self.set_dynamic_params_for_children(dynamic_params)
|
|
|
|
return super().to_representation(*args, **kwargs)
|
|
|
|
class Meta:
|
|
abstract = True
|
|
|
|
|
|
# region Servers
|
|
|
|
class ServerSerializer(DynamicModelSerializer):
|
|
id = serializers.CharField()
|
|
|
|
class Meta:
|
|
model = Server
|
|
fields = (
|
|
"id",
|
|
"name",
|
|
"icon_hash",
|
|
"is_bot_operational",
|
|
"active"
|
|
)
|
|
|
|
|
|
# region Filters
|
|
|
|
class ContentFilterSerializer(DynamicModelSerializer):
|
|
class Meta:
|
|
model = ContentFilter
|
|
fields = (
|
|
"id",
|
|
"server",
|
|
"name",
|
|
"match",
|
|
"matching_algorithm",
|
|
"is_insensitive",
|
|
"is_whitelist"
|
|
)
|
|
|
|
|
|
# region Msg Mutators
|
|
|
|
class MessageMutatorSerializer(DynamicModelSerializer):
|
|
class Meta:
|
|
model = MessageMutator
|
|
fields = ("id", "name", "value")
|
|
|
|
|
|
# region Msg Styles
|
|
|
|
class MessageStyleSerializer(DynamicModelSerializer):
|
|
title_mutator_detail = serializers.SerializerMethodField()
|
|
description_mutator_detail = serializers.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = MessageStyle
|
|
fields = (
|
|
"id",
|
|
"server",
|
|
"name",
|
|
"is_embed",
|
|
"colour",
|
|
"is_hyperlinked",
|
|
"show_author",
|
|
"show_timestamp",
|
|
"show_images",
|
|
"fetch_images",
|
|
"title_mutator",
|
|
"title_mutator_detail",
|
|
"description_mutator",
|
|
"description_mutator_detail",
|
|
"auto_created"
|
|
)
|
|
read_only_fields = ("auto_created",)
|
|
|
|
def get_title_mutator_detail(self, obj: MessageStyle):
|
|
request = self.context.get("request")
|
|
if request and request.method == "GET":
|
|
return MessageMutatorSerializer(obj.title_mutator).data
|
|
return {}
|
|
|
|
def get_description_mutator_detail(self, obj: MessageStyle):
|
|
request = self.context.get("request")
|
|
if request and request.method == "GET":
|
|
return MessageMutatorSerializer(obj.description_mutator).data
|
|
return {}
|
|
|
|
|
|
# region Subscriptions
|
|
|
|
class DiscordChannelField(serializers.PrimaryKeyRelatedField):
|
|
def to_internal_value(self, data):
|
|
try:
|
|
data = int(data)
|
|
except (TypeError, ValueError):
|
|
self.fail("invalid", pk_value=data)
|
|
|
|
return super().to_internal_value(data)
|
|
|
|
def to_representation(self, value):
|
|
return str(value.pk)
|
|
|
|
|
|
class NestedDiscordChannelSerializer(DynamicModelSerializer):
|
|
|
|
class Meta:
|
|
model = DiscordChannel
|
|
fields = ("id", "name", "is_nsfw")
|
|
|
|
|
|
class SubscriptionSerializer(DynamicModelSerializer):
|
|
filters = serializers.PrimaryKeyRelatedField(
|
|
queryset=ContentFilter.objects.all(),
|
|
many=True
|
|
)
|
|
channels = DiscordChannelField(
|
|
queryset=DiscordChannel.objects.all(),
|
|
many=True,
|
|
required=True,
|
|
allow_empty=False
|
|
)
|
|
channels_detail = serializers.SerializerMethodField()
|
|
filters_detail = serializers.SerializerMethodField()
|
|
message_style = serializers.PrimaryKeyRelatedField(
|
|
queryset=MessageStyle.objects.all(),
|
|
required=True,
|
|
allow_null=False
|
|
)
|
|
message_style_detail = serializers.SerializerMethodField()
|
|
|
|
class Meta:
|
|
model = Subscription
|
|
fields = (
|
|
"id",
|
|
"server",
|
|
"name",
|
|
"url",
|
|
"created_at",
|
|
"updated_at",
|
|
"extra_notes",
|
|
"active",
|
|
"publish_threshold",
|
|
"channels",
|
|
"channels_detail",
|
|
"filters",
|
|
"filters_detail",
|
|
"message_style",
|
|
"message_style_detail"
|
|
)
|
|
|
|
def get_channels_detail(self, obj: Subscription):
|
|
request = self.context.get("request")
|
|
if request.method == "GET":
|
|
return NestedDiscordChannelSerializer(obj.channels.all(), many=True).data
|
|
return []
|
|
|
|
def get_filters_detail(self, obj: Subscription):
|
|
request = self.context.get("request")
|
|
if request.method == "GET":
|
|
return ContentFilterSerializer(obj.filters.all(), many=True).data
|
|
return []
|
|
|
|
def get_message_style_detail(self, obj: Subscription):
|
|
request = self.context.get("request")
|
|
if request.method == "GET":
|
|
return MessageStyleSerializer(obj.message_style).data
|
|
return {}
|
|
|
|
def validate(self, data):
|
|
server = data.get("server") or self.context.get("server")
|
|
if not server:
|
|
return data
|
|
|
|
# Enforce max subscriptions per server
|
|
subscriptions_count = Subscription.objects.filter(server=server).count();
|
|
if subscriptions_count >= settings.MAX_SUBSCRIPTIONS_PER_SERVER:
|
|
raise serializers.ValidationError(
|
|
f"Cannot create more than {settings.MAX_SUBSCRIPTIONS_PER_SERVER} subscriptions for this server."
|
|
)
|
|
|
|
# Prevent using filters from a different server
|
|
selected_filters = data.get("filters", [])
|
|
valid_filter_ids = ContentFilter.objects.filter(server=server).values_list("id", flat=True)
|
|
if any(fltr.id not in valid_filter_ids for fltr in selected_filters):
|
|
raise serializers.ValidationError(
|
|
{"filters": "All filters must belong to the specified server."}
|
|
)
|
|
|
|
# Prevent using message styles from a different server
|
|
message_style = data.get("message_style")
|
|
if message_style and message_style.server != server:
|
|
raise serializers.ValidationError(
|
|
{"message_style": "Message style must belong to the specified server."}
|
|
)
|
|
|
|
# Prevent assigning more channels than permitted
|
|
channels = data.get("channels")
|
|
if len(channels) > settings.MAX_CHANNELS_PER_SUBSCRIPTION:
|
|
raise serializers.ValidationError(
|
|
{"channels": "Please select 5 channels or fewer."}
|
|
)
|
|
|
|
return data
|
|
|
|
|
|
# region Content
|
|
|
|
class ContentSerializer(DynamicModelSerializer):
|
|
class Meta:
|
|
model = Content
|
|
fields = (
|
|
"id",
|
|
"subscription",
|
|
"item_id",
|
|
"item_guid",
|
|
"item_url",
|
|
"item_title",
|
|
"item_description",
|
|
"item_content_hash",
|
|
"item_image_url",
|
|
"item_thumbnail_url",
|
|
"item_published",
|
|
"item_author",
|
|
"item_author_url",
|
|
"item_feed_title",
|
|
"item_feed_url",
|
|
"blocked"
|
|
)
|