diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d4dbdf0f029d4..a043905765ecf 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + // New batch to aggregate in partial aggregation operator + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } + + // Found end from input stream None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); @@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done - } else if self.should_skip_aggregation() { + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -879,10 +920,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -927,9 +968,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); @@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream { } /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation fn update_skip_aggregation_probe(&mut self, input_rows: usize) { if let Some(probe) = self.skip_aggregation_probe.as_mut() { // Skip aggregation probe is not supported if stream has any spills, @@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream { /// In case the probe indicates that aggregation may be /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { @@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream { /// Returns true if the aggregation probe indicates that aggregation /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation fn should_skip_aggregation(&self) -> bool { self.skip_aggregation_probe .as_ref()