diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..0c83420682 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -480,7 +480,7 @@ async def create_session( storage_app_state.state, storage_user_state.state, session_state ) session = storage_session.to_session( - state=merged_state, is_sqlite=is_sqlite + state=merged_state, is_sqlite=is_sqlite, is_postgresql=is_postgresql ) return session @@ -498,6 +498,8 @@ async def get_session( # 2. Get all the events based on session id and filtering config # 3. Convert and return the session schema = self._get_schema_classes() + is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT async with self._rollback_on_exception_session( read_only=True ) as sql_session: @@ -543,9 +545,11 @@ async def get_session( # Convert storage session to session events = [e.to_event() for e in reversed(storage_events)] - is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT session = storage_session.to_session( - state=merged_state, events=events, is_sqlite=is_sqlite + state=merged_state, + events=events, + is_sqlite=is_sqlite, + is_postgresql=is_postgresql, ) return session @@ -592,12 +596,17 @@ async def list_sessions( sessions = [] is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT for storage_session in results: session_state = storage_session.state user_state = user_states_map.get(storage_session.user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) sessions.append( - storage_session.to_session(state=merged_state, is_sqlite=is_sqlite) + storage_session.to_session( + state=merged_state, + is_sqlite=is_sqlite, + is_postgresql=is_postgresql, + ) ) return ListSessionsResponse(sessions=sessions) @@ -633,6 +642,7 @@ async def append_event(self, session: Session, event: Event) -> Event: # 3. Store the new event. schema = self._get_schema_classes() is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT + is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT use_row_level_locking = self._supports_row_level_locking() state_delta = ( @@ -662,7 +672,9 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session = storage_session_result.scalars().one_or_none() if storage_session is None: raise ValueError(f"Session {session.id} not found.") - storage_update_time = storage_session.get_update_timestamp(is_sqlite) + storage_update_time = storage_session.get_update_timestamp( + is_sqlite, is_postgresql + ) storage_update_marker = storage_session.get_update_marker() storage_app_state = await _select_required_state( @@ -728,12 +740,12 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.state | state_deltas["session"] ) - if is_sqlite: + if is_sqlite or is_postgresql: update_time = datetime.fromtimestamp( event.timestamp, timezone.utc ).replace(tzinfo=None) else: - update_time = datetime.fromtimestamp(event.timestamp) + update_time = datetime.fromtimestamp(event.timestamp, timezone.utc) storage_session.update_time = update_time sql_session.add(schema.StorageEvent.from_event(session, event)) @@ -741,7 +753,7 @@ async def append_event(self, session: Session, event: Event) -> Event: # Update timestamp with commit time session.last_update_time = storage_session.get_update_timestamp( - is_sqlite + is_sqlite, is_postgresql ) session._storage_update_marker = storage_session.get_update_marker() diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index e4a4368c6d..0a6dc21524 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -167,19 +167,24 @@ def update_timestamp_tz(self) -> float: This is a compatibility alias for callers that used the pre-`main` API. """ sqlalchemy_session = inspect(self).session - is_sqlite = bool( - sqlalchemy_session - and sqlalchemy_session.bind - and sqlalchemy_session.bind.dialect.name == "sqlite" + dialect_name = ( + sqlalchemy_session.bind.dialect.name + if sqlalchemy_session and sqlalchemy_session.bind + else None + ) + is_sqlite = dialect_name == "sqlite" + is_postgresql = dialect_name == "postgresql" + return self.get_update_timestamp( + is_sqlite=is_sqlite, is_postgresql=is_postgresql ) - return self.get_update_timestamp(is_sqlite=is_sqlite) - def get_update_timestamp(self, is_sqlite: bool) -> float: + def get_update_timestamp( + self, is_sqlite: bool, is_postgresql: bool = False + ) -> float: """Returns the time zone aware update timestamp.""" - if is_sqlite: - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. + if is_sqlite or is_postgresql: + # SQLite and PostgreSQL store naive datetimes as UTC values. We need to + # attach UTC timezone info before converting to a POSIX timestamp. return self.update_time.replace(tzinfo=timezone.utc).timestamp() return self.update_time.timestamp() @@ -195,6 +200,7 @@ def to_session( state: dict[str, Any] | None = None, events: list[Event] | None = None, is_sqlite: bool = False, + is_postgresql: bool = False, ) -> Session: """Converts the storage session to a session object.""" if state is None: @@ -208,7 +214,9 @@ def to_session( id=self.id, state=state, events=events, - last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), + last_update_time=self.get_update_timestamp( + is_sqlite=is_sqlite, is_postgresql=is_postgresql + ), ) session._storage_update_marker = self.get_update_marker() return session diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py index 12d8ee9061..a4ab40e7ef 100644 --- a/src/google/adk/sessions/schemas/v1.py +++ b/src/google/adk/sessions/schemas/v1.py @@ -114,19 +114,24 @@ def update_timestamp_tz(self) -> float: This is a compatibility alias for callers that used the pre-`main` API. """ sqlalchemy_session = inspect(self).session - is_sqlite = bool( - sqlalchemy_session - and sqlalchemy_session.bind - and sqlalchemy_session.bind.dialect.name == "sqlite" + dialect_name = ( + sqlalchemy_session.bind.dialect.name + if sqlalchemy_session and sqlalchemy_session.bind + else None + ) + is_sqlite = dialect_name == "sqlite" + is_postgresql = dialect_name == "postgresql" + return self.get_update_timestamp( + is_sqlite=is_sqlite, is_postgresql=is_postgresql ) - return self.get_update_timestamp(is_sqlite=is_sqlite) - def get_update_timestamp(self, is_sqlite: bool) -> float: + def get_update_timestamp( + self, is_sqlite: bool, is_postgresql: bool = False + ) -> float: """Returns the time zone aware update timestamp.""" - if is_sqlite: - # SQLite does not support timezone. SQLAlchemy returns a naive datetime - # object without timezone information. We need to convert it to UTC - # manually. + if is_sqlite or is_postgresql: + # SQLite and PostgreSQL store naive datetimes as UTC values. We need to + # attach UTC timezone info before converting to a POSIX timestamp. return self.update_time.replace(tzinfo=timezone.utc).timestamp() return self.update_time.timestamp() @@ -142,6 +147,7 @@ def to_session( state: dict[str, Any] | None = None, events: list[Event] | None = None, is_sqlite: bool = False, + is_postgresql: bool = False, ) -> Session: """Converts the storage session to a session object.""" if state is None: @@ -155,7 +161,9 @@ def to_session( id=self.id, state=state, events=events, - last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite), + last_update_time=self.get_update_timestamp( + is_sqlite=is_sqlite, is_postgresql=is_postgresql + ), ) session._storage_update_marker = self.get_update_marker() return session