diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index ea063669d5..0e0bd30be3 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -508,11 +508,24 @@ def try_peek( else: return None - def to_pandas_batches(self): - """Download results one message at a time.""" + def to_pandas_batches( + self, page_size: Optional[int] = None, max_results: Optional[int] = None + ): + """Download results one message at a time. + + page_size and max_results determine the size and number of batches, + see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJob#google_cloud_bigquery_job_QueryJob_result""" dtypes = dict(zip(self.index_columns, self.index.dtypes)) dtypes.update(zip(self.value_columns, self.dtypes)) - results_iterator, _ = self.session._execute(self.expr, sorted=True) + _, query_job = self.session._query_to_destination( + self.session._to_sql(self.expr, sorted=True), + list(self.index_columns), + api_name="cached", + do_clustering=False, + ) + results_iterator = query_job.result( + page_size=page_size, max_results=max_results + ) for arrow_table in results_iterator.to_arrow_iterable( bqstorage_client=self.session.bqstoragereadclient ): diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index e404e439ab..874ef76f6e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -1215,10 +1215,30 @@ def to_pandas( self._set_internal_query_job(query_job) return df.set_axis(self._block.column_labels, axis=1, copy=False) - def to_pandas_batches(self) -> Iterable[pandas.DataFrame]: - """Stream DataFrame results to an iterable of pandas DataFrame""" + def to_pandas_batches( + self, page_size: Optional[int] = None, max_results: Optional[int] = None + ) -> Iterable[pandas.DataFrame]: + """Stream DataFrame results to an iterable of pandas DataFrame. + + page_size and max_results determine the size and number of batches, + see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJob#google_cloud_bigquery_job_QueryJob_result + + Args: + page_size (int, default None): + The size of each batch. + max_results (int, default None): + If given, only download this many rows at maximum. + + Returns: + Iterable[pandas.DataFrame]: + An iterable of smaller dataframes which combine to + form the original dataframe. Results stream from bigquery, + see https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.table.RowIterator#google_cloud_bigquery_table_RowIterator_to_arrow_iterable + """ self._optimize_query_complexity() - return self._block.to_pandas_batches() + return self._block.to_pandas_batches( + page_size=page_size, max_results=max_results + ) def _compute_dry_run(self) -> bigquery.QueryJob: return self._block._compute_dry_run() diff --git a/tests/system/load/test_large_tables.py b/tests/system/load/test_large_tables.py index cf1c787a58..f92207b191 100644 --- a/tests/system/load/test_large_tables.py +++ b/tests/system/load/test_large_tables.py @@ -75,22 +75,17 @@ def test_index_repr_large_table(): def test_to_pandas_batches_large_table(): - df = bpd.read_gbq("load_testing.scalars_10gb") - # df will be downloaded locally - expected_row_count, expected_column_count = df.shape - - row_count = 0 - # TODO(b/340890167): fix type error - for df in df.to_pandas_batches(): # type: ignore - batch_row_count, batch_column_count = df.shape + df = bpd.read_gbq("load_testing.scalars_1tb") + _, expected_column_count = df.shape + + # download only a few batches, since 1tb would be too much + iterable = df.to_pandas_batches(page_size=500, max_results=1500) + # use page size since client library doesn't support + # streaming only part of the dataframe via bqstorage + for pdf in iterable: + batch_row_count, batch_column_count = pdf.shape assert batch_column_count == expected_column_count - row_count += batch_row_count - - # Attempt to save on memory by manually removing the batch df - # from local memory after finishing with processing. - del df - - assert row_count == expected_row_count + assert batch_row_count > 0 @pytest.mark.skip(reason="See if it caused kokoro build aborted.")