Fix train_on_batch/test_on_batch returning cumulative metric instead of per-batch loss#22617
Fix train_on_batch/test_on_batch returning cumulative metric instead of per-batch loss#22617ssam18 wants to merge 3 commits intokeras-team:masterfrom
Conversation
…of per-batch loss. Call reset_metrics() at the start of train_on_batch and test_on_batch across all three backends (TF, Torch, JAX) so each call returns the loss for that batch only, consistent with model.evaluate() and manual metric computation. Fix for the bug keras-team#22596
There was a problem hiding this comment.
Code Review
This pull request introduces self.reset_metrics() calls into the train_on_batch and test_on_batch methods across the JAX, TensorFlow, and Torch backends. The review comments correctly identify that this change forces per-batch metric reporting, which breaks compatibility with Keras 2 workflows that rely on cumulative metric accumulation. Furthermore, the placement of these calls before the model build step is identified as a functional issue, as it fails to reset metrics initialized during the build process. The reviewer suggests moving these calls after the build and training function initialization to maintain proper state management and API consistency.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22617 +/- ##
=======================================
Coverage 83.29% 83.29%
=======================================
Files 596 596
Lines 68138 68144 +6
Branches 10613 10613
=======================================
+ Hits 56754 56760 +6
Misses 8638 8638
Partials 2746 2746
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…eview feedback: reset_metrics() must run after _symbolic_build() and make_train/test_function() so that any metrics initialized during the build process are also included in the reset.
…ior. The previous expected values in test_on_batch_methods and test_nested_inputs were based on the old cumulative metric behavior (where each call accumulated state from prior calls). Now that reset_metrics() is called at the start of each train_on_batch/test_on_batch, every call returns the loss for that batch only — so all expected values after the first call are updated to match the actual per-batch losses.
c5d5207 to
69c02a1
Compare
Train_on_batch and test_on_batch were accumulating metric state across calls, so the reported loss was a running average over all previous batches rather than the current batch alone. Added reset_metrics() at the start of both methods across all three backends (TF, Torch, JAX) to match the behavior of model.evaluate() and manual loss calculations.