From 500eb920e4498c0573ceadd22f65117fd55fd075 Mon Sep 17 00:00:00 2001 From: Manuel Hermann Date: Tue, 7 Aug 2012 17:04:03 +0200 Subject: [PATCH 1/2] Use info from getlasterror whether a document has been updated or created. --- mongoengine/document.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index f8bf769db..bb5a60fce 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -208,11 +208,20 @@ def save(self, safe=True, force_insert=False, validate=True, write_options=None, actual_key = self._db_field_map.get(k, k) select_dict[actual_key] = doc[actual_key] + def is_new_object(last_error): + if last_error is not None: + updated = last_error.get("updatedExisting") + if updated is not None: + return not updated + return created + upsert = self._created if updates: - collection.update(select_dict, {"$set": updates}, upsert=upsert, safe=safe, **write_options) + last_error = collection.update(select_dict, {"$set": updates}, upsert=upsert, safe=safe, **write_options) + created = is_new_object(last_error) if removals: - collection.update(select_dict, {"$unset": removals}, upsert=upsert, safe=safe, **write_options) + last_error = collection.update(select_dict, {"$unset": removals}, upsert=upsert, safe=safe, **write_options) + created = created or is_new_object(last_error) cascade = self._meta.get('cascade', True) if cascade is None else cascade if cascade: From f5ef81c5da0603c5cf58b8b7286e62f6a474a6ff Mon Sep 17 00:00:00 2001 From: helduel Date: Thu, 8 Nov 2012 14:54:32 +0100 Subject: [PATCH 2/2] Add IPFields and IPNetworkFields. --- mongoengine/fields.py | 120 +++++++++++++++++++- mongoengine/queryset.py | 63 +++++++++-- setup.py | 3 +- tests/test_fields.py | 242 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 419 insertions(+), 9 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 8e3cf15a5..20f2148d9 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -14,6 +14,8 @@ from connection import get_db, DEFAULT_CONNECTION_NAME from operator import itemgetter +from IPy import IP + try: from PIL import Image, ImageOps @@ -33,7 +35,8 @@ 'DecimalField', 'ComplexDateTimeField', 'URLField', 'DynamicField', 'GenericReferenceField', 'FileField', 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField', 'ImageField', - 'SequenceField', 'UUIDField', 'GenericEmbeddedDocumentField'] + 'SequenceField', 'UUIDField', 'GenericEmbeddedDocumentField', + 'IPv4Field', 'IPv6Field', 'IPv4NetworkField', 'IPv6NetworkField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -1351,3 +1354,118 @@ def validate(self, value): value = uuid.UUID(value) except Exception, exc: self.error('Could not convert to UUID: %s' % exc) + + +class IPField(BaseField): + """An IP field. + """ + def __init__(self, v=4, **kwargs): + if v not in (4, 6): + raise ValueError("IP version must be 4 or 6") + self.v = v + super(IPField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + value = super(IPField, self).__get__(instance, owner) + if value is not None: + value = IP(value) + return value + + def validate(self, value): + if value.version() != self.v: + self.error("IP version mismatch") + + def to_mongo(self, value): + if self.v == 4: + return IP(value).int() + else: + return IP(value).strHex() + + def to_python(self, value): + return IP(value) + + def prepare_query_value(self, op, value): + if self.v == 4: + return IP(value).int() + return IP(value).strHex() + + +class IPv4Field(IPField): + """An IPv4 field. + """ + def __init__(self, **kwargs): + super(IPv4Field, self).__init__(v=4, **kwargs) + + +class IPv6Field(IPField): + """An IPv6 field. + """ + def __init__(self, **kwargs): + super(IPv6Field, self).__init__(v=6, **kwargs) + + +class IPNetworkField(BaseField): + """An IP network field. + """ + def __init__(self, v=4, **kwargs): + if v not in (4, 6): + raise ValueError("IP version must be 4 or 6") + self.v = v + super(IPNetworkField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + value = super(IPNetworkField, self).__get__(instance, owner) + if value is not None: + value = self.to_python(value) + return value + + def validate(self, value): + if value.version() != self.v: + self.error("IP version mismatch") + + def to_mongo(self, value): + value = IP(value) + if self.v == 4: + return { + "net$prefix": value.prefixlen(), + "net$lower": value[0].int(), + "net$upper": value[-1].int(), + } + return { + "net$prefix": value.prefixlen(), + "net$lower": value[0].strHex(), + "net$upper": value[-1].strHex(), + } + + def to_python(self, value): + if isinstance(value, dict): + value = "%s/%i" % (value["net$lower"], value["net$prefix"]) + return IP(value) + + def prepare_query_value(self, op, value): + if self.v == 4: + value = IP(value).int() + else: + value = IP(value).strHex() + if op == "contains": + value = { + "net$lower": {"$lte": value}, + "net$upper": {"$gte": value}, + } + return value + + +class IPv4NetworkField(IPNetworkField): + """An IPv6 network field. + """ + def __init__(self, **kwargs): + super(IPv4NetworkField, self).__init__(v=4, **kwargs) + + +class IPv6NetworkField(IPNetworkField): + """An IPv6 network field. + """ + def __init__(self, **kwargs): + super(IPv6NetworkField, self).__init__(v=6, **kwargs) + + diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 703b6e5fd..2f153a283 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -8,6 +8,7 @@ import pymongo from bson.code import Code +from IPy import IP from mongoengine import signals @@ -669,6 +670,8 @@ def _translate_field_name(cls, doc_cls, field, sep='.'): def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. """ + from mongoengine.fields import IPNetworkField + from mongoengine.fields import IPField operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists', 'not'] geo_operators = ['within_distance', 'within_spherical_distance', 'within_box', 'within_polygon', 'near', 'near_sphere'] @@ -678,6 +681,18 @@ def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): custom_operators = ['match'] mongo_query = {} + def mongo_query_and(value): + if "$and" in mongo_query: + mongo_query["$and"].append(value) + else: + mongo_query["$and"] = value + + def mongo_query_or(value): + if "$or" in mongo_query: + mongo_query["$or"].append(value) + else: + mongo_query["$or"] = value + for key, value in query.items(): if key == "__raw__": mongo_query.update(value) @@ -726,8 +741,44 @@ def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): value = field else: value = field.prepare_query_value(op, value) + if isinstance(field, IPNetworkField): + if op == "contains": + contains_query = list() + for (k, v) in value.items(): + new_key = ".".join(parts + [k]) + if negate: + if "$lte" in v: + v["$gt"] = v.pop("$lte") + else: + v["$lt"] = v.pop("$gte") + contains_query.append({new_key: v}) + if negate: + mongo_query_or(contains_query) + else: + mongo_query_and(contains_query) + continue + parts.append("net$lower") elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values + if isinstance(value, IP): + if op not in ('in', 'nin'): + raise NotImplemented('%s not implemented for IP fields' % op) + key = '.'.join(parts) + lower = field.prepare_query_value(None, value[0]) + upper = field.prepare_query_value(None, value[-1]) + if op == 'in' and not negate: + value = [ + {key: {"$gte": lower}}, + {key: {"$lte": upper}}, + ] + mongo_query_and(value) + else: + value = [ + {key: {"$lt": lower}}, + {key: {"$gt": upper}}, + ] + mongo_query_or(value) + continue value = [field.prepare_query_value(op, v) for v in value] # if op and op not in match_operators: @@ -757,12 +808,14 @@ def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): elif op not in match_operators: value = {'$' + op: value} - if negate: + if negate and not isinstance(field, IPNetworkField): value = {'$not': value} for i, part in indices: parts.insert(i, part) + key = '.'.join(parts) + if op is None or key not in mongo_query: mongo_query[key] = value elif key in mongo_query: @@ -774,14 +827,10 @@ def _transform_query(cls, _doc_cls=None, _field_operation=False, **query): mongo_query[key] = [mongo_query[key], value] for k, v in mongo_query.items(): - if isinstance(v, list): + if k not in ("$and", "$or") and isinstance(v, list): value = [{k:val} for val in v] - if '$and' in mongo_query.keys(): - mongo_query['$and'].append(value) - else: - mongo_query['$and'] = value + mongo_query_and(value) del mongo_query[k] - return mongo_query def get(self, *q_objs, **query): diff --git a/setup.py b/setup.py index 20f3ea389..05d12bec4 100644 --- a/setup.py +++ b/setup.py @@ -42,11 +42,12 @@ def get_version(version_tuple): maintainer_email="ross.lawley@{nospam}gmail.com", url='http://mongoengine.org/', license='MIT', + require=['IPy'], include_package_data=True, description=DESCRIPTION, long_description=LONG_DESCRIPTION, platforms=['any'], classifiers=CLASSIFIERS, install_requires=['pymongo'], - tests_require=['nose', 'coverage', 'blinker', 'django>=1.3', 'PIL'] + tests_require=['nose', 'coverage', 'blinker', 'django>=1.3', 'PIL', 'IPy'] ) diff --git a/tests/test_fields.py b/tests/test_fields.py index a6eaca434..eb554cd17 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -9,6 +9,8 @@ from bson import Binary from decimal import Decimal +from IPy import IP + from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import _document_registry, NotRegistered @@ -2174,6 +2176,246 @@ class Post(Document): post.comments[1].content = 'here we go' post.validate() + def test_ipv4_field(self): + """Ensure that an IPv4 field works as expected. + """ + class Host(Document): + name = StringField(required=True) + ip = IPv4Field(required=True) + + Host.drop_collection() + + ip1 = "192.168.0.1" + # given as string + host1 = Host(name="foo", ip=ip1) + self.assertEquals(host1.ip, IP(ip1)) + host1 = host1.save() + self.assertEquals(host1.ip, IP(ip1)) + + loaded = Host.objects()[0] + self.assertEquals(loaded.ip, IP(ip1)) + + ip2 = "192.168.0.2" + # given as IP object + host2 = Host(name="bar", ip=IP(ip2)) + host2.save() + + # search with IP object + result = Host.objects(ip=IP(ip1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # lower than + result = Host.objects(ip__lt=IP(ip2)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # greater than + result = Host.objects(ip__gt=IP(ip1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "bar") + + ip3 = "192.168.0.3" + host3 = Host(name="baz", ip=IP(ip3)) + host3.save() + + # lower than equals + result = Host.objects(ip__lte=IP(ip2)).order_by("name") + self.assertEquals(len(result), 2) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "foo") + + # greater than equals + result = Host.objects(ip__gte=IP(ip2)).order_by("name") + self.assertEquals(len(result), 2) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "baz") + + # fail on wrong IP version + host4 = Host(name="fail") + host4.ip = IP("2001:0db8:85a3:08d3:1319:8a2e:0370:7344") + self.assertRaises(ValidationError, host4.validate) + + ip4 = "172.30.1.2" + host4 = Host(name="boo", ip=IP(ip4)) + host4.save() + + net4 = "172.30.0.0/16" + + # in network + result = Host.objects(ip__in=IP(net4)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "boo") + + # nin network + result = Host.objects(ip__nin=IP(net4)).order_by("name") + self.assertEquals(len(result), 3) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "baz") + self.assertEquals(result[2].name, "foo") + + # not in network + result = Host.objects(ip__not__in=IP(net4)).order_by("name") + self.assertEquals(len(result), 3) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "baz") + self.assertEquals(result[2].name, "foo") + + def test_ipv6_field(self): + """Ensure that an IPv6 field works as expected. + """ + class Host(Document): + name = StringField(required=True) + ip = IPv6Field(required=True) + + Host.drop_collection() + + ip1 = "2001:0db8:85a3:08d3:1319:8a2e:0370:0001" + # given as string + host1 = Host(name="foo", ip=ip1) + self.assertEquals(host1.ip, IP(ip1)) + host1 = host1.save() + self.assertEquals(host1.ip, IP(ip1)) + + loaded = Host.objects()[0] + self.assertEquals(loaded.ip, IP(ip1)) + + ip2 = "2001:0db8:85a3:08d3:1319:8a2e:0370:0002" + # given as IP object + host2 = Host(name="bar", ip=IP(ip2)) + host2.save() + + # search with IP object + result = Host.objects(ip=IP(ip1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # lower than + result = Host.objects(ip__lt=IP(ip2)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # greater than + result = Host.objects(ip__gt=IP(ip1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "bar") + + ip3 = "2001:0db8:85a3:08d3:1319:8a2e:0370:0003" + host3 = Host(name="baz", ip=IP(ip3)) + host3.save() + + # lower than equals + result = Host.objects(ip__lte=IP(ip2)).order_by("name") + self.assertEquals(len(result), 2) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "foo") + + # greater than equals + result = Host.objects(ip__gte=IP(ip2)).order_by("name") + self.assertEquals(len(result), 2) + self.assertEquals(result[0].name, "bar") + self.assertEquals(result[1].name, "baz") + + # fail on wrong IP version + host4 = Host(name="fail") + host4.ip = IP("192.168.0.1") + self.assertRaises(ValidationError, host4.validate) + + def test_ipv4_network_field(self): + """Ensure that an IPv4 network field works as expected. + """ + class Site(Document): + name = StringField(required=True) + network = IPv4NetworkField(required=True) + + Site.drop_collection() + + network1 = "192.168.0.0/24" + site1 = Site(name="foo", network=network1) + self.assertEquals(site1.network, IP(network1)) + site1 = site1.save() + self.assertEquals(site1.network, IP(network1)) + + # search with IP object + result = Site.objects(network=IP(network1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + network2 = "172.30.0.0/16" + # add with IP object + site2 = Site(name="bar", network=IP(network2)) + self.assertEquals(site2.network, IP(network2)) + site2 = site2.save() + self.assertEquals(site2.network, IP(network2)) + + # search network + result = Site.objects(network=IP(network1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # search not network + result = Site.objects(network__ne=IP(network2)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # search network by containing IP + result = Site.objects(network__contains=IP("172.30.1.2")) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "bar") + + # search network by NOT containing IP + result = Site.objects(network__not__contains=IP("172.30.1.2")) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + def test_ipv6_network_field(self): + """Ensure that an IPv6 network field works as expected. + """ + class Site(Document): + name = StringField(required=True) + network = IPv6NetworkField(required=True) + + Site.drop_collection() + + network1 = "2001:db8:85a3:1111::/64" + site1 = Site(name="foo", network=network1) + self.assertEquals(site1.network, IP(network1)) + site1 = site1.save() + self.assertEquals(site1.network, IP(network1)) + + # search with IP object + result = Site.objects(network=IP(network1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + network2 = "2001:db8:85a3:2222::/64" + # add with IP object + site2 = Site(name="bar", network=IP(network2)) + self.assertEquals(site2.network, IP(network2)) + site2 = site2.save() + self.assertEquals(site2.network, IP(network2)) + + # search network + result = Site.objects(network=IP(network1)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + # search not network + result = Site.objects(network__ne=IP(network2)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + + v6addr = "2001:db8:85a3:2222::3039" + # search network by containing IP + result = Site.objects(network__contains=IP(v6addr)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "bar") + + # search network by NOT containing IP + result = Site.objects(network__not__contains=IP(v6addr)) + self.assertEquals(len(result), 1) + self.assertEquals(result[0].name, "foo") + if __name__ == '__main__': unittest.main()