Skip to content

Fix train_on_batch/test_on_batch returning cumulative metric instead of per-batch loss#22617

Open
ssam18 wants to merge 3 commits intokeras-team:masterfrom
ssam18:fix/train-test-on-batch-cumulative-metric
Open

Fix train_on_batch/test_on_batch returning cumulative metric instead of per-batch loss#22617
ssam18 wants to merge 3 commits intokeras-team:masterfrom
ssam18:fix/train-test-on-batch-cumulative-metric

Conversation

@ssam18
Copy link
Copy Markdown
Contributor

@ssam18 ssam18 commented Apr 2, 2026

Train_on_batch and test_on_batch were accumulating metric state across calls, so the reported loss was a running average over all previous batches rather than the current batch alone. Added reset_metrics() at the start of both methods across all three backends (TF, Torch, JAX) to match the behavior of model.evaluate() and manual loss calculations.

…of per-batch loss. Call reset_metrics() at the start of train_on_batch and test_on_batch across all three backends (TF, Torch, JAX) so each call returns the loss for that batch only, consistent with model.evaluate() and manual metric computation.

Fix for the bug keras-team#22596
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces self.reset_metrics() calls into the train_on_batch and test_on_batch methods across the JAX, TensorFlow, and Torch backends. The review comments correctly identify that this change forces per-batch metric reporting, which breaks compatibility with Keras 2 workflows that rely on cumulative metric accumulation. Furthermore, the placement of these calls before the model build step is identified as a functional issue, as it fails to reset metrics initialized during the build process. The reviewer suggests moving these calls after the build and training function initialization to maintain proper state management and API consistency.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 2, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 83.29%. Comparing base (9708582) to head (69c02a1).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #22617   +/-   ##
=======================================
  Coverage   83.29%   83.29%           
=======================================
  Files         596      596           
  Lines       68138    68144    +6     
  Branches    10613    10613           
=======================================
+ Hits        56754    56760    +6     
  Misses       8638     8638           
  Partials     2746     2746           
Flag Coverage Δ
keras 83.10% <100.00%> (+<0.01%) ⬆️
keras-jax 59.66% <33.33%> (-0.01%) ⬇️
keras-numpy 55.33% <0.00%> (-0.01%) ⬇️
keras-openvino 53.38% <0.00%> (-0.01%) ⬇️
keras-tensorflow 61.03% <33.33%> (-0.01%) ⬇️
keras-torch 59.85% <33.33%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

ssam18 added 2 commits April 2, 2026 11:34
…eview feedback: reset_metrics() must run after _symbolic_build() and make_train/test_function() so that any metrics initialized during the build process are also included in the reset.
…ior.

The previous expected values in test_on_batch_methods and test_nested_inputs
were based on the old cumulative metric behavior (where each call accumulated
state from prior calls). Now that reset_metrics() is called at the start of
each train_on_batch/test_on_batch, every call returns the loss for that
batch only — so all expected values after the first call are updated to match
the actual per-batch losses.
@ssam18 ssam18 force-pushed the fix/train-test-on-batch-cumulative-metric branch from c5d5207 to 69c02a1 Compare April 2, 2026 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants