diff --git a/datafusion/src/datasource/csv.rs b/datafusion/src/datasource/csv.rs index 6f6c9abe07741..1bd1b4be823ee 100644 --- a/datafusion/src/datasource/csv.rs +++ b/datafusion/src/datasource/csv.rs @@ -35,8 +35,9 @@ use arrow::datatypes::SchemaRef; use std::any::Any; +use std::io::{Read, Seek}; use std::string::String; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use crate::datasource::datasource::Statistics; use crate::datasource::TableProvider; @@ -46,10 +47,17 @@ use crate::physical_plan::csv::CsvExec; pub use crate::physical_plan::csv::CsvReadOptions; use crate::physical_plan::{common, ExecutionPlan}; +enum Source { + /// Path to a single CSV file or a directory containing one of more CSV files + Path(String), + + /// Read CSV data from a reader + Reader(Mutex>>), +} + /// Represents a CSV file with a provided schema pub struct CsvFile { - /// Path to a single CSV file or a directory containing one of more CSV files - path: String, + source: Source, schema: SchemaRef, has_header: bool, delimiter: u8, @@ -77,7 +85,7 @@ impl CsvFile { }); Ok(Self { - path: String::from(path), + source: Source::Path(path.to_string()), schema, has_header: options.has_header, delimiter: options.delimiter, @@ -86,9 +94,64 @@ impl CsvFile { }) } + /// Attempt to initialize a `CsvFile` from a reader. The schema MUST be provided in options. + pub fn try_new_from_reader( + reader: R, + options: CsvReadOptions, + ) -> Result { + let schema = Arc::new(match options.schema { + Some(s) => s.clone(), + None => { + return Err(DataFusionError::Execution( + "Schema must be provided to CsvRead".to_string(), + )); + } + }); + + Ok(Self { + source: Source::Reader(Mutex::new(Some(Box::new(reader)))), + schema, + has_header: options.has_header, + delimiter: options.delimiter, + statistics: Statistics::default(), + file_extension: String::new(), + }) + } + + /// Attempt to initialize a `CsvRead` from a reader impls `Seek`. The schema can be inferred automatically. + pub fn try_new_from_reader_infer_schema( + mut reader: R, + options: CsvReadOptions, + ) -> Result { + let schema = Arc::new(match options.schema { + Some(s) => s.clone(), + None => { + let (schema, _) = arrow::csv::reader::infer_file_schema( + &mut reader, + options.delimiter, + Some(options.schema_infer_max_records), + options.has_header, + )?; + schema + } + }); + + Ok(Self { + source: Source::Reader(Mutex::new(Some(Box::new(reader)))), + schema, + has_header: options.has_header, + delimiter: options.delimiter, + statistics: Statistics::default(), + file_extension: String::new(), + }) + } + /// Get the path for the CSV file(s) represented by this CsvFile instance pub fn path(&self) -> &str { - &self.path + match &self.source { + Source::Reader(_) => "", + Source::Path(path) => path, + } } /// Determine whether the CSV file(s) represented by this CsvFile instance have a header row @@ -123,22 +186,75 @@ impl TableProvider for CsvFile { _filters: &[Expr], limit: Option, ) -> Result> { - Ok(Arc::new(CsvExec::try_new( - &self.path, - CsvReadOptions::new() - .schema(&self.schema) - .has_header(self.has_header) - .delimiter(self.delimiter) - .file_extension(self.file_extension.as_str()), - projection.clone(), - limit - .map(|l| std::cmp::min(l, batch_size)) - .unwrap_or(batch_size), - limit, - )?)) + let opts = CsvReadOptions::new() + .schema(&self.schema) + .has_header(self.has_header) + .delimiter(self.delimiter) + .file_extension(self.file_extension.as_str()); + let batch_size = limit + .map(|l| std::cmp::min(l, batch_size)) + .unwrap_or(batch_size); + + let exec = match &self.source { + Source::Reader(maybe_reader) => { + if let Some(rdr) = maybe_reader.lock().unwrap().take() { + CsvExec::try_new_from_reader( + rdr, + opts, + projection.clone(), + batch_size, + limit, + )? + } else { + return Err(DataFusionError::Execution( + "You can only read once if the data comes from a reader" + .to_string(), + )); + } + } + Source::Path(p) => { + CsvExec::try_new(&p, opts, projection.clone(), batch_size, limit)? + } + }; + Ok(Arc::new(exec)) } fn statistics(&self) -> Statistics { self.statistics.clone() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::prelude::*; + + #[tokio::test] + async fn csv_file_from_reader() -> Result<()> { + let testdata = arrow::util::test_util::arrow_test_data(); + let filename = "aggregate_test_100.csv"; + let path = format!("{}/csv/{}", testdata, filename); + let buf = std::fs::read(path).unwrap(); + let rdr = std::io::Cursor::new(buf); + let mut ctx = ExecutionContext::new(); + ctx.register_table( + "aggregate_test", + Arc::new(CsvFile::try_new_from_reader_infer_schema( + rdr, + CsvReadOptions::new(), + )?), + )?; + let df = ctx.sql("select max(c2) from aggregate_test")?; + let batches = df.collect().await?; + assert_eq!( + batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 5 + ); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/csv.rs b/datafusion/src/physical_plan/csv.rs index 7ee5ae3fd90b0..b96a702f27325 100644 --- a/datafusion/src/physical_plan/csv.rs +++ b/datafusion/src/physical_plan/csv.rs @@ -17,12 +17,6 @@ //! Execution plan for reading CSV files -use std::any::Any; -use std::fs::File; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - use crate::error::{DataFusionError, Result}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::{common, Partitioning}; @@ -31,6 +25,13 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use futures::Stream; +use std::any::Any; +use std::fs::File; +use std::io::Read; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; +use std::task::{Context, Poll}; use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; @@ -106,13 +107,69 @@ impl<'a> CsvReadOptions<'a> { } } +/// Source represents where the data comes from. +enum Source { + /// The data comes from partitioned files + PartitionedFiles { + /// Path to directory containing partitioned files with the same schema + path: String, + /// The individual files under path + filenames: Vec, + }, + + /// The data comes from anything impl Read trait + Reader(Mutex>>), +} + +impl std::fmt::Debug for Source { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Source::PartitionedFiles { path, filenames } => f + .debug_struct("PartitionedFiles") + .field("path", path) + .field("filenames", filenames) + .finish()?, + Source::Reader(_) => f.write_str("Reader")?, + }; + Ok(()) + } +} + +impl Clone for Source { + fn clone(&self) -> Self { + match self { + Source::PartitionedFiles { path, filenames } => Self::PartitionedFiles { + path: path.clone(), + filenames: filenames.clone(), + }, + Source::Reader(_) => Self::Reader(Mutex::new(None)), + } + } +} + +impl Source { + /// Path to directory containing partitioned files with the same schema + pub fn path(&self) -> &str { + match self { + Source::PartitionedFiles { path, .. } => path.as_str(), + Source::Reader(_) => "", + } + } + + /// The individual files under path + pub fn filenames(&self) -> &[String] { + match self { + Source::PartitionedFiles { filenames, .. } => filenames, + Source::Reader(_) => &[], + } + } +} + /// Execution plan for scanning a CSV file #[derive(Debug, Clone)] pub struct CsvExec { - /// Path to directory containing partitioned CSV files with the same schema - path: String, - /// The individual files under path - filenames: Vec, + /// Where the data comes from. + source: Source, /// Schema representing the CSV file schema: SchemaRef, /// Does the CSV file have a header? @@ -163,8 +220,10 @@ impl CsvExec { }; Ok(Self { - path: path.to_string(), - filenames, + source: Source::PartitionedFiles { + path: path.to_string(), + filenames, + }, schema: Arc::new(schema), has_header: options.has_header, delimiter: Some(options.delimiter), @@ -175,15 +234,50 @@ impl CsvExec { limit, }) } + /// Create a new execution plan for reading from a reader + pub fn try_new_from_reader( + reader: impl Read + Send + Sync + 'static, + options: CsvReadOptions, + projection: Option>, + batch_size: usize, + limit: Option, + ) -> Result { + let schema = match options.schema { + Some(s) => s.clone(), + None => { + return Err(DataFusionError::Execution( + "The schema must be provided in options when reading from a reader" + .to_string(), + )); + } + }; + + let projected_schema = match &projection { + None => schema.clone(), + Some(p) => Schema::new(p.iter().map(|i| schema.field(*i).clone()).collect()), + }; + + Ok(Self { + source: Source::Reader(Mutex::new(Some(Box::new(reader)))), + schema: Arc::new(schema), + has_header: options.has_header, + delimiter: Some(options.delimiter), + file_extension: String::new(), + projection, + projected_schema: Arc::new(projected_schema), + batch_size, + limit, + }) + } /// Path to directory containing partitioned CSV files with the same schema pub fn path(&self) -> &str { - &self.path + self.source.path() } /// The individual files under path pub fn filenames(&self) -> &[String] { - &self.filenames + self.source.filenames() } /// Does the CSV file have a header? @@ -249,7 +343,10 @@ impl ExecutionPlan for CsvExec { /// Get the output partitioning of this plan fn output_partitioning(&self) -> Partitioning { - Partitioning::UnknownPartitioning(self.filenames.len()) + Partitioning::UnknownPartitioning(match &self.source { + Source::PartitionedFiles { filenames, .. } => filenames.len(), + Source::Reader(_) => 1, + }) } fn children(&self) -> Vec> { @@ -272,25 +369,51 @@ impl ExecutionPlan for CsvExec { } async fn execute(&self, partition: usize) -> Result { - Ok(Box::pin(CsvStream::try_new( - &self.filenames[partition], - self.schema.clone(), - self.has_header, - self.delimiter, - &self.projection, - self.batch_size, - self.limit, - )?)) + match &self.source { + Source::PartitionedFiles { filenames, .. } => { + Ok(Box::pin(CsvStream::try_new( + &filenames[partition], + self.schema.clone(), + self.has_header, + self.delimiter, + &self.projection, + self.batch_size, + self.limit, + )?)) + } + Source::Reader(rdr) => { + if partition != 0 { + Err(DataFusionError::Internal( + "Only partition 0 is valid when CSV comes from a reader" + .to_string(), + )) + } else if let Some(rdr) = rdr.lock().unwrap().take() { + Ok(Box::pin(CsvStream::try_new_from_reader( + rdr, + self.schema.clone(), + self.has_header, + self.delimiter, + &self.projection, + self.batch_size, + self.limit, + )?)) + } else { + Err(DataFusionError::Execution( + "Error reading CSV: Data can only be read a single time when the source is a reader" + .to_string(), + )) + } + } + } } } /// Iterator over batches -struct CsvStream { +struct CsvStream { /// Arrow CSV reader - reader: csv::Reader, + reader: csv::Reader, } - -impl CsvStream { +impl CsvStream { /// Create an iterator for a CSV file pub fn try_new( filename: &str, @@ -302,11 +425,27 @@ impl CsvStream { limit: Option, ) -> Result { let file = File::open(filename)?; + Self::try_new_from_reader( + file, schema, has_header, delimiter, projection, batch_size, limit, + ) + } +} +impl CsvStream { + /// Create an iterator for a reader + pub fn try_new_from_reader( + reader: R, + schema: SchemaRef, + has_header: bool, + delimiter: Option, + projection: &Option>, + batch_size: usize, + limit: Option, + ) -> Result> { let start_line = if has_header { 1 } else { 0 }; let bounds = limit.map(|x| (0, x + start_line)); let reader = csv::Reader::new( - file, + reader, schema, has_header, delimiter, @@ -319,7 +458,7 @@ impl CsvStream { } } -impl Stream for CsvStream { +impl Stream for CsvStream { type Item = ArrowResult; fn poll_next( @@ -330,7 +469,7 @@ impl Stream for CsvStream { } } -impl RecordBatchStream for CsvStream { +impl RecordBatchStream for CsvStream { /// Get the schema fn schema(&self) -> SchemaRef { self.reader.schema() @@ -398,4 +537,34 @@ mod tests { assert_eq!("c3", batch_schema.field(2).name()); Ok(()) } + + #[tokio::test] + async fn csv_exec_with_reader() -> Result<()> { + let schema = aggr_test_schema(); + let testdata = arrow::util::test_util::arrow_test_data(); + let filename = "aggregate_test_100.csv"; + let path = format!("{}/csv/{}", testdata, filename); + let buf = std::fs::read(path).unwrap(); + let rdr = std::io::Cursor::new(buf); + let csv = CsvExec::try_new_from_reader( + rdr, + CsvReadOptions::new().schema(&schema), + Some(vec![0, 2, 4]), + 1024, + None, + )?; + assert_eq!(13, csv.schema.fields().len()); + assert_eq!(3, csv.projected_schema.fields().len()); + assert_eq!(13, csv.file_schema().fields().len()); + assert_eq!(3, csv.schema().fields().len()); + let mut stream = csv.execute(0).await?; + let batch = stream.next().await.unwrap()?; + assert_eq!(3, batch.num_columns()); + let batch_schema = batch.schema(); + assert_eq!(3, batch_schema.fields().len()); + assert_eq!("c1", batch_schema.field(0).name()); + assert_eq!("c3", batch_schema.field(1).name()); + assert_eq!("c5", batch_schema.field(2).name()); + Ok(()) + } }