Refactor of tags queries using djorm-pgarray queries

remotes/origin/enhancement/email-actions
Jesús Espino 2014-12-02 16:42:25 +01:00
parent 50ad5046e6
commit af1450ccf1
5 changed files with 8 additions and 117 deletions

View File

@ -210,12 +210,16 @@ class TagsFilter(FilterBackend):
self.filter_name = filter_name self.filter_name = filter_name
def _get_tags_queryparams(self, params): def _get_tags_queryparams(self, params):
return params.get(self.filter_name, "") tags = params.get(self.filter_name, None)
if tags:
return tags.split(",")
return None
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
query_tags = self._get_tags_queryparams(request.QUERY_PARAMS) query_tags = self._get_tags_queryparams(request.QUERY_PARAMS)
if query_tags: if query_tags:
queryset = tags.filter(queryset, contains=query_tags) queryset = queryset.filter(tags__contains=query_tags)
return super().filter_queryset(request, queryset, view) return super().filter_queryset(request, queryset, view)

View File

@ -29,94 +29,3 @@ class TaggedMixin(models.Model):
class Meta: class Meta:
abstract = True abstract = True
def get_queryset_table(queryset):
"""Return queryset model's table name"""
return queryset.model._meta.db_table
def _filter_bin(queryset, value, operator):
"""tags <operator> <value>"""
if not isinstance(value, str):
value = ",".join(value)
sql = "{table_name}.tags {operator} string_to_array(%s, ',')"
where_clause = sql.format(table_name=get_queryset_table(queryset), operator=operator)
queryset = queryset.extra(where=[where_clause], params=[value])
return queryset
_filter_contains = partial(_filter_bin, operator="@>")
_filter_contained_by = partial(_filter_bin, operator="<@")
_filter_overlap = partial(_filter_bin, operator="&&")
def _filter_index(queryset, index, value):
"""tags[<index>] == <value>"""
sql = "{table_name}.tags[{index}] = %s"
where_clause = sql.format(table_name=get_queryset_table(queryset), index=index)
queryset = queryset.extra(where=[where_clause], params=[value])
return queryset
def _filter_len(queryset, value):
"""len(tags) == <value>"""
sql = "array_length({table_name}.tags, 1) = %s"
where_clause = sql.format(table_name=get_queryset_table(queryset))
queryset = queryset.extra(where=[where_clause], params=[value])
return queryset
def _filter_len_operator(queryset, value, operator):
"""len(tags) <operator> <value>"""
operator = {"gt": ">", "lt": "<", "gte": ">=", "lte": "<="}[operator]
sql = "array_length({table_name}.tags, 1) {operator} %s"
where_clause = sql.format(table_name=get_queryset_table(queryset), operator=operator)
queryset = queryset.extra(where=[where_clause], params=[value])
return queryset
def _filter_index_operator(queryset, value, operator):
"""tags[<operator>] == value"""
index = int(operator) + 1
sql = "{table_name}.tags[{index}] = %s"
where_clause = sql.format(table_name=get_queryset_table(queryset), index=index)
queryset = queryset.extra(where=[where_clause], params=[value])
return queryset
def _tags_filter(**filters_map):
filter_re = re.compile(r"""(?:(len__)(gte|lte|lt|gt)
|
(index__)(\d+))""", re.VERBOSE)
def get_filter(filter_name, strict=False):
return filters_map[filter_name] if strict else filters_map.get(filter_name)
def get_filter_matching(filter_name):
match = filter_re.search(filter_name)
filter_name, operator = (group for group in match.groups() if group)
return partial(get_filter(filter_name, strict=True), operator=operator)
def tags_filter(model_or_qs, **filters):
"Filter a queryset but adding support to filters that work with postgresql array fields"
if hasattr(model_or_qs, "_meta"):
qs = model_or_qs._default_manager.get_queryset()
else:
qs = model_or_qs
for filter_name, filter_value in filters.items():
try:
filter = get_filter(filter_name) or get_filter_matching(filter_name)
except (LookupError, AttributeError):
qs = qs.filter(**{filter_name: filter_value})
else:
qs = filter(queryset=qs, value=filter_value)
return qs
return tags_filter
filter = _tags_filter(contains=_filter_contains,
contained_by=_filter_contained_by,
overlap=_filter_overlap,
len=_filter_len,
len__=_filter_len_operator,
index__=_filter_index_operator)

View File

@ -79,7 +79,7 @@ class IssuesFilter(filters.FilterBackend):
filterdata = self._prepare_filters_data(request) filterdata = self._prepare_filters_data(request)
if "tags" in filterdata: if "tags" in filterdata:
queryset = tags.filter(queryset, contains=filterdata["tags"]) queryset = queryset.filter(tags__contains=filterdata["tags"])
for name, value in filter(lambda x: x[0] != "tags", filterdata.items()): for name, value in filter(lambda x: x[0] != "tags", filterdata.items()):
if None in value: if None in value:

View File

@ -21,7 +21,6 @@ import pytest
from taiga.projects.userstories.models import UserStory from taiga.projects.userstories.models import UserStory
from taiga.projects.issues.models import Issue from taiga.projects.issues.models import Issue
from taiga.base import tags
from taiga.base import neighbors as n from taiga.base import neighbors as n
from .. import factories as f from .. import factories as f
@ -58,7 +57,7 @@ class TestUserStories:
us1 = f.UserStoryFactory.create(project=project, tags=tag_names) us1 = f.UserStoryFactory.create(project=project, tags=tag_names)
us2 = f.UserStoryFactory.create(project=project, tags=tag_names) us2 = f.UserStoryFactory.create(project=project, tags=tag_names)
test_user_stories = tags.filter(UserStory.objects.get_queryset(), contains=tag_names) test_user_stories = UserStory.objects.get_queryset().filter(tags__contains=tag_names)
neighbors = n.get_neighbors(us1, results_set=test_user_stories) neighbors = n.get_neighbors(us1, results_set=test_user_stories)

View File

@ -1,21 +0,0 @@
import pytest
pytestmark = pytest.mark.django_db
from taiga.base import tags
from tests.models import TaggedModel
def test_tags():
tags1 = TaggedModel.objects.create(tags=["foo", "bar"])
tags2 = TaggedModel.objects.create(tags=["foo"])
assert list(tags.filter(TaggedModel, contains=["foo"])) == [tags1, tags2]
assert list(tags.filter(TaggedModel, contained_by=["foo"])) == [tags2]
assert list(tags.filter(TaggedModel, overlap=["bar"])) == [tags1]
assert list(tags.filter(TaggedModel, len=2)) == [tags1]
assert list(tags.filter(TaggedModel, len__gte=1)) == [tags1, tags2]
assert list(tags.filter(TaggedModel, len__lt=2)) == [tags2]
assert list(tags.filter(TaggedModel, index__1="bar")) == [tags1]
assert list(tags.filter(TaggedModel, index__1="bar", id__isnull=False)) == [tags1]