Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ async def create_session(
storage_app_state.state, storage_user_state.state, session_state
)
session = storage_session.to_session(
state=merged_state, is_sqlite=is_sqlite
state=merged_state, is_sqlite=is_sqlite, is_postgresql=is_postgresql
)
return session

Expand All @@ -498,6 +498,8 @@ async def get_session(
# 2. Get all the events based on session id and filtering config
# 3. Convert and return the session
schema = self._get_schema_classes()
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
async with self._rollback_on_exception_session(
read_only=True
) as sql_session:
Expand Down Expand Up @@ -543,9 +545,11 @@ async def get_session(

# Convert storage session to session
events = [e.to_event() for e in reversed(storage_events)]
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
state=merged_state,
events=events,
is_sqlite=is_sqlite,
is_postgresql=is_postgresql,
)
return session

Expand Down Expand Up @@ -592,12 +596,17 @@ async def list_sessions(

sessions = []
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
for storage_session in results:
session_state = storage_session.state
user_state = user_states_map.get(storage_session.user_id, {})
merged_state = _merge_state(app_state, user_state, session_state)
sessions.append(
storage_session.to_session(state=merged_state, is_sqlite=is_sqlite)
storage_session.to_session(
state=merged_state,
is_sqlite=is_sqlite,
is_postgresql=is_postgresql,
)
)
return ListSessionsResponse(sessions=sessions)

Expand Down Expand Up @@ -633,6 +642,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
# 3. Store the new event.
schema = self._get_schema_classes()
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
is_postgresql = self.db_engine.dialect.name == _POSTGRESQL_DIALECT
use_row_level_locking = self._supports_row_level_locking()

state_delta = (
Expand Down Expand Up @@ -662,7 +672,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_session = storage_session_result.scalars().one_or_none()
if storage_session is None:
raise ValueError(f"Session {session.id} not found.")
storage_update_time = storage_session.get_update_timestamp(is_sqlite)
storage_update_time = storage_session.get_update_timestamp(
is_sqlite, is_postgresql
)
storage_update_marker = storage_session.get_update_marker()

storage_app_state = await _select_required_state(
Expand Down Expand Up @@ -728,20 +740,20 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_session.state | state_deltas["session"]
)

if is_sqlite:
if is_sqlite or is_postgresql:
update_time = datetime.fromtimestamp(
event.timestamp, timezone.utc
).replace(tzinfo=None)
else:
update_time = datetime.fromtimestamp(event.timestamp)
update_time = datetime.fromtimestamp(event.timestamp, timezone.utc)
storage_session.update_time = update_time
sql_session.add(schema.StorageEvent.from_event(session, event))

await sql_session.commit()

# Update timestamp with commit time
session.last_update_time = storage_session.get_update_timestamp(
is_sqlite
is_sqlite, is_postgresql
)
session._storage_update_marker = storage_session.get_update_marker()

Expand Down
30 changes: 19 additions & 11 deletions src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,24 @@ def update_timestamp_tz(self) -> float:
This is a compatibility alias for callers that used the pre-`main` API.
"""
sqlalchemy_session = inspect(self).session
is_sqlite = bool(
sqlalchemy_session
and sqlalchemy_session.bind
and sqlalchemy_session.bind.dialect.name == "sqlite"
dialect_name = (
sqlalchemy_session.bind.dialect.name
if sqlalchemy_session and sqlalchemy_session.bind
else None
)
is_sqlite = dialect_name == "sqlite"
is_postgresql = dialect_name == "postgresql"
return self.get_update_timestamp(
is_sqlite=is_sqlite, is_postgresql=is_postgresql
)
return self.get_update_timestamp(is_sqlite=is_sqlite)

def get_update_timestamp(self, is_sqlite: bool) -> float:
def get_update_timestamp(
self, is_sqlite: bool, is_postgresql: bool = False
) -> float:
"""Returns the time zone aware update timestamp."""
if is_sqlite:
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
# object without timezone information. We need to convert it to UTC
# manually.
if is_sqlite or is_postgresql:
# SQLite and PostgreSQL store naive datetimes as UTC values. We need to
# attach UTC timezone info before converting to a POSIX timestamp.
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()

Expand All @@ -195,6 +200,7 @@ def to_session(
state: dict[str, Any] | None = None,
events: list[Event] | None = None,
is_sqlite: bool = False,
is_postgresql: bool = False,
) -> Session:
"""Converts the storage session to a session object."""
if state is None:
Expand All @@ -208,7 +214,9 @@ def to_session(
id=self.id,
state=state,
events=events,
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
last_update_time=self.get_update_timestamp(
is_sqlite=is_sqlite, is_postgresql=is_postgresql
),
)
session._storage_update_marker = self.get_update_marker()
return session
Expand Down
30 changes: 19 additions & 11 deletions src/google/adk/sessions/schemas/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,24 @@ def update_timestamp_tz(self) -> float:
This is a compatibility alias for callers that used the pre-`main` API.
"""
sqlalchemy_session = inspect(self).session
is_sqlite = bool(
sqlalchemy_session
and sqlalchemy_session.bind
and sqlalchemy_session.bind.dialect.name == "sqlite"
dialect_name = (
sqlalchemy_session.bind.dialect.name
if sqlalchemy_session and sqlalchemy_session.bind
else None
)
is_sqlite = dialect_name == "sqlite"
is_postgresql = dialect_name == "postgresql"
return self.get_update_timestamp(
is_sqlite=is_sqlite, is_postgresql=is_postgresql
)
return self.get_update_timestamp(is_sqlite=is_sqlite)

def get_update_timestamp(self, is_sqlite: bool) -> float:
def get_update_timestamp(
self, is_sqlite: bool, is_postgresql: bool = False
) -> float:
"""Returns the time zone aware update timestamp."""
if is_sqlite:
# SQLite does not support timezone. SQLAlchemy returns a naive datetime
# object without timezone information. We need to convert it to UTC
# manually.
if is_sqlite or is_postgresql:
# SQLite and PostgreSQL store naive datetimes as UTC values. We need to
# attach UTC timezone info before converting to a POSIX timestamp.
return self.update_time.replace(tzinfo=timezone.utc).timestamp()
return self.update_time.timestamp()

Expand All @@ -142,6 +147,7 @@ def to_session(
state: dict[str, Any] | None = None,
events: list[Event] | None = None,
is_sqlite: bool = False,
is_postgresql: bool = False,
) -> Session:
"""Converts the storage session to a session object."""
if state is None:
Expand All @@ -155,7 +161,9 @@ def to_session(
id=self.id,
state=state,
events=events,
last_update_time=self.get_update_timestamp(is_sqlite=is_sqlite),
last_update_time=self.get_update_timestamp(
is_sqlite=is_sqlite, is_postgresql=is_postgresql
),
)
session._storage_update_marker = self.get_update_marker()
return session
Expand Down