Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions api/comms/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ func chatReadMessages(db dbv1.DBTX, ctx context.Context, userId int32, chatId st
return err
}

// chatReadAllMessages clears unread state for every chat this user belongs to.
// The (last_active_at IS NULL OR last_active_at < $1) guard mirrors the
// per-chat chat.read handler so out-of-order RPCs can't roll back a more
// recent read. The unread_count > 0 filter keeps the write set small.
func chatReadAllMessages(db dbv1.DBTX, ctx context.Context, userId int32, readTimestamp time.Time) error {
_, err := db.Exec(ctx, `
update chat_member
set unread_count = 0, last_active_at = $1
where user_id = $2
and (last_active_at is null or last_active_at < $1)
and unread_count > 0`,
readTimestamp.UTC(), userId)
return err
}

var permissions = []ChatPermission{
ChatPermissionFollowees,
ChatPermissionFollowers,
Expand Down
65 changes: 65 additions & 0 deletions api/comms/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,68 @@ func TestChat(t *testing.T) {

assertReaction(user1Id, replyMessageId, nil)
}

func TestChatReadAllMessages(t *testing.T) {
pool := database.CreateTestDatabase(t, "test_comms")
defer pool.Close()

ctx := context.Background()
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))

user1Id := int32(1)
user2Id := int32(2)
user3Id := int32(3)

chatA := trashid.ChatID(int(user1Id), int(user2Id))
chatB := trashid.ChatID(int(user1Id), int(user3Id))
SetupChatWithMembers(t, pool, ctx, chatA, user1Id, user2Id, "a1", "a2")
SetupChatWithMembers(t, pool, ctx, chatB, user1Id, user3Id, "b1", "b3")

assertUnreadCount := func(chatId string, userId int32, expected int) {
t.Helper()
unreadCount := 0
err := pool.QueryRow(ctx, "select unread_count from chat_member where chat_id = $1 and user_id = $2", chatId, userId).Scan(&unreadCount)
assert.NoError(t, err)
assert.Equal(t, expected, unreadCount, "unread for chat %s user %d", chatId, userId)
}

// Send user1Id one message in each chat from the other party.
err := chatSendMessage(pool, ctx, user2Id, chatA, strconv.Itoa(seededRand.Int()), time.Now(), "hi from 2")
assert.NoError(t, err)
err = chatSendMessage(pool, ctx, user3Id, chatB, strconv.Itoa(seededRand.Int()), time.Now(), "hi from 3")
assert.NoError(t, err)

assertUnreadCount(chatA, user1Id, 1)
assertUnreadCount(chatB, user1Id, 1)
// Senders' own unread counts stay at zero.
assertUnreadCount(chatA, user2Id, 0)
assertUnreadCount(chatB, user3Id, 0)

// Single call clears every unread chat for user1Id without touching
// the other members' chats.
readTs := time.Now()
err = chatReadAllMessages(pool, ctx, user1Id, readTs)
assert.NoError(t, err)

assertUnreadCount(chatA, user1Id, 0)
assertUnreadCount(chatB, user1Id, 0)

// Re-confirm: a stale read (older timestamp) does NOT roll back.
// Add a new unread, advance via chatReadAllMessages, then try a stale read.
err = chatSendMessage(pool, ctx, user2Id, chatA, strconv.Itoa(seededRand.Int()), time.Now(), "another from 2")
assert.NoError(t, err)
assertUnreadCount(chatA, user1Id, 1)

freshTs := time.Now()
err = chatReadAllMessages(pool, ctx, user1Id, freshTs)
assert.NoError(t, err)
assertUnreadCount(chatA, user1Id, 0)

// Older timestamp must be a no-op for last_active_at.
err = chatReadAllMessages(pool, ctx, user1Id, freshTs.Add(-time.Hour))
assert.NoError(t, err)
var lastActive time.Time
err = pool.QueryRow(ctx, "select last_active_at from chat_member where chat_id = $1 and user_id = $2", chatA, user1Id).Scan(&lastActive)
assert.NoError(t, err)
assert.WithinDuration(t, freshTs.UTC(), lastActive.UTC(), time.Second, "stale read should not roll back last_active_at")
}
7 changes: 7 additions & 0 deletions api/comms/rpc_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ select last_active_at from chat_member where chat_id = $1 and user_id = $2`
return err
}
}
case RPCMethodChatReadAll:
// No params to unmarshal. The per-row last_active_at guard lives
// inside chatReadAllMessages so we don't have to read first.
err = chatReadAllMessages(tx, ctx, userId, messageTs)
if err != nil {
return err
}
case RPCMethodChatPermit:
var params ChatPermitRPCParams
err = json.Unmarshal(rawRpc.Params, &params)
Expand Down
14 changes: 14 additions & 0 deletions api/comms/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ type ChatReadRPCParams struct {
ChatID string `json:"chat_id"`
}

type ChatReadAllRPC struct {
Method ChatReadAllRPCMethod `json:"method"`
Params ChatReadAllRPCParams `json:"params"`
}

type ChatReadAllRPCParams struct{}

type ChatBlockRPC struct {
Method ChatBlockRPCMethod `json:"method"`
Params ChatBlockRPCParams `json:"params"`
Expand Down Expand Up @@ -367,6 +374,12 @@ const (
MethodChatRead ChatReadRPCMethod = "chat.read"
)

type ChatReadAllRPCMethod string

const (
MethodChatReadAll ChatReadAllRPCMethod = "chat.read_all"
)

type ChatBlockRPCMethod string

const (
Expand Down Expand Up @@ -410,6 +423,7 @@ const (
RPCMethodChatPermit RPCMethod = "chat.permit"
RPCMethodChatReact RPCMethod = "chat.react"
RPCMethodChatRead RPCMethod = "chat.read"
RPCMethodChatReadAll RPCMethod = "chat.read_all"
RPCMethodChatUnblock RPCMethod = "chat.unblock"
RPCMethodUserValidateCanChat RPCMethod = "user.validate_can_chat"
)
3 changes: 3 additions & 0 deletions api/comms/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ func (vtor *Validator) Validate(ctx context.Context, userId int32, rawRpc RawRPC
return vtor.validateChatReact(vtor.pool, ctx, userId, rawRpc)
case RPCMethodChatRead:
return vtor.validateChatRead(userId, rawRpc)
case RPCMethodChatReadAll:
// No params to validate; ban check above already gates this call.
return nil
case RPCMethodChatPermit:
return vtor.validateChatPermit(userId, rawRpc)
case RPCMethodChatBlock:
Expand Down
Loading