diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index 25584506f..2a17ac3e0 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -1,4 +1,4 @@ from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT +from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ -__all__ = ["TINYINT"] +__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"] diff --git a/src/databricks/sqlalchemy/_types.py b/src/databricks/sqlalchemy/_types.py index 133abb299..1dc6d9e97 100644 --- a/src/databricks/sqlalchemy/_types.py +++ b/src/databricks/sqlalchemy/_types.py @@ -1,12 +1,29 @@ +from datetime import datetime, time, timezone +from itertools import product +from typing import Any, Union, Optional + import sqlalchemy +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.ext.compiler import compiles -from typing import Union +from databricks.sql.utils import ParamEscaper -from datetime import datetime, time +def process_literal_param_hack(value: Any): + """This method is supposed to accept a Python type and return a string representation of that type. + But due to some weirdness in the way SQLAlchemy's literal rendering works, we have to return + the value itself because, by the time it reaches our custom type code, it's already been converted + into a string. -from databricks.sql.utils import ParamEscaper + TimeTest + DateTimeTest + DateTimeTZTest + + This dynamic only seems to affect the literal rendering of datetime and time objects. + + All fail without this hack in-place. I'm not sure why. But it works. + """ + return value @compiles(sqlalchemy.types.Enum, "databricks") @@ -64,7 +81,7 @@ def compile_numeric_databricks(type_, compiler, **kw): @compiles(sqlalchemy.types.DateTime, "databricks") def compile_datetime_databricks(type_, compiler, **kw): """ - We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP" instead of "DATETIME" + We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME" """ return "TIMESTAMP_NTZ" @@ -87,13 +104,15 @@ def compile_array_databricks(type_, compiler, **kw): return f"ARRAY<{inner}>" -class DatabricksDateTimeNoTimezoneType(sqlalchemy.types.TypeDecorator): - """The decimal that pysql creates when it receives the contents of a TIMESTAMP_NTZ - includes a timezone of 'Etc/UTC'. But since SQLAlchemy's test suite assumes that - the sqlalchemy.types.DateTime type will return a datetime.datetime _without_ any - timezone set, we need to strip the timezone off the value received from pysql. +class TIMESTAMP_NTZ(sqlalchemy.types.TypeDecorator): + """Represents values comprising values of fields year, month, day, hour, minute, and second. + All operations are performed without taking any time zone into account. + + Our dialect maps sqlalchemy.types.DateTime() to this type, which means that all DateTime() + objects are stored without tzinfo. To read and write timezone-aware datetimes use + databricks.sql.TIMESTAMP instead. - It's not clear if DBR sends a timezone to pysql or if pysql is adding it. This could be a bug. + https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html """ impl = sqlalchemy.types.DateTime @@ -106,36 +125,115 @@ def process_result_value(self, value: Union[None, datetime], dialect): return value.replace(tzinfo=None) +class TIMESTAMP(sqlalchemy.types.TypeDecorator): + """Represents values comprising values of fields year, month, day, hour, minute, and second, + with the session local time-zone. + + Our dialect maps sqlalchemy.types.DateTime() to TIMESTAMP_NTZ, which means that all DateTime() + objects are stored without tzinfo. To read and write timezone-aware datetimes use + this type instead. + + ```python + # This won't work + `Column(sqlalchemy.DateTime(timezone=True))` + + # But this does + `Column(TIMESTAMP)` + ```` + + https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-type.html + """ + + impl = sqlalchemy.types.DateTime + + cache_ok = True + + def process_result_value(self, value: Union[None, datetime], dialect): + if value is None: + return None + + if not value.tzinfo: + return value.replace(tzinfo=timezone.utc) + return value + + def process_bind_param( + self, value: Union[datetime, None], dialect + ) -> Optional[datetime]: + """pysql can pass datetime.datetime() objects directly to DBR""" + return value + + def process_literal_param( + self, value: Union[datetime, None], dialect: Dialect + ) -> str: + """ """ + return process_literal_param_hack(value) + + +@compiles(TIMESTAMP, "databricks") +def compile_timestamp_databricks(type_, compiler, **kw): + """ + We need to override the default DateTime compilation rendering because Databricks uses "TIMESTAMP_NTZ" instead of "DATETIME" + """ + return "TIMESTAMP" + + class DatabricksTimeType(sqlalchemy.types.TypeDecorator): """Databricks has no native TIME type. So we store it as a string.""" impl = sqlalchemy.types.Time cache_ok = True - TIME_WITH_MICROSECONDS_FMT = "%H:%M:%S.%f" - TIME_NO_MICROSECONDS_FMT = "%H:%M:%S" + BASE_FMT = "%H:%M:%S" + MICROSEC_PART = ".%f" + TIMEZONE_PART = "%z" + + def _generate_fmt_string(self, ms: bool, tz: bool) -> str: + """Return a format string for datetime.strptime() that includes or excludes microseconds and timezone.""" + _ = lambda x, y: x if y else "" + return f"{self.BASE_FMT}{_(self.MICROSEC_PART,ms)}{_(self.TIMEZONE_PART,tz)}" + + @property + def allowed_fmt_strings(self): + """Time strings can be read with or without microseconds and with or without a timezone.""" + + if not hasattr(self, "_allowed_fmt_strings"): + ms_switch = tz_switch = [True, False] + self._allowed_fmt_strings = [ + self._generate_fmt_string(x, y) + for x, y in product(ms_switch, tz_switch) + ] + + return self._allowed_fmt_strings + + def _parse_result_string(self, value: str) -> time: + """Parse a string into a time object. Try all allowed formats until one works.""" + for fmt in self.allowed_fmt_strings: + try: + # We use timetz() here because we want to preserve the timezone information + # Calling .time() will strip the timezone information + return datetime.strptime(value, fmt).timetz() + except ValueError: + pass + + raise ValueError(f"Could not parse time string {value}") + + def _determine_fmt_string(self, value: time) -> str: + """Determine which format string to use to render a time object as a string.""" + ms_bool = value.microsecond > 0 + tz_bool = value.tzinfo is not None + return self._generate_fmt_string(ms_bool, tz_bool) def process_bind_param(self, value: Union[time, None], dialect) -> Union[None, str]: """Values sent to the database are converted to %:H:%M:%S strings.""" if value is None: return None - return value.strftime(self.TIME_WITH_MICROSECONDS_FMT) + fmt_string = self._determine_fmt_string(value) + return value.strftime(fmt_string) # mypy doesn't like this workaround because TypeEngine wants process_literal_param to return a string def process_literal_param(self, value, dialect) -> time: # type: ignore - """It's not clear to me why this is necessary. Without it, SQLAlchemy's Timetest:test_literal fails - because the string literal renderer receives a str() object and calls .isoformat() on it. - - Whereas this method receives a datetime.time() object which is subsequently passed to that - same renderer. And that works. - - UPDATE: After coping with the literal_processor override in DatabricksStringType, I suspect a similar - mechanism is at play. Two different processors are are called in sequence. This is likely a byproduct - of Databricks not having a true TIME type. I think the string representation of Time() types is - somehow affecting the literal rendering process. But as long as this passes the tests, I'm not - worried about it. - """ - return value + """ """ + return process_literal_param_hack(value) def process_result_value( self, value: Union[None, str], dialect @@ -144,13 +242,7 @@ def process_result_value( if value is None: return None - try: - _parsed = datetime.strptime(value, self.TIME_WITH_MICROSECONDS_FMT) - except ValueError: - # If the string doesn't have microseconds, try parsing it without them - _parsed = datetime.strptime(value, self.TIME_NO_MICROSECONDS_FMT) - - return _parsed.time() + return self._parse_result_string(value) class DatabricksStringType(sqlalchemy.types.TypeDecorator): diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 053a45e2c..072c5111d 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -66,7 +66,7 @@ class DatabricksDialect(default.DefaultDialect): supports_sequences: bool = False colspecs = { - sqlalchemy.types.DateTime: dialect_type_impl.DatabricksDateTimeNoTimezoneType, + sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ, sqlalchemy.types.Time: dialect_type_impl.DatabricksTimeType, sqlalchemy.types.String: dialect_type_impl.DatabricksStringType, } diff --git a/src/databricks/sqlalchemy/requirements.py b/src/databricks/sqlalchemy/requirements.py index b6ff46641..b68f63448 100644 --- a/src/databricks/sqlalchemy/requirements.py +++ b/src/databricks/sqlalchemy/requirements.py @@ -228,3 +228,10 @@ def denormalized_names(self): UPPERCASE as case insensitive names.""" return sqlalchemy.testing.exclusions.open() + + @property + def time_timezone(self): + """target dialect supports representation of Python + datetime.time() with tzinfo with Time(timezone=True).""" + + return sqlalchemy.testing.exclusions.open() diff --git a/src/databricks/sqlalchemy/test/_extra.py b/src/databricks/sqlalchemy/test/_extra.py index f8e11bde6..2f3e7a7db 100644 --- a/src/databricks/sqlalchemy/test/_extra.py +++ b/src/databricks/sqlalchemy/test/_extra.py @@ -1,6 +1,8 @@ """Additional tests authored by Databricks that use SQLAlchemy's test fixtures """ +import datetime + from sqlalchemy.testing.suite.test_types import ( _LiteralRoundTripFixture, fixtures, @@ -10,8 +12,10 @@ Table, Column, config, + _DateFixture, + literal, ) -from databricks.sqlalchemy import TINYINT +from databricks.sqlalchemy import TINYINT, TIMESTAMP class TinyIntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): @@ -46,3 +50,21 @@ def run(datatype, data): assert isinstance(row[0], int) return run + + +class DateTimeTZTestCustom(_DateFixture, fixtures.TablesTest): + """This test confirms that when a user uses the TIMESTAMP + type to store a datetime object, it retains its timezone + """ + + __backend__ = True + datatype = TIMESTAMP + data = datetime.datetime(2012, 10, 15, 12, 57, 18, tzinfo=datetime.timezone.utc) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + + # We need to pass the TIMESTAMP type to the literal function + # so that the value is processed correctly. + result = connection.scalar(select(literal(self.data, TIMESTAMP))) + eq_(result, self.data) diff --git a/src/databricks/sqlalchemy/test/_future.py b/src/databricks/sqlalchemy/test/_future.py index 519a4e092..7c7d8608e 100644 --- a/src/databricks/sqlalchemy/test/_future.py +++ b/src/databricks/sqlalchemy/test/_future.py @@ -24,7 +24,6 @@ CollateTest, ComputedColumnTest, ComputedReflectionTest, - DateTimeTZTest, DifficultParametersTest, FutureWeCanSetDefaultSchemaWEventsTest, IdentityColumnTest, @@ -35,7 +34,6 @@ QuotedNameArgumentTest, RowCountTest, SimpleUpdateDeleteTest, - TimeTZTest, WeCanSetDefaultSchemaWEventsTest, ) @@ -58,7 +56,6 @@ class FutureFeature(Enum): TBL_COMMENTS = "table comment reflection" TBL_OPTS = "get_table_options method" TEST_DESIGN = "required test-fixture overrides" - TIMEZONE = "timezone handling for DateTime() or Time() types" TUPLE_LITERAL = "tuple-like IN markers completely" UUID = "native Uuid() type" VIEW_DEF = "get_view_definition method" @@ -202,26 +199,6 @@ def test_regexp_match(self): pass -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TIMEZONE, True)) -class DateTimeTZTest(DateTimeTZTest): - """When I initially implemented DateTime type handling, I started using TIMESTAMP_NTZ because - that's the default behaviour of the DateTime() type and the other tests passed. I simply missed - this group of tests. Will need to modify the compilation and result_processor for our type override - so that we can pass both DateTimeTZTest and DateTimeTest. Currently, only DateTimeTest passes. - """ - - pass - - -@pytest.mark.reviewed -@pytest.mark.skip(render_future_feature(FutureFeature.TIMEZONE, True)) -class TimeTZTest(TimeTZTest): - """Similar to DateTimeTZTest, this should be possible for the dialect since we can override type compilation - and processing in _types.py. Implementation has been deferred. - """ - - @pytest.mark.reviewed @pytest.mark.skip(render_future_feature(FutureFeature.COLLATE)) class CollateTest(CollateTest): diff --git a/src/databricks/sqlalchemy/test/_regression.py b/src/databricks/sqlalchemy/test/_regression.py index 6342d2d51..c797bbb70 100644 --- a/src/databricks/sqlalchemy/test/_regression.py +++ b/src/databricks/sqlalchemy/test/_regression.py @@ -45,6 +45,7 @@ TimeMicrosecondsTest, TimestampMicrosecondsTest, TimeTest, + TimeTZTest, TrueDivTest, UnicodeTextTest, UnicodeVarcharTest, @@ -300,3 +301,8 @@ class IdentityAutoincrementTest(IdentityAutoincrementTest): @pytest.mark.reviewed class LikeFunctionsTest(LikeFunctionsTest): pass + + +@pytest.mark.reviewed +class TimeTZTest(TimeTZTest): + pass diff --git a/src/databricks/sqlalchemy/test/_unsupported.py b/src/databricks/sqlalchemy/test/_unsupported.py index 899e73e43..1fce4467a 100644 --- a/src/databricks/sqlalchemy/test/_unsupported.py +++ b/src/databricks/sqlalchemy/test/_unsupported.py @@ -19,6 +19,7 @@ # These are test suites that are fully skipped with a SkipReason from sqlalchemy.testing.suite import ( AutocommitIsolationTest, + DateTimeTZTest, ExceptionTest, HasIndexTest, HasSequenceTest, @@ -51,6 +52,7 @@ class SkipReason(Enum): STRING_FEAT = "required STRING type features" SYMBOL_CHARSET = "symbols expected by test" TEMP_TBL = "temporary tables" + TIMEZONE_OPT = "timezone-optional TIMESTAMP fields" TRANSACTIONS = "transactions" UNIQUE = "UNIQUE constraints" @@ -415,3 +417,18 @@ def test_delete_scalar_subq_round_trip(self): This suggests a limitation of the platform. But a workaround may be possible if customers require it. """ pass + + +@pytest.mark.reviewed +@pytest.mark.skip(render_skip_reason(SkipReason.TIMEZONE_OPT, True)) +class DateTimeTZTest(DateTimeTZTest): + """Test whether the sqlalchemy.DateTime() type can _optionally_ include timezone info. + This dialect maps DateTime() → TIMESTAMP, which _always_ includes tzinfo. + + Users can use databricks.sqlalchemy.TIMESTAMP_NTZ for a tzinfo-less timestamp. The SQLA docs + acknowledge this is expected for some dialects. + + https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.DateTime + """ + + pass diff --git a/src/databricks/sqlalchemy/test/test_suite.py b/src/databricks/sqlalchemy/test/test_suite.py index 1c3920498..54f071731 100644 --- a/src/databricks/sqlalchemy/test/test_suite.py +++ b/src/databricks/sqlalchemy/test/test_suite.py @@ -23,4 +23,4 @@ def start_protocol_patch(): from databricks.sqlalchemy.test._regression import * from databricks.sqlalchemy.test._unsupported import * from databricks.sqlalchemy.test._future import * -from databricks.sqlalchemy.test._extra import TinyIntegerTest +from databricks.sqlalchemy.test._extra import TinyIntegerTest, DateTimeTZTestCustom diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py index c29edfcec..73e286699 100644 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ b/src/databricks/sqlalchemy/test_local/test_types.py @@ -4,7 +4,7 @@ import sqlalchemy from databricks.sqlalchemy.base import DatabricksDialect -from databricks.sqlalchemy._types import TINYINT +from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ class DatabricksDataType(enum.Enum): @@ -129,6 +129,8 @@ def test_numeric_renders_as_decimal_with_precision_and_scale(self): sqlalchemy.types.SMALLINT: DatabricksDataType.SMALLINT, sqlalchemy.types.TIMESTAMP: DatabricksDataType.TIMESTAMP, TINYINT: DatabricksDataType.TINYINT, + TIMESTAMP: DatabricksDataType.TIMESTAMP, + TIMESTAMP_NTZ: DatabricksDataType.TIMESTAMP_NTZ, }