# -*- encoding: utf-8 -*- import logging from django.conf import settings from rest_framework import serializers from apps.home.models import ( Server, ContentFilter, MessageMutator, MessageStyle, UniqueContentRule, DiscordChannel, Subscription, Content, ) log = logging.getLogger(__name__) # region Dynamic Model # This DynamicModelSerializer is from a StackOverflow user in an obscure thread. # I wish that I could remember which thread, because god bless that man. class DynamicModelSerializer(serializers.ModelSerializer): """ For use with GET requests, to specify which fields to include or exclude Mimics some graphql functionality. Usage: Inherit your ModelSerializer with this class. Add "only_fields" or "exclude_fields" to the query parameters of your GET request. This also works with nested foreign keys, for example: ?only_fields=name,age&company__only_fields=id,name Some more examples: ?only_fields=company,name&company__exclude_fields=name ?exclude_fields=name&company__only_fields=id ?company__exclude_fields=name Note: the Foreign Key serializer must also inherit from this class """ def only_keep_fields(self, fields_to_keep): fields_to_keep = set(fields_to_keep.split(",")) all_fields = set(self.fields.keys()) for field in all_fields - fields_to_keep: self.fields.pop(field, None) def exclude_fields(self, fields_to_exclude): fields_to_exclude = fields_to_exclude.split(",") for field in fields_to_exclude: self.fields.pop(field, None) def remove_unwanted_fields(self, dynamic_params): if fields_to_keep := dynamic_params.pop("only_fields", None): self.only_keep_fields(fields_to_keep) if fields_to_exclude := dynamic_params.pop("exclude_fields", None): self.exclude_fields(fields_to_exclude) def get_or_create_dynamic_params(self, child): if "dynamic_params" not in self.fields[child]._context: self.fields[child]._context.update({"dynamic_params": {}}) return self.fields[child]._context["dynamic_params"] @staticmethod def split_param(dynamic_param): crumbs = dynamic_param.split("__") return crumbs[0], "__".join(crumbs[1:]) if len(crumbs) > 1 else None def set_dynamic_params_for_children(self, dynamic_params): for param, fields in dynamic_params.items(): child, child_dynamic_param = self.split_param(param) if child in set(self.fields.keys()): dynamic_params = self.get_or_create_dynamic_params(child) dynamic_params.update({child_dynamic_param: fields}) @staticmethod def is_param_dynamic(p): return p.endswith("only_fields") or p.endswith("exclude_fields") def get_dynamic_params_for_root(self, request): query_params = request.query_params.items() return {k: v for k, v in query_params if self.is_param_dynamic(k)} def get_dynamic_params(self): """ When dynamic params get passed down in set_context_for_children If the child is a subclass of ListSerializer (has many=True) The context must be fetched from ListSerializer Class """ if isinstance(self.parent, serializers.ListSerializer): return self.parent._context.get("dynamic_params", {}) return self._context.get("dynamic_params", {}) def __init__(self, *args, **kwargs): request = kwargs.get("context", {}).get("request") super().__init__(*args, **kwargs) is_root = bool(request) if is_root: if request.method != "GET": return dynamic_params = self.get_dynamic_params_for_root(request) self._context.update({"dynamic_params": dynamic_params}) def to_representation(self, *args, **kwargs): if dynamic_params := self.get_dynamic_params().copy(): self.remove_unwanted_fields(dynamic_params) self.set_dynamic_params_for_children(dynamic_params) return super().to_representation(*args, **kwargs) class Meta: abstract = True # region Servers class ServerSerializer(DynamicModelSerializer): id = serializers.CharField() class Meta: model = Server fields = ( "id", "name", "icon_hash", "is_bot_operational", "active" ) # region Filters class ContentFilterSerializer(DynamicModelSerializer): class Meta: model = ContentFilter fields = ( "id", "server", "name", "match", "matching_algorithm", "is_insensitive", "is_whitelist" ) # region Msg Mutators class MessageMutatorSerializer(DynamicModelSerializer): class Meta: model = MessageMutator fields = ("id", "name", "value") # region Msg Styles class MessageStyleSerializer(DynamicModelSerializer): title_mutator_detail = serializers.SerializerMethodField() description_mutator_detail = serializers.SerializerMethodField() class Meta: model = MessageStyle fields = ( "id", "server", "name", "is_embed", "colour", "is_hyperlinked", "show_author", "show_timestamp", "show_images", "fetch_images", "title_mutator", "title_mutator_detail", "description_mutator", "description_mutator_detail", "auto_created" ) read_only_fields = ("auto_created",) def get_title_mutator_detail(self, obj: MessageStyle): request = self.context.get("request") if request and request.method == "GET": return MessageMutatorSerializer(obj.title_mutator).data return {} def get_description_mutator_detail(self, obj: MessageStyle): request = self.context.get("request") if request and request.method == "GET": return MessageMutatorSerializer(obj.description_mutator).data return {} # region Rules class UniqueContentRuleSerializer(DynamicModelSerializer): class Meta: model = UniqueContentRule fields = ("id", "name", "value") # region Subscriptions class DiscordChannelField(serializers.PrimaryKeyRelatedField): def to_internal_value(self, data): try: data = int(data) except (TypeError, ValueError): self.fail("invalid", pk_value=data) return super().to_internal_value(data) def to_representation(self, value): return str(value.pk) class NestedDiscordChannelSerializer(DynamicModelSerializer): class Meta: model = DiscordChannel fields = ("id", "name", "is_nsfw") class SubscriptionSerializer(DynamicModelSerializer): filters = serializers.PrimaryKeyRelatedField( queryset=ContentFilter.objects.all(), many=True ) channels = DiscordChannelField( queryset=DiscordChannel.objects.all(), many=True, required=True, allow_empty=False ) channels_detail = serializers.SerializerMethodField() filters_detail = serializers.SerializerMethodField() message_style = serializers.PrimaryKeyRelatedField( queryset=MessageStyle.objects.all(), required=True, allow_null=False ) message_style_detail = serializers.SerializerMethodField() unique_rules_detail = serializers.SerializerMethodField() class Meta: model = Subscription fields = ( "id", "server", "name", "url", "created_at", "updated_at", "extra_notes", "active", "publish_threshold", "channels", "channels_detail", "filters", "filters_detail", "message_style", "message_style_detail", "unique_rules", "unique_rules_detail" ) def get_channels_detail(self, obj: Subscription): request = self.context.get("request") if request.method == "GET": return NestedDiscordChannelSerializer(obj.channels.all(), many=True).data return [] def get_filters_detail(self, obj: Subscription): request = self.context.get("request") if request.method == "GET": return ContentFilterSerializer(obj.filters.all(), many=True).data return [] def get_message_style_detail(self, obj: Subscription): request = self.context.get("request") if request.method == "GET": return MessageStyleSerializer(obj.message_style).data return {} def get_unique_rules_detail(self, obj: Subscription): request = self.context.get("request") if request.method == "GET": return UniqueContentRuleSerializer(obj.unique_rules.all(), many=True).data return [] def validate(self, data): server = data.get("server") or self.context.get("server") if not server: return data # Enforce max subscriptions per server subscriptions_count = Subscription.objects.filter(server=server).count(); if subscriptions_count >= settings.MAX_SUBSCRIPTIONS_PER_SERVER: raise serializers.ValidationError( f"Cannot create more than {settings.MAX_SUBSCRIPTIONS_PER_SERVER} subscriptions for this server." ) # Prevent using filters from a different server selected_filters = data.get("filters", []) valid_filter_ids = ContentFilter.objects.filter(server=server).values_list("id", flat=True) if any(fltr.id not in valid_filter_ids for fltr in selected_filters): raise serializers.ValidationError( {"filters": "All filters must belong to the specified server."} ) # Prevent using message styles from a different server message_style = data.get("message_style") if message_style and message_style.server != server: raise serializers.ValidationError( {"message_style": "Message style must belong to the specified server."} ) # Prevent assigning more channels than permitted channels = data.get("channels") if len(channels) > settings.MAX_CHANNELS_PER_SUBSCRIPTION: raise serializers.ValidationError( {"channels": "Please select 5 channels or fewer."} ) return data # region Content class ContentSerializer(DynamicModelSerializer): class Meta: model = Content fields = ( "id", "subscription", "item_id", "item_guid", "item_url", "item_title", "item_content_hash" )