diff --git a/taiga/projects/userstories/validators.py b/taiga/projects/userstories/validators.py index 9ff6b2f3..9f6780c8 100644 --- a/taiga/projects/userstories/validators.py +++ b/taiga/projects/userstories/validators.py @@ -20,16 +20,17 @@ from django.utils.translation import ugettext as _ from taiga.base.api import serializers from taiga.base.api import validators -from taiga.base.api.utils import get_object_or_404 from taiga.base.exceptions import ValidationError from taiga.base.fields import PgArrayField from taiga.base.fields import PickledObjectField -from taiga.projects.milestones.validators import MilestoneExistsValidator -from taiga.projects.models import Project +from taiga.projects.milestones.models import Milestone +from taiga.projects.models import UserStoryStatus from taiga.projects.notifications.mixins import EditableWatchedResourceSerializer from taiga.projects.notifications.validators import WatchersValidator from taiga.projects.tagging.fields import TagsAndTagsColorsField -from taiga.projects.validators import ProjectExistsValidator, UserStoryStatusExistsValidator +from taiga.projects.userstories.models import UserStory +from taiga.projects.validators import ProjectExistsValidator +from taiga.projects.validators import UserStoryStatusExistsValidator from . import models @@ -67,12 +68,22 @@ class UserStoryValidator(WatchersValidator, EditableWatchedResourceSerializer, v read_only_fields = ('id', 'ref', 'created_date', 'modified_date', 'owner') -class UserStoriesBulkValidator(ProjectExistsValidator, UserStoryStatusExistsValidator, - validators.Validator): +class UserStoriesBulkValidator(ProjectExistsValidator, validators.Validator): project_id = serializers.IntegerField() status_id = serializers.IntegerField(required=False) bulk_stories = serializers.CharField() + def validate_status_id(self, attrs, source): + filters = {"project__id": attrs["project_id"]} + if source in attrs: + filters["id"] = attrs[source] + + if not UserStoryStatus.objects.filter(**filters).exists(): + raise ValidationError(_("Invalid user story status id. The status must belong to " + "the same project.")) + + return attrs + # Order bulk validators @@ -88,20 +99,42 @@ class UpdateUserStoriesOrderBulkValidator(ProjectExistsValidator, UserStoryStatu milestone_id = serializers.IntegerField(required=False) bulk_stories = _UserStoryOrderBulkValidator(many=True) - def validate(self, data): - filters = {"project__id": data["project_id"]} - if "status_id" in data: - filters["status__id"] = data["status_id"] - if "milestone_id" in data: - filters["milestone__id"] = data["milestone_id"] + def validate_status_id(self, attrs, source): + filters = {"project__id": attrs["project_id"]} + if source in attrs: + filters["id"] = attrs[source] - filters["id__in"] = [us["us_id"] for us in data["bulk_stories"]] + if not UserStoryStatus.objects.filter(**filters).exists(): + raise ValidationError(_("Invalid user story status id. The status must belong " + "to the same project.")) + + return attrs + + def validate_milestone_id(self, attrs, source): + filters = {"project__id": attrs["project_id"]} + if source in attrs: + filters["id"] = attrs[source] + + if not Milestone.objects.filter(**filters).exists(): + raise ValidationError(_("Invalid milestone id. The milistone must belong to the " + "same project.")) + + return attrs + + def validate_bulk_stories(self, attrs, source): + filters = {"project__id": attrs["project_id"]} + if "status_id" in attrs: + filters["status__id"] = attrs["status_id"] + if "milestone_id" in attrs: + filters["milestone__id"] = attrs["milestone_id"] + + filters["id__in"] = [us["us_id"] for us in attrs[source]] if models.UserStory.objects.filter(**filters).count() != len(filters["id__in"]): - raise ValidationError(_("Invalid user story ids. All stories must belong to the same project and, " - "if it exists, to the same status and milestone.")) + raise ValidationError(_("Invalid user story ids. All stories must belong to the same project " + "and, if it exists, to the same status and milestone.")) - return data + return attrs # Milestone bulk validators @@ -111,22 +144,27 @@ class _UserStoryMilestoneBulkValidator(validators.Validator): order = serializers.IntegerField() -class UpdateMilestoneBulkValidator(ProjectExistsValidator, MilestoneExistsValidator, validators.Validator): +class UpdateMilestoneBulkValidator(ProjectExistsValidator, validators.Validator): project_id = serializers.IntegerField() milestone_id = serializers.IntegerField() bulk_stories = _UserStoryMilestoneBulkValidator(many=True) - def validate(self, data): - """ - All the userstories and the milestone are from the same project - """ - user_story_ids = [us["us_id"] for us in data["bulk_stories"]] - project = get_object_or_404(Project, pk=data["project_id"]) + def validate_milestone_id(self, attrs, source): + filters = { + "project__id": attrs["project_id"], + "id": attrs[source] + } + if not Milestone.objects.filter(**filters).exists(): + raise ValidationError(_("The milestone isn't valid for the project")) + return attrs - if project.user_stories.filter(id__in=user_story_ids).count() != len(user_story_ids): + def validate_bulk_stories(self, attrs, source): + filters = { + "project__id": attrs["project_id"], + "id__in": [us["us_id"] for us in attrs[source]] + } + + if UserStory.objects.filter(**filters).count() != len(filters["id__in"]): raise ValidationError(_("All the user stories must be from the same project")) - if project.milestones.filter(id=data["milestone_id"]).count() != 1: - raise ValidationError(_("The milestone isn't valid for the project")) - - return data + return attrs diff --git a/tests/integration/test_userstories.py b/tests/integration/test_userstories.py index e158b802..57a2f520 100644 --- a/tests/integration/test_userstories.py +++ b/tests/integration/test_userstories.py @@ -157,6 +157,24 @@ def test_api_create_in_bulk_with_status(client): assert response.data[0]["status"] == project.default_us_status.id +def test_api_create_in_bulk_with_invalid_status(client): + project = f.create_project() + status = f.UserStoryStatusFactory.create() + f.MembershipFactory.create(project=project, user=project.owner, is_admin=True) + url = reverse("userstories-bulk-create") + data = { + "bulk_stories": "Story #1\nStory #2", + "project_id": project.id, + "status_id": status.id + } + + client.login(project.owner) + response = client.json.post(url, json.dumps(data)) + + assert response.status_code == 400, response.data + assert "status_id" in response.data + + def test_api_update_orders_in_bulk(client): project = f.create_project() f.MembershipFactory.create(project=project, user=project.owner, is_admin=True) @@ -175,13 +193,14 @@ def test_api_update_orders_in_bulk(client): client.login(project.owner) - response1 = client.json.post(url1, json.dumps(data)) - response2 = client.json.post(url2, json.dumps(data)) - response3 = client.json.post(url3, json.dumps(data)) + response = client.json.post(url1, json.dumps(data)) + assert response.status_code == 200, response.data - assert response1.status_code == 200, response1.data - assert response2.status_code == 200, response2.data - assert response3.status_code == 200, response3.data + response = client.json.post(url2, json.dumps(data)) + assert response.status_code == 200, response.data + + response = client.json.post(url3, json.dumps(data)) + assert response.status_code == 200, response.data def test_api_update_orders_in_bulk_invalid_userstories(client): @@ -204,19 +223,24 @@ def test_api_update_orders_in_bulk_invalid_userstories(client): client.login(project.owner) - response1 = client.json.post(url1, json.dumps(data)) - response2 = client.json.post(url2, json.dumps(data)) - response3 = client.json.post(url3, json.dumps(data)) + response = client.json.post(url1, json.dumps(data)) + assert response.status_code == 400, response.data + assert "bulk_stories" in response.data - assert response1.status_code == 400, response1.data - assert response2.status_code == 400, response2.data - assert response3.status_code == 400, response3.data + response = client.json.post(url2, json.dumps(data)) + assert response.status_code == 400, response.data + assert "bulk_stories" in response.data + + response = client.json.post(url3, json.dumps(data)) + assert response.status_code == 400, response.data + assert "bulk_stories" in response.data def test_api_update_orders_in_bulk_invalid_status(client): project = f.create_project() f.MembershipFactory.create(project=project, user=project.owner, is_admin=True) - us1 = f.create_userstory(project=project) + status = f.UserStoryStatusFactory.create() + us1 = f.create_userstory(project=project, status=status) us2 = f.create_userstory(project=project, status=us1.status) us3 = f.create_userstory(project=project) @@ -226,7 +250,7 @@ def test_api_update_orders_in_bulk_invalid_status(client): data = { "project_id": project.id, - "status_id": us1.status.id, + "status_id": status.id, "bulk_stories": [{"us_id": us1.id, "order": 1}, {"us_id": us2.id, "order": 2}, {"us_id": us3.id, "order": 3}] @@ -234,19 +258,26 @@ def test_api_update_orders_in_bulk_invalid_status(client): client.login(project.owner) - response1 = client.json.post(url1, json.dumps(data)) - response2 = client.json.post(url2, json.dumps(data)) - response3 = client.json.post(url3, json.dumps(data)) + response = client.json.post(url1, json.dumps(data)) + assert response.status_code == 400, response.data + assert "status_id" in response.data + assert "bulk_stories" in response.data - assert response1.status_code == 400, response1.data - assert response2.status_code == 400, response2.data - assert response3.status_code == 400, response3.data + response = client.json.post(url2, json.dumps(data)) + assert response.status_code == 400, response.data + assert "status_id" in response.data + assert "bulk_stories" in response.data + + response = client.json.post(url3, json.dumps(data)) + assert response.status_code == 400, response.data + assert "status_id" in response.data + assert "bulk_stories" in response.data def test_api_update_orders_in_bulk_invalid_milestione(client): project = f.create_project() f.MembershipFactory.create(project=project, user=project.owner, is_admin=True) - mil1 = f.MilestoneFactory.create(project=project) + mil1 = f.MilestoneFactory.create() us1 = f.create_userstory(project=project, milestone=mil1) us2 = f.create_userstory(project=project, milestone=mil1) us3 = f.create_userstory(project=project) @@ -265,13 +296,20 @@ def test_api_update_orders_in_bulk_invalid_milestione(client): client.login(project.owner) - response1 = client.json.post(url1, json.dumps(data)) - response2 = client.json.post(url2, json.dumps(data)) - response3 = client.json.post(url3, json.dumps(data)) + response = client.json.post(url1, json.dumps(data)) + assert response.status_code == 400, response.data + assert "milestone_id" in response.data + assert "bulk_stories" in response.data - assert response1.status_code == 400, response1.data - assert response2.status_code == 400, response2.data - assert response3.status_code == 400, response3.data + response = client.json.post(url2, json.dumps(data)) + assert response.status_code == 400, response.data + assert "milestone_id" in response.data + assert "bulk_stories" in response.data + + response = client.json.post(url3, json.dumps(data)) + assert response.status_code == 400, response.data + assert "milestone_id" in response.data + assert "bulk_stories" in response.data def test_api_update_milestone_in_bulk(client): @@ -322,7 +360,7 @@ def test_api_update_milestone_in_bulk_invalid_milestone(client): response = client.json.post(url, json.dumps(data)) assert response.status_code == 400 - assert len(response.data["non_field_errors"]) == 1 + assert "milestone_id" in response.data def test_api_update_milestone_in_bulk_invalid_userstories(client): @@ -344,7 +382,7 @@ def test_api_update_milestone_in_bulk_invalid_userstories(client): response = client.json.post(url, json.dumps(data)) assert response.status_code == 400 - assert len(response.data["non_field_errors"]) == 1 + assert "bulk_stories" in response.data def test_update_userstory_points(client):