diff --git a/.travis.yml b/.travis.yml index 48cddd83..854db307 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,5 @@ env: + - VERSION=19.3.3 - VERSION=18.12.17 - VERSION=18.12.13 - VERSION=18.10.3 diff --git a/README.rst b/README.rst index a3911c0f..7b230760 100644 --- a/README.rst +++ b/README.rst @@ -42,6 +42,7 @@ Features * Nullable(T) * UUID * Decimal + * IPv4/IPv6 - Query progress information. diff --git a/clickhouse_driver/columns/ipcolumn.py b/clickhouse_driver/columns/ipcolumn.py new file mode 100644 index 00000000..4092c8ce --- /dev/null +++ b/clickhouse_driver/columns/ipcolumn.py @@ -0,0 +1,93 @@ +from ipaddress import IPv4Address, IPv6Address, AddressValueError + +from .. import errors +from ..util import compat +from .exceptions import ColumnTypeMismatchException +from .stringcolumn import ByteFixedString +from .intcolumn import UInt32Column + + +class IPv4Column(UInt32Column): + ch_type = "IPv4" + py_types = compat.string_types + (IPv4Address, int) + + def __init__(self, types_check=False, **kwargs): + # UIntColumn overrides before_write_item and check_item + # in its __init__ when types_check is True so we force + # __init__ without it then add the appropriate check method for IPv4 + super(UInt32Column, self).__init__(types_check=False, **kwargs) + + self.types_check_enabled = types_check + if types_check: + + def check_item(value): + if isinstance(value, int) and value < 0: + raise ColumnTypeMismatchException(value) + + if not isinstance(value, IPv4Address): + try: + value = IPv4Address(value) + except AddressValueError: + # Cannot parse input in a valid IPv4 + raise ColumnTypeMismatchException(value) + + self.check_item = check_item + + def before_write_item(self, value): + # allow Ipv4 in integer, string or IPv4Address object + try: + if isinstance(value, int): + return value + + if not isinstance(value, IPv4Address): + value = IPv4Address(value) + + return int(value) + except AddressValueError: + raise errors.CannotParseDomainError( + "Cannot parse IPv4 '{}'".format(value) + ) + + def after_read_item(self, value): + return IPv4Address(value) + + +class IPv6Column(ByteFixedString): + ch_type = "IPv6" + py_types = compat.string_types + (IPv6Address, bytes) + + def __init__(self, types_check=False, **kwargs): + super(IPv6Column, self).__init__(16, types_check=types_check, **kwargs) + + if types_check: + + def check_item(value): + if isinstance(value, bytes) and len(value) != 16: + raise ColumnTypeMismatchException(value) + + if not isinstance(value, IPv6Address): + try: + value = IPv6Address(value) + except AddressValueError: + # Cannot parse input in a valid IPv6 + raise ColumnTypeMismatchException(value) + + self.check_item = check_item + + def before_write_item(self, value): + # allow Ipv6 in bytes or python IPv6Address + # this is raw bytes (not encoded) in order to fit FixedString(16) + try: + if isinstance(value, bytes): + return value + + if not isinstance(value, IPv6Address): + value = IPv6Address(value) + return value.packed + except AddressValueError: + raise errors.CannotParseDomainError( + "Cannot parse IPv6 '{}'".format(value) + ) + + def after_read_item(self, value): + return IPv6Address(value) diff --git a/clickhouse_driver/columns/service.py b/clickhouse_driver/columns/service.py index cde0de83..62ad9b8f 100644 --- a/clickhouse_driver/columns/service.py +++ b/clickhouse_driver/columns/service.py @@ -20,6 +20,7 @@ IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn, IntervalSecondColumn ) +from .ipcolumn import IPv4Column, IPv6Column column_by_type = {c.ch_type: c for c in [ @@ -29,7 +30,7 @@ NothingColumn, NullColumn, UUIDColumn, IntervalYearColumn, IntervalMonthColumn, IntervalWeekColumn, IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn, - IntervalSecondColumn + IntervalSecondColumn, IPv4Column, IPv6Column ]} diff --git a/clickhouse_driver/errors.py b/clickhouse_driver/errors.py index d4c464a5..86b59d38 100644 --- a/clickhouse_driver/errors.py +++ b/clickhouse_driver/errors.py @@ -365,6 +365,7 @@ class ErrorCodes(object): FUNCTION_THROW_IF_VALUE_IS_NON_ZERO = 395 TOO_MANY_ROWS_OR_BYTES = 396 QUERY_IS_NOT_SUPPORTED_IN_MATERIALIZED_VIEW = 397 + CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING = 441 KEEPER_EXCEPTION = 999 POCO_EXCEPTION = 1000 @@ -466,3 +467,7 @@ class UnknownPacketFromServerError(Error): class CannotParseUuidError(Error): code = ErrorCodes.CANNOT_PARSE_UUID + + +class CannotParseDomainError(Error): + code = ErrorCodes.CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING diff --git a/setup.py b/setup.py index 1d5e01ec..c3708f01 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ install_requires = ['pytz'] if not PY34: install_requires.append('enum34') + install_requires.append('ipaddress') def read_version(): diff --git a/tests/columns/test_ip.py b/tests/columns/test_ip.py new file mode 100644 index 00000000..c70ae410 --- /dev/null +++ b/tests/columns/test_ip.py @@ -0,0 +1,228 @@ +from __future__ import unicode_literals + +from clickhouse_driver import errors +from ipaddress import IPv6Address, IPv4Address + +from tests.testcase import BaseTestCase +from tests.util import require_server_version + + +class IPv4TestCase(BaseTestCase): + @require_server_version(19, 3, 3) + def test_simple(self): + with self.create_table('a IPv4'): + data = [ + (IPv4Address("10.0.0.1"),), + (IPv4Address("192.168.253.42"),) + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '10.0.0.1\n' + '192.168.253.42\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv4Address("10.0.0.1"),), + (IPv4Address("192.168.253.42"),) + ]) + + @require_server_version(19, 3, 3) + def test_from_int(self): + with self.create_table('a IPv4'): + data = [ + (167772161,), + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '10.0.0.1\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv4Address("10.0.0.1"),), + ]) + + @require_server_version(19, 3, 3) + def test_from_str(self): + with self.create_table('a IPv4'): + data = [ + ("10.0.0.1",), + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '10.0.0.1\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv4Address("10.0.0.1"),), + ]) + + @require_server_version(19, 3, 3) + def test_type_mismatch(self): + data = [(1025.2147,)] + with self.create_table('a IPv4'): + with self.assertRaises(errors.TypeMismatchError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + @require_server_version(19, 3, 3) + def test_bad_ipv4(self): + data = [('985.512.12.0',)] + with self.create_table('a IPv4'): + with self.assertRaises(errors.CannotParseDomainError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + @require_server_version(19, 3, 3) + def test_bad_ipv4_with_type_check(self): + data = [('985.512.12.0',)] + with self.create_table('a IPv4'): + with self.assertRaises(errors.TypeMismatchError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + @require_server_version(19, 3, 3) + def test_nullable(self): + with self.create_table('a Nullable(IPv4)'): + data = [(IPv4Address('10.10.10.10'),), (None,)] + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, + '10.10.10.10\n\\N\n') + + inserted = self.client.execute(query) + self.assertEqual(inserted, data) + + +class IPv6TestCase(BaseTestCase): + @require_server_version(19, 3, 3) + def test_simple(self): + with self.create_table('a IPv6'): + data = [ + (IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),), + (IPv6Address('a22:cc64:cf47:1653:4976:3c0c:ff8d:417c'),), + (IPv6Address('12ff:0000:0000:0000:0000:0000:0000:0001'),) + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n' + 'a22:cc64:cf47:1653:4976:3c0c:ff8d:417c\n' + '12ff::1\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),), + (IPv6Address('a22:cc64:cf47:1653:4976:3c0c:ff8d:417c'),), + (IPv6Address('12ff::1'),) + ]) + + @require_server_version(19, 3, 3) + def test_from_str(self): + with self.create_table('a IPv6'): + data = [ + ('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae',), + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),), + ]) + + @require_server_version(19, 3, 3) + def test_from_bytes(self): + with self.create_table('a IPv6'): + data = [ + (b"y\xf4\xe6\x98E\xde\xa5\x9b'e(\xe3\x8d:5\xae",), + ] + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, ( + '79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n' + )) + inserted = self.client.execute(query) + self.assertEqual(inserted, [ + (IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),), + ]) + + @require_server_version(19, 3, 3) + def test_type_mismatch(self): + data = [(1025.2147,)] + with self.create_table('a IPv6'): + with self.assertRaises(errors.TypeMismatchError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + @require_server_version(19, 3, 3) + def test_bad_ipv6(self): + data = [("ghjk:e698:45de:a59b:2765:28e3:8d3a:zzzz",)] + with self.create_table('a IPv6'): + with self.assertRaises(errors.CannotParseDomainError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + @require_server_version(19, 3, 3) + def test_bad_ipv6_with_type_check(self): + data = [("ghjk:e698:45de:a59b:2765:28e3:8d3a:zzzz",)] + with self.create_table('a IPv6'): + with self.assertRaises(errors.TypeMismatchError): + self.client.execute( + 'INSERT INTO test (a) VALUES', data, types_check=True + ) + + @require_server_version(19, 3, 3) + def test_nullable(self): + with self.create_table('a Nullable(IPv6)'): + data = [ + (IPv6Address('79f4:e698:45de:a59b:2765:28e3:8d3a:35ae'),), + (None,)] + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, + '79f4:e698:45de:a59b:2765:28e3:8d3a:35ae\n\\N\n') + + inserted = self.client.execute(query) + self.assertEqual(inserted, data)