2024-01-23 22:53:19 +00:00

106 lines
3.8 KiB
Python

# -*- encoding: utf-8 -*-
import logging
from django.apps import apps
from django.conf import settings
from rest_framework import serializers
log = logging.getLogger(__name__)
class DynamicModelSerializer(serializers.ModelSerializer):
"""For use with GET requests, to specify which fields to include or exclude
Mimics some graphql functionality.
Usage: Inherit your ModelSerializer with this class. Add "only_fields" or
"exclude_fields" to the query parameters of your GET request.
This also works with nested foreign keys, for example:
?only_fields=name,age&company__only_fields=id,name
Some more examples:
?only_fields=company,name&company__exclude_fields=name
?exclude_fields=name&company__only_fields=id
?company__exclude_fields=name
Note: the Foreign Key serializer must also inherit from this class
"""
def only_keep_fields(self, fields_to_keep):
fields_to_keep = set(fields_to_keep.split(","))
all_fields = set(self.fields.keys())
for field in all_fields - fields_to_keep:
self.fields.pop(field, None)
def exclude_fields(self, fields_to_exclude):
fields_to_exclude = fields_to_exclude.split(",")
for field in fields_to_exclude:
self.fields.pop(field, None)
def remove_unwanted_fields(self, dynamic_params):
if fields_to_keep := dynamic_params.pop("only_fields", None):
self.only_keep_fields(fields_to_keep)
if fields_to_exclude := dynamic_params.pop("exclude_fields", None):
self.exclude_fields(fields_to_exclude)
def get_or_create_dynamic_params(self, child):
if "dynamic_params" not in self.fields[child]._context:
self.fields[child]._context.update({"dynamic_params": {}})
return self.fields[child]._context["dynamic_params"]
@staticmethod
def split_param(dynamic_param):
crumbs = dynamic_param.split("__")
return crumbs[0], "__".join(crumbs[1:]) if len(crumbs) > 1 else None
def set_dynamic_params_for_children(self, dynamic_params):
for param, fields in dynamic_params.items():
child, child_dynamic_param = self.split_param(param)
if child in set(self.fields.keys()):
dynamic_params = self.get_or_create_dynamic_params(child)
dynamic_params.update({child_dynamic_param: fields})
@staticmethod
def is_param_dynamic(p):
return p.endswith("only_fields") or p.endswith("exclude_fields")
def get_dynamic_params_for_root(self, request):
query_params = request.query_params.items()
return {k: v for k, v in query_params if self.is_param_dynamic(k)}
def get_dynamic_params(self):
"""
When dynamic params get passed down in set_context_for_children
If the child is a subclass of ListSerializer (has many=True)
The context must be fetched from ListSerializer Class
"""
if isinstance(self.parent, serializers.ListSerializer):
return self.parent._context.get("dynamic_params", {})
return self._context.get("dynamic_params", {})
def __init__(self, *args, **kwargs):
request = kwargs.get("context", {}).get("request")
super().__init__(*args, **kwargs)
is_root = bool(request)
if is_root:
if request.method != "GET":
return
dynamic_params = self.get_dynamic_params_for_root(request)
self._context.update({"dynamic_params": dynamic_params})
def to_representation(self, *args, **kwargs):
if dynamic_params := self.get_dynamic_params().copy():
self.remove_unwanted_fields(dynamic_params)
self.set_dynamic_params_for_children(dynamic_params)
return super().to_representation(*args, **kwargs)
class Meta:
abstract = True