diff --git a/apps/api/views.py b/apps/api/views.py index c8d647e..12aec59 100644 --- a/apps/api/views.py +++ b/apps/api/views.py @@ -3,6 +3,7 @@ import logging from django.db.utils import IntegrityError +from django.core.exceptions import ValidationError from django_filters import rest_framework as rest_filters from rest_framework import status, permissions, filters, generics from rest_framework.response import Response @@ -61,7 +62,7 @@ class Subscription_ListView(generics.ListCreateAPIView): self.perform_create(serializer) except IntegrityError: return Response( - {"detail": "Subscription name must be unique"}, + {"detail": "Duplicate or limit reached"}, status=status.HTTP_409_CONFLICT, exception=True ) @@ -122,7 +123,7 @@ class SubscriptionChannel_ListView(generics.ListCreateAPIView): except IntegrityError as exc: log.error(exc) return Response( - {"detail": "Duplicates not allowed"}, + {"detail": "Duplicate or limit reached"}, status=status.HTTP_409_CONFLICT, exception=True ) @@ -158,7 +159,7 @@ class SubscriptionChannel_DetailView(generics.RetrieveDestroyAPIView): self.perform_create(serializer) except IntegrityError: return Response( - {"detail": "Channel must be unique"}, + {"detail": "Duplicate or limit reached"}, status=status.HTTP_409_CONFLICT, exception=True ) diff --git a/apps/home/models.py b/apps/home/models.py index eab1cf5..aebd8f3 100644 --- a/apps/home/models.py +++ b/apps/home/models.py @@ -7,6 +7,7 @@ from pathlib import Path from django.db import models from django.utils import timezone from django.dispatch import receiver +from django.db.utils import IntegrityError from django.db.models.signals import pre_save from django.utils.translation import gettext_lazy as _ from django.core.files.storage import FileSystemStorage @@ -111,6 +112,12 @@ class Subscription(models.Model): return SubscriptionChannel.objects.filter(subscription=self) + def save(self, *args, **kwargs): + if Subscription.objects.filter(server=self.server).count() >= 1: + raise IntegrityError(f"Subscription limit reached for server '{self.server}'") + + super().save(*args, **kwargs) + class SubscriptionChannel(models.Model): """ @@ -154,6 +161,13 @@ class SubscriptionChannel(models.Model): def __str__(self): return str(self.id) + def save(self, *args, **kwargs): + if SubscriptionChannel.objects.filter(subscription=self.subscription).count() >= 4: + raise IntegrityError( + f"SubscriptionChannel limit reached for subscription '{self.subscription}'" + ) + + super().save(*args, **kwargs) class TrackedContent(models.Model):