diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 435b380ffb..99c8f09bc0 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -20,7 +20,16 @@ import functools import itertools import typing -from typing import Callable, cast, Iterable, Mapping, Optional, Sequence, Tuple +from typing import ( + AbstractSet, + Callable, + cast, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, +) import google.cloud.bigquery as bq @@ -572,8 +581,39 @@ def with_id(self, id: identifiers.ColumnId) -> ScanItem: @dataclasses.dataclass(frozen=True) class ScanList: + """ + Defines the set of columns to scan from a source, along with the variable to bind the columns to. + """ + items: typing.Tuple[ScanItem, ...] + def filter_cols( + self, + ids: AbstractSet[identifiers.ColumnId], + ) -> ScanList: + """Drop columns from the scan that except those in the 'ids' arg.""" + result = ScanList(tuple(item for item in self.items if item.id in ids)) + if len(result.items) == 0: + # We need to select something, or sql syntax breaks + result = ScanList(self.items[:1]) + return result + + def project( + self, + selections: Mapping[identifiers.ColumnId, identifiers.ColumnId], + ) -> ScanList: + """Project given ids from the scanlist, dropping previous bindings.""" + by_id = {item.id: item for item in self.items} + result = ScanList( + tuple( + by_id[old_id].with_id(new_id) for old_id, new_id in selections.items() + ) + ) + if len(result.items) == 0: + # We need to select something, or sql syntax breaks + result = ScanList((self.items[:1])) + return result + @dataclasses.dataclass(frozen=True, eq=False) class ReadLocalNode(LeafNode): @@ -675,6 +715,11 @@ def from_table(table: bq.Table, columns: Sequence[str] = ()) -> GbqTable: else tuple(table.clustering_fields), ) + def get_table_ref(self) -> bq.TableReference: + return bq.TableReference( + bq.DatasetReference(self.project_id, self.dataset_id), self.table_id + ) + @property @functools.cache def schema_by_id(self): @@ -1068,6 +1113,11 @@ def variables_introduced(self) -> int: # This operation only renames variables, doesn't actually create new ones return 0 + @property + def has_multi_referenced_ids(self) -> bool: + referenced = tuple(ref.ref.id for ref in self.input_output_pairs) + return len(referenced) != len(set(referenced)) + # TODO: Reuse parent namespace # Currently, Selection node allows renaming an reusing existing names, so it must establish a # new namespace. diff --git a/bigframes/core/rewrite/__init__.py b/bigframes/core/rewrite/__init__.py index 58730805e4..128cefe94c 100644 --- a/bigframes/core/rewrite/__init__.py +++ b/bigframes/core/rewrite/__init__.py @@ -17,6 +17,7 @@ from bigframes.core.rewrite.legacy_align import legacy_join_as_projection from bigframes.core.rewrite.order import pull_up_order from bigframes.core.rewrite.pruning import column_pruning +from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions from bigframes.core.rewrite.windows import rewrite_range_rolling @@ -31,4 +32,5 @@ "pull_up_order", "column_pruning", "rewrite_range_rolling", + "try_reduce_to_table_scan", ] diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py index 5a94f2aa40..5f4990094c 100644 --- a/bigframes/core/rewrite/pruning.py +++ b/bigframes/core/rewrite/pruning.py @@ -170,7 +170,7 @@ def prune_readlocal( node: bigframes.core.nodes.ReadLocalNode, selection: AbstractSet[identifiers.ColumnId], ) -> bigframes.core.nodes.ReadLocalNode: - new_scan_list = filter_scanlist(node.scan_list, selection) + new_scan_list = node.scan_list.filter_cols(selection) return dataclasses.replace( node, scan_list=new_scan_list, @@ -183,18 +183,5 @@ def prune_readtable( node: bigframes.core.nodes.ReadTableNode, selection: AbstractSet[identifiers.ColumnId], ) -> bigframes.core.nodes.ReadTableNode: - new_scan_list = filter_scanlist(node.scan_list, selection) + new_scan_list = node.scan_list.filter_cols(selection) return dataclasses.replace(node, scan_list=new_scan_list) - - -def filter_scanlist( - scanlist: bigframes.core.nodes.ScanList, - ids: AbstractSet[identifiers.ColumnId], -): - result = bigframes.core.nodes.ScanList( - tuple(item for item in scanlist.items if item.id in ids) - ) - if len(result.items) == 0: - # We need to select something, or stuff breaks - result = bigframes.core.nodes.ScanList(scanlist.items[:1]) - return result diff --git a/bigframes/core/rewrite/scan_reduction.py b/bigframes/core/rewrite/scan_reduction.py new file mode 100644 index 0000000000..be8db4827c --- /dev/null +++ b/bigframes/core/rewrite/scan_reduction.py @@ -0,0 +1,47 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import functools +from typing import Optional + +from bigframes.core import nodes + + +def try_reduce_to_table_scan(root: nodes.BigFrameNode) -> Optional[nodes.ReadTableNode]: + for node in root.unique_nodes(): + if not isinstance(node, (nodes.ReadTableNode, nodes.SelectionNode)): + return None + result = root.bottom_up(merge_scan) + if isinstance(result, nodes.ReadTableNode): + return result + return None + + +@functools.singledispatch +def merge_scan(node: nodes.BigFrameNode) -> nodes.BigFrameNode: + return node + + +@merge_scan.register +def _(node: nodes.SelectionNode) -> nodes.BigFrameNode: + if not isinstance(node.child, nodes.ReadTableNode): + return node + if node.has_multi_referenced_ids: + return node + + selection = { + aliased_ref.ref.id: aliased_ref.id for aliased_ref in node.input_output_pairs + } + new_scan_list = node.child.scan_list.project(selection) + return dataclasses.replace(node.child, scan_list=new_scan_list) diff --git a/bigframes/session/__init__.py b/bigframes/session/__init__.py index b94e6985c3..a27094952f 100644 --- a/bigframes/session/__init__.py +++ b/bigframes/session/__init__.py @@ -64,11 +64,10 @@ # to register new and replacement ops with the Ibis BigQuery backend. import bigframes.functions._function_session as bff_session import bigframes.functions.function as bff -from bigframes.session import bigquery_session +from bigframes.session import bigquery_session, bq_caching_executor, executor import bigframes.session._io.bigquery as bf_io_bigquery import bigframes.session.anonymous_dataset import bigframes.session.clients -import bigframes.session.executor import bigframes.session.loader import bigframes.session.metrics import bigframes.session.validation @@ -245,14 +244,12 @@ def __init__( self._temp_storage_manager = ( self._session_resource_manager or self._anon_dataset_manager ) - self._executor: bigframes.session.executor.Executor = ( - bigframes.session.executor.BigQueryCachingExecutor( - bqclient=self._clients_provider.bqclient, - bqstoragereadclient=self._clients_provider.bqstoragereadclient, - storage_manager=self._temp_storage_manager, - strictly_ordered=self._strictly_ordered, - metrics=self._metrics, - ) + self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor( + bqclient=self._clients_provider.bqclient, + bqstoragereadclient=self._clients_provider.bqstoragereadclient, + storage_manager=self._temp_storage_manager, + strictly_ordered=self._strictly_ordered, + metrics=self._metrics, ) self._loader = bigframes.session.loader.GbqDataLoader( session=self, diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py new file mode 100644 index 0000000000..10a42dab10 --- /dev/null +++ b/bigframes/session/bq_caching_executor.py @@ -0,0 +1,602 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import os +from typing import cast, Literal, Mapping, Optional, Sequence, Tuple, Union +import warnings +import weakref + +import google.api_core.exceptions +from google.cloud import bigquery +import google.cloud.bigquery.job as bq_job +import google.cloud.bigquery.table as bq_table +import google.cloud.bigquery_storage_v1 + +import bigframes.core +import bigframes.core.compile +import bigframes.core.guid +import bigframes.core.nodes as nodes +import bigframes.core.ordering as order +import bigframes.core.tree_properties as tree_properties +import bigframes.dtypes +import bigframes.exceptions as bfe +import bigframes.features +from bigframes.session import executor, read_api_execution +import bigframes.session._io.bigquery as bq_io +import bigframes.session.metrics +import bigframes.session.planner +import bigframes.session.temporary_storage + +# Max complexity that should be executed as a single query +QUERY_COMPLEXITY_LIMIT = 1e7 +# Number of times to factor out subqueries before giving up. +MAX_SUBTREE_FACTORINGS = 5 +_MAX_CLUSTER_COLUMNS = 4 +MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G + + +class BigQueryCachingExecutor(executor.Executor): + """Computes BigFrames values using BigQuery Engine. + + This executor can cache expressions. If those expressions are executed later, this session + will re-use the pre-existing results from previous executions. + + This class is not thread-safe. + """ + + def __init__( + self, + bqclient: bigquery.Client, + storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager, + bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient, + *, + strictly_ordered: bool = True, + metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None, + ): + self.bqclient = bqclient + self.storage_manager = storage_manager + self.compiler: bigframes.core.compile.SQLCompiler = ( + bigframes.core.compile.SQLCompiler() + ) + self.strictly_ordered: bool = strictly_ordered + self._cached_executions: weakref.WeakKeyDictionary[ + nodes.BigFrameNode, nodes.BigFrameNode + ] = weakref.WeakKeyDictionary() + self.metrics = metrics + self.bqstoragereadclient = bqstoragereadclient + # Simple left-to-right precedence for now + self._semi_executors = ( + read_api_execution.ReadApiSemiExecutor( + bqstoragereadclient=bqstoragereadclient, + project=self.bqclient.project, + ), + ) + + def to_sql( + self, + array_value: bigframes.core.ArrayValue, + offset_column: Optional[str] = None, + ordered: bool = False, + enable_cache: bool = True, + ) -> str: + if offset_column: + array_value, _ = array_value.promote_offsets() + node = ( + self.replace_cached_subtrees(array_value.node) + if enable_cache + else array_value.node + ) + return self.compiler.compile(node, ordered=ordered) + + def execute( + self, + array_value: bigframes.core.ArrayValue, + *, + ordered: bool = True, + use_explicit_destination: Optional[bool] = None, + page_size: Optional[int] = None, + max_results: Optional[int] = None, + ) -> executor.ExecuteResult: + if use_explicit_destination is None: + use_explicit_destination = bigframes.options.bigquery.allow_large_results + + if bigframes.options.compute.enable_multi_query_execution: + self._simplify_with_caching(array_value) + + plan = self.replace_cached_subtrees(array_value.node) + # Use explicit destination to avoid 10GB limit of temporary table + destination_table = ( + self.storage_manager.create_temp_table( + array_value.schema.to_bigquery(), cluster_cols=[] + ) + if use_explicit_destination + else None + ) + return self._execute_plan( + plan, + ordered=ordered, + page_size=page_size, + max_results=max_results, + destination=destination_table, + ) + + def export_gbq( + self, + array_value: bigframes.core.ArrayValue, + destination: bigquery.TableReference, + if_exists: Literal["fail", "replace", "append"] = "fail", + cluster_cols: Sequence[str] = [], + ): + """ + Export the ArrayValue to an existing BigQuery table. + """ + if bigframes.options.compute.enable_multi_query_execution: + self._simplify_with_caching(array_value) + + dispositions = { + "fail": bigquery.WriteDisposition.WRITE_EMPTY, + "replace": bigquery.WriteDisposition.WRITE_TRUNCATE, + "append": bigquery.WriteDisposition.WRITE_APPEND, + } + sql = self.to_sql(array_value, ordered=False) + job_config = bigquery.QueryJobConfig( + write_disposition=dispositions[if_exists], + destination=destination, + clustering_fields=cluster_cols if cluster_cols else None, + ) + # TODO(swast): plumb through the api_name of the user-facing api that + # caused this query. + _, query_job = self._run_execute_query( + sql=sql, + job_config=job_config, + ) + + has_timedelta_col = any( + t == bigframes.dtypes.TIMEDELTA_DTYPE for t in array_value.schema.dtypes + ) + + if if_exists != "append" and has_timedelta_col: + # Only update schema if this is not modifying an existing table, and the + # new table contains timedelta columns. + table = self.bqclient.get_table(destination) + table.schema = array_value.schema.to_bigquery() + self.bqclient.update_table(table, ["schema"]) + + return query_job + + def export_gcs( + self, + array_value: bigframes.core.ArrayValue, + uri: str, + format: Literal["json", "csv", "parquet"], + export_options: Mapping[str, Union[bool, str]], + ): + query_job = self.execute( + array_value, + ordered=False, + use_explicit_destination=True, + ).query_job + assert query_job is not None + result_table = query_job.destination + assert result_table is not None + export_data_statement = bq_io.create_export_data_statement( + f"{result_table.project}.{result_table.dataset_id}.{result_table.table_id}", + uri=uri, + format=format, + export_options=dict(export_options), + ) + + bq_io.start_query_with_client( + self.bqclient, + export_data_statement, + job_config=bigquery.QueryJobConfig(), + api_name=f"dataframe-to_{format.lower()}", + metrics=self.metrics, + ) + return query_job + + def dry_run( + self, array_value: bigframes.core.ArrayValue, ordered: bool = True + ) -> bigquery.QueryJob: + sql = self.to_sql(array_value, ordered=ordered) + job_config = bigquery.QueryJobConfig(dry_run=True) + query_job = self.bqclient.query(sql, job_config=job_config) + return query_job + + def peek( + self, + array_value: bigframes.core.ArrayValue, + n_rows: int, + use_explicit_destination: Optional[bool] = None, + ) -> executor.ExecuteResult: + """ + A 'peek' efficiently accesses a small number of rows in the dataframe. + """ + plan = self.replace_cached_subtrees(array_value.node) + if not tree_properties.can_fast_peek(plan): + msg = bfe.format_message("Peeking this value cannot be done efficiently.") + warnings.warn(msg) + if use_explicit_destination is None: + use_explicit_destination = bigframes.options.bigquery.allow_large_results + + destination_table = ( + self.storage_manager.create_temp_table( + array_value.schema.to_bigquery(), cluster_cols=[] + ) + if use_explicit_destination + else None + ) + + return self._execute_plan( + plan, ordered=False, destination=destination_table, peek=n_rows + ) + + def head( + self, array_value: bigframes.core.ArrayValue, n_rows: int + ) -> executor.ExecuteResult: + + maybe_row_count = self._local_get_row_count(array_value) + if (maybe_row_count is not None) and (maybe_row_count <= n_rows): + return self.execute(array_value, ordered=True) + + if not self.strictly_ordered and not array_value.node.explicitly_ordered: + # No user-provided ordering, so just get any N rows, its faster! + return self.peek(array_value, n_rows) + + plan = self.replace_cached_subtrees(array_value.node) + if not tree_properties.can_fast_head(plan): + # If can't get head fast, we are going to need to execute the whole query + # Will want to do this in a way such that the result is reusable, but the first + # N values can be easily extracted. + # This currently requires clustering on offsets. + self._cache_with_offsets(array_value) + # Get a new optimized plan after caching + plan = self.replace_cached_subtrees(array_value.node) + assert tree_properties.can_fast_head(plan) + + head_plan = generate_head_plan(plan, n_rows) + return self._execute_plan(head_plan, ordered=True) + + def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int: + # TODO: Fold row count node in and use local execution + count = self._local_get_row_count(array_value) + if count is not None: + return count + else: + row_count_plan = self.replace_cached_subtrees( + generate_row_count_plan(array_value.node) + ) + results = self._execute_plan(row_count_plan, ordered=True) + pa_table = next(results.arrow_batches()) + pa_array = pa_table.column(0) + return pa_array.tolist()[0] + + def cached( + self, + array_value: bigframes.core.ArrayValue, + *, + force: bool = False, + use_session: bool = False, + cluster_cols: Sequence[str] = (), + ) -> None: + """Write the block to a session table.""" + # use a heuristic for whether something needs to be cached + if (not force) and self._is_trivially_executable(array_value): + return + if use_session: + self._cache_with_session_awareness(array_value) + else: + self._cache_with_cluster_cols(array_value, cluster_cols=cluster_cols) + + def _local_get_row_count( + self, array_value: bigframes.core.ArrayValue + ) -> Optional[int]: + # optimized plan has cache materializations which will have row count metadata + # that is more likely to be usable than original leaf nodes. + plan = self.replace_cached_subtrees(array_value.node) + return tree_properties.row_count(plan) + + # Helpers + def _run_execute_query( + self, + sql: str, + job_config: Optional[bq_job.QueryJobConfig] = None, + api_name: Optional[str] = None, + page_size: Optional[int] = None, + max_results: Optional[int] = None, + query_with_job: bool = True, + ) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]: + """ + Starts BigQuery query job and waits for results. + """ + job_config = bq_job.QueryJobConfig() if job_config is None else job_config + if bigframes.options.compute.maximum_bytes_billed is not None: + job_config.maximum_bytes_billed = ( + bigframes.options.compute.maximum_bytes_billed + ) + + if not self.strictly_ordered: + job_config.labels["bigframes-mode"] = "unordered" + + # Note: add_and_trim_labels is global scope which may have unexpected effects + # Ensure no additional labels are added to job_config after this point, + # as `add_and_trim_labels` ensures the label count does not exceed 64. + bq_io.add_and_trim_labels(job_config, api_name=api_name) + try: + iterator, query_job = bq_io.start_query_with_client( + self.bqclient, + sql, + job_config=job_config, + api_name=api_name, + max_results=max_results, + page_size=page_size, + metrics=self.metrics, + query_with_job=query_with_job, + ) + return iterator, query_job + + except google.api_core.exceptions.BadRequest as e: + # Unfortunately, this error type does not have a separate error code or exception type + if "Resources exceeded during query execution" in e.message: + new_message = "Computation is too complex to execute as a single query. Try using DataFrame.cache() on intermediate results, or setting bigframes.options.compute.enable_multi_query_execution." + raise bigframes.exceptions.QueryComplexityError(new_message) from e + else: + raise + + def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode: + return nodes.top_down(node, lambda x: self._cached_executions.get(x, x)) + + def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): + """ + Can the block be evaluated very cheaply? + If True, the array_value probably is not worth caching. + """ + # Once rewriting is available, will want to rewrite before + # evaluating execution cost. + return tree_properties.is_trivially_executable( + self.replace_cached_subtrees(array_value.node) + ) + + def _cache_with_cluster_cols( + self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str] + ): + """Executes the query and uses the resulting table to rewrite future executions.""" + + sql, schema, ordering_info = self.compiler.compile_raw( + self.replace_cached_subtrees(array_value.node) + ) + tmp_table = self._sql_as_cached_temp_table( + sql, + schema, + cluster_cols=bq_io.select_cluster_cols(schema, cluster_cols), + ) + cached_replacement = array_value.as_cached( + cache_table=self.bqclient.get_table(tmp_table), + ordering=ordering_info, + ).node + self._cached_executions[array_value.node] = cached_replacement + + def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): + """Executes the query and uses the resulting table to rewrite future executions.""" + offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") + w_offsets, offset_column = array_value.promote_offsets() + sql = self.compiler.compile( + self.replace_cached_subtrees(w_offsets.node), ordered=False + ) + + tmp_table = self._sql_as_cached_temp_table( + sql, + w_offsets.schema.to_bigquery(), + cluster_cols=[offset_column], + ) + cached_replacement = array_value.as_cached( + cache_table=self.bqclient.get_table(tmp_table), + ordering=order.TotalOrdering.from_offset_col(offset_column), + ).node + self._cached_executions[array_value.node] = cached_replacement + + def _cache_with_session_awareness( + self, + array_value: bigframes.core.ArrayValue, + ) -> None: + session_forest = [obj._block._expr.node for obj in array_value.session.objects] + # These node types are cheap to re-compute + target, cluster_cols = bigframes.session.planner.session_aware_cache_plan( + array_value.node, list(session_forest) + ) + cluster_cols_sql_names = [id.sql for id in cluster_cols] + if len(cluster_cols) > 0: + self._cache_with_cluster_cols( + bigframes.core.ArrayValue(target), cluster_cols_sql_names + ) + elif self.strictly_ordered: + self._cache_with_offsets(bigframes.core.ArrayValue(target)) + else: + self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), []) + + def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): + """Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces.""" + # Apply existing caching first + for _ in range(MAX_SUBTREE_FACTORINGS): + node_with_cache = self.replace_cached_subtrees(array_value.node) + if node_with_cache.planning_complexity < QUERY_COMPLEXITY_LIMIT: + return + + did_cache = self._cache_most_complex_subtree(array_value.node) + if not did_cache: + return + + def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool: + # TODO: If query fails, retry with lower complexity limit + selection = tree_properties.select_cache_target( + node, + min_complexity=(QUERY_COMPLEXITY_LIMIT / 500), + max_complexity=QUERY_COMPLEXITY_LIMIT, + cache=dict(self._cached_executions), + # Heuristic: subtree_compleixty * (copies of subtree)^2 + heuristic=lambda complexity, count: math.log(complexity) + + 2 * math.log(count), + ) + if selection is None: + # No good subtrees to cache, just return original tree + return False + + self._cache_with_cluster_cols(bigframes.core.ArrayValue(selection), []) + return True + + def _sql_as_cached_temp_table( + self, + sql: str, + schema: Sequence[bigquery.SchemaField], + cluster_cols: Sequence[str], + ) -> bigquery.TableReference: + assert len(cluster_cols) <= _MAX_CLUSTER_COLUMNS + temp_table = self.storage_manager.create_temp_table(schema, cluster_cols) + + # TODO: Get default job config settings + job_config = cast( + bigquery.QueryJobConfig, + bigquery.QueryJobConfig.from_api_repr({}), + ) + job_config.destination = temp_table + _, query_job = self._run_execute_query( + sql, + job_config=job_config, + api_name="cached", + ) + assert query_job is not None + query_job.result() + return query_job.destination + + def _validate_result_schema( + self, + array_value: bigframes.core.ArrayValue, + bq_schema: list[bigquery.SchemaField], + ): + actual_schema = _sanitize(tuple(bq_schema)) + ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema( + self.replace_cached_subtrees(array_value.node) + ).to_bigquery() + internal_schema = _sanitize(array_value.schema.to_bigquery()) + if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable: + return + + if internal_schema != actual_schema: + raise ValueError( + f"This error should only occur while testing. BigFrames internal schema: {internal_schema} does not match actual schema: {actual_schema}" + ) + + if ibis_schema != actual_schema: + raise ValueError( + f"This error should only occur while testing. Ibis schema: {ibis_schema} does not match actual schema: {actual_schema}" + ) + + def _execute_plan( + self, + plan: nodes.BigFrameNode, + ordered: bool, + page_size: Optional[int] = None, + max_results: Optional[int] = None, + destination: Optional[bq_table.TableReference] = None, + peek: Optional[int] = None, + ): + """Just execute whatever plan as is, without further caching or decomposition.""" + + # First try to execute fast-paths + # TODO: Allow page_size and max_results by rechunking/truncating results + if (not page_size) and (not max_results) and (not destination) and (not peek): + for semi_executor in self._semi_executors: + maybe_result = semi_executor.execute(plan, ordered=ordered) + if maybe_result: + return maybe_result + + # TODO(swast): plumb through the api_name of the user-facing api that + # caused this query. + job_config = bigquery.QueryJobConfig() + # Use explicit destination to avoid 10GB limit of temporary table + if destination is not None: + job_config.destination = destination + sql = self.compiler.compile(plan, ordered=ordered, limit=peek) + iterator, query_job = self._run_execute_query( + sql=sql, + job_config=job_config, + page_size=page_size, + max_results=max_results, + query_with_job=(destination is not None), + ) + + # Though we provide the read client, iterator may or may not use it based on what is efficient for the result + def iterator_supplier(): + # Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154 + if iterator._page_size is not None or iterator.max_results is not None: + return iterator.to_arrow_iterable(bqstorage_client=None) + else: + return iterator.to_arrow_iterable( + bqstorage_client=self.bqstoragereadclient + ) + + if query_job: + size_bytes = self.bqclient.get_table(query_job.destination).num_bytes + else: + size_bytes = None + + if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES: + msg = bfe.format_message( + "The query result size has exceeded 10 GB. In BigFrames 2.0 and " + "later, you might need to manually set `allow_large_results=True` in " + "the IO method or adjust the BigFrames option: " + "`bigframes.options.bigquery.allow_large_results=True`." + ) + warnings.warn(msg, FutureWarning) + # Runs strict validations to ensure internal type predictions and ibis are completely in sync + # Do not execute these validations outside of testing suite. + if "PYTEST_CURRENT_TEST" in os.environ: + self._validate_result_schema( + bigframes.core.ArrayValue(plan), iterator.schema + ) + + return executor.ExecuteResult( + arrow_batches=iterator_supplier, + schema=plan.schema, + query_job=query_job, + total_bytes=size_bytes, + total_rows=iterator.total_rows, + ) + + +def _sanitize( + schema: Tuple[bigquery.SchemaField, ...] +) -> Tuple[bigquery.SchemaField, ...]: + # Schema inferred from SQL strings and Ibis expressions contain only names, types and modes, + # so we disregard other fields (e.g timedelta description for timedelta columns) for validations. + return tuple( + bigquery.SchemaField( + f.name, + f.field_type, + f.mode, # type:ignore + fields=_sanitize(f.fields), + ) + for f in schema + ) + + +def generate_head_plan(node: nodes.BigFrameNode, n: int): + return nodes.SliceNode(node, start=None, stop=n) + + +def generate_row_count_plan(node: nodes.BigFrameNode): + return nodes.RowCountNode(node) diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index 150122b7dd..4c27c25058 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -16,53 +16,13 @@ import abc import dataclasses -import math -import os -from typing import ( - Callable, - cast, - Iterator, - Literal, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) -import warnings -import weakref +from typing import Callable, Iterator, Literal, Mapping, Optional, Sequence, Union -import google.api_core.exceptions from google.cloud import bigquery -import google.cloud.bigquery.job as bq_job -import google.cloud.bigquery.table as bq_table -import google.cloud.bigquery_storage_v1 import pyarrow import bigframes.core -import bigframes.core.compile -import bigframes.core.guid -import bigframes.core.identifiers -import bigframes.core.nodes as nodes -import bigframes.core.ordering as order import bigframes.core.schema -import bigframes.core.tree_properties as tree_properties -import bigframes.dtypes -import bigframes.exceptions as bfe -import bigframes.features -import bigframes.session._io.bigquery as bq_io -import bigframes.session.metrics -import bigframes.session.planner -import bigframes.session.temporary_storage - -# Max complexity that should be executed as a single query -QUERY_COMPLEXITY_LIMIT = 1e7 -# Number of times to factor out subqueries before giving up. -MAX_SUBTREE_FACTORINGS = 5 -_MAX_CLUSTER_COLUMNS = 4 -# TODO: b/338258028 Enable pruning to reduce text size. -ENABLE_PRUNING = False -MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G @dataclasses.dataclass(frozen=True) @@ -181,536 +141,3 @@ def cached( cluster_cols: Sequence[str] = (), ) -> None: raise NotImplementedError("cached not implemented for this executor") - - -class BigQueryCachingExecutor(Executor): - """Computes BigFrames values using BigQuery Engine. - - This executor can cache expressions. If those expressions are executed later, this session - will re-use the pre-existing results from previous executions. - - This class is not thread-safe. - """ - - def __init__( - self, - bqclient: bigquery.Client, - storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager, - bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient, - *, - strictly_ordered: bool = True, - metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None, - ): - self.bqclient = bqclient - self.storage_manager = storage_manager - self.compiler: bigframes.core.compile.SQLCompiler = ( - bigframes.core.compile.SQLCompiler() - ) - self.strictly_ordered: bool = strictly_ordered - self._cached_executions: weakref.WeakKeyDictionary[ - nodes.BigFrameNode, nodes.BigFrameNode - ] = weakref.WeakKeyDictionary() - self.metrics = metrics - self.bqstoragereadclient = bqstoragereadclient - - def to_sql( - self, - array_value: bigframes.core.ArrayValue, - offset_column: Optional[str] = None, - ordered: bool = False, - enable_cache: bool = True, - ) -> str: - if offset_column: - array_value, _ = array_value.promote_offsets() - node = ( - self.replace_cached_subtrees(array_value.node) - if enable_cache - else array_value.node - ) - return self.compiler.compile(node, ordered=ordered) - - def execute( - self, - array_value: bigframes.core.ArrayValue, - *, - ordered: bool = True, - use_explicit_destination: Optional[bool] = None, - page_size: Optional[int] = None, - max_results: Optional[int] = None, - ): - if use_explicit_destination is None: - use_explicit_destination = bigframes.options.bigquery.allow_large_results - - if bigframes.options.compute.enable_multi_query_execution: - self._simplify_with_caching(array_value) - - sql = self.to_sql(array_value, ordered=ordered) - job_config = bigquery.QueryJobConfig() - # Use explicit destination to avoid 10GB limit of temporary table - if use_explicit_destination: - destination_table = self.storage_manager.create_temp_table( - array_value.schema.to_bigquery(), cluster_cols=[] - ) - job_config.destination = destination_table - # TODO(swast): plumb through the api_name of the user-facing api that - # caused this query. - iterator, query_job = self._run_execute_query( - sql=sql, - job_config=job_config, - page_size=page_size, - max_results=max_results, - query_with_job=use_explicit_destination, - ) - - # Though we provide the read client, iterator may or may not use it based on what is efficient for the result - def iterator_supplier(): - # Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154 - if iterator._page_size is not None or iterator.max_results is not None: - return iterator.to_arrow_iterable(bqstorage_client=None) - else: - return iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient - ) - - if query_job: - size_bytes = self.bqclient.get_table(query_job.destination).num_bytes - else: - size_bytes = None - - if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES: - msg = bfe.format_message( - "The query result size has exceeded 10 GB. In BigFrames 2.0 and " - "later, you might need to manually set `allow_large_results=True` in " - "the IO method or adjust the BigFrames option: " - "`bigframes.options.bigquery.allow_large_results=True`." - ) - warnings.warn(msg, FutureWarning) - # Runs strict validations to ensure internal type predictions and ibis are completely in sync - # Do not execute these validations outside of testing suite. - if "PYTEST_CURRENT_TEST" in os.environ: - self._validate_result_schema(array_value, iterator.schema) - - return ExecuteResult( - arrow_batches=iterator_supplier, - schema=array_value.schema, - query_job=query_job, - total_bytes=size_bytes, - total_rows=iterator.total_rows, - ) - - def export_gbq( - self, - array_value: bigframes.core.ArrayValue, - destination: bigquery.TableReference, - if_exists: Literal["fail", "replace", "append"] = "fail", - cluster_cols: Sequence[str] = [], - ): - """ - Export the ArrayValue to an existing BigQuery table. - """ - if bigframes.options.compute.enable_multi_query_execution: - self._simplify_with_caching(array_value) - - dispositions = { - "fail": bigquery.WriteDisposition.WRITE_EMPTY, - "replace": bigquery.WriteDisposition.WRITE_TRUNCATE, - "append": bigquery.WriteDisposition.WRITE_APPEND, - } - sql = self.to_sql(array_value, ordered=False) - job_config = bigquery.QueryJobConfig( - write_disposition=dispositions[if_exists], - destination=destination, - clustering_fields=cluster_cols if cluster_cols else None, - ) - # TODO(swast): plumb through the api_name of the user-facing api that - # caused this query. - _, query_job = self._run_execute_query( - sql=sql, - job_config=job_config, - ) - - has_timedelta_col = any( - t == bigframes.dtypes.TIMEDELTA_DTYPE for t in array_value.schema.dtypes - ) - - if if_exists != "append" and has_timedelta_col: - # Only update schema if this is not modifying an existing table, and the - # new table contains timedelta columns. - table = self.bqclient.get_table(destination) - table.schema = array_value.schema.to_bigquery() - self.bqclient.update_table(table, ["schema"]) - - return query_job - - def export_gcs( - self, - array_value: bigframes.core.ArrayValue, - uri: str, - format: Literal["json", "csv", "parquet"], - export_options: Mapping[str, Union[bool, str]], - ): - query_job = self.execute( - array_value, - ordered=False, - use_explicit_destination=True, - ).query_job - result_table = query_job.destination - export_data_statement = bq_io.create_export_data_statement( - f"{result_table.project}.{result_table.dataset_id}.{result_table.table_id}", - uri=uri, - format=format, - export_options=dict(export_options), - ) - - bq_io.start_query_with_client( - self.bqclient, - export_data_statement, - job_config=bigquery.QueryJobConfig(), - api_name=f"dataframe-to_{format.lower()}", - metrics=self.metrics, - ) - return query_job - - def dry_run( - self, array_value: bigframes.core.ArrayValue, ordered: bool = True - ) -> bigquery.QueryJob: - sql = self.to_sql(array_value, ordered=ordered) - job_config = bigquery.QueryJobConfig(dry_run=True) - query_job = self.bqclient.query(sql, job_config=job_config) - return query_job - - def peek( - self, - array_value: bigframes.core.ArrayValue, - n_rows: int, - use_explicit_destination: Optional[bool] = None, - ) -> ExecuteResult: - """ - A 'peek' efficiently accesses a small number of rows in the dataframe. - """ - plan = self.replace_cached_subtrees(array_value.node) - if not tree_properties.can_fast_peek(plan): - msg = bfe.format_message("Peeking this value cannot be done efficiently.") - warnings.warn(msg) - if use_explicit_destination is None: - use_explicit_destination = bigframes.options.bigquery.allow_large_results - - job_config = bigquery.QueryJobConfig() - # Use explicit destination to avoid 10GB limit of temporary table - if use_explicit_destination: - destination_table = self.storage_manager.create_temp_table( - array_value.schema.to_bigquery(), cluster_cols=[] - ) - job_config.destination = destination_table - - sql = self.compiler.compile(plan, ordered=False, limit=n_rows) - - # TODO(swast): plumb through the api_name of the user-facing api that - # caused this query. - iterator, query_job = self._run_execute_query( - sql=sql, job_config=job_config, query_with_job=use_explicit_destination - ) - return ExecuteResult( - # Probably don't need read client for small peek results, but let client decide - arrow_batches=lambda: iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient - ), - schema=array_value.schema, - query_job=query_job, - total_rows=iterator.total_rows, - ) - - def head( - self, array_value: bigframes.core.ArrayValue, n_rows: int - ) -> ExecuteResult: - - maybe_row_count = self._local_get_row_count(array_value) - if (maybe_row_count is not None) and (maybe_row_count <= n_rows): - return self.execute(array_value, ordered=True) - - if not self.strictly_ordered and not array_value.node.explicitly_ordered: - # No user-provided ordering, so just get any N rows, its faster! - return self.peek(array_value, n_rows) - - plan = self.replace_cached_subtrees(array_value.node) - if not tree_properties.can_fast_head(plan): - # If can't get head fast, we are going to need to execute the whole query - # Will want to do this in a way such that the result is reusable, but the first - # N values can be easily extracted. - # This currently requires clustering on offsets. - self._cache_with_offsets(array_value) - # Get a new optimized plan after caching - plan = self.replace_cached_subtrees(array_value.node) - assert tree_properties.can_fast_head(plan) - - head_plan = generate_head_plan(plan, n_rows) - sql = self.compiler.compile(head_plan) - - # TODO(swast): plumb through the api_name of the user-facing api that - # caused this query. - iterator, query_job = self._run_execute_query(sql=sql) - return ExecuteResult( - # Probably don't need read client for small head results, but let client decide - arrow_batches=lambda: iterator.to_arrow_iterable( - bqstorage_client=self.bqstoragereadclient - ), - schema=array_value.schema, - query_job=query_job, - total_rows=iterator.total_rows, - ) - - def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int: - count = self._local_get_row_count(array_value) - if count is not None: - return count - else: - row_count_plan = self.replace_cached_subtrees( - generate_row_count_plan(array_value.node) - ) - sql = self.compiler.compile(row_count_plan, ordered=False) - iter, _ = self._run_execute_query(sql, query_with_job=False) - return next(iter)[0] - - def cached( - self, - array_value: bigframes.core.ArrayValue, - *, - force: bool = False, - use_session: bool = False, - cluster_cols: Sequence[str] = (), - ) -> None: - """Write the block to a session table.""" - # use a heuristic for whether something needs to be cached - if (not force) and self._is_trivially_executable(array_value): - return - if use_session: - self._cache_with_session_awareness(array_value) - else: - self._cache_with_cluster_cols(array_value, cluster_cols=cluster_cols) - - def _local_get_row_count( - self, array_value: bigframes.core.ArrayValue - ) -> Optional[int]: - # optimized plan has cache materializations which will have row count metadata - # that is more likely to be usable than original leaf nodes. - plan = self.replace_cached_subtrees(array_value.node) - return tree_properties.row_count(plan) - - # Helpers - def _run_execute_query( - self, - sql: str, - job_config: Optional[bq_job.QueryJobConfig] = None, - api_name: Optional[str] = None, - page_size: Optional[int] = None, - max_results: Optional[int] = None, - query_with_job: bool = True, - ) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]: - """ - Starts BigQuery query job and waits for results. - """ - job_config = bq_job.QueryJobConfig() if job_config is None else job_config - if bigframes.options.compute.maximum_bytes_billed is not None: - job_config.maximum_bytes_billed = ( - bigframes.options.compute.maximum_bytes_billed - ) - - if not self.strictly_ordered: - job_config.labels["bigframes-mode"] = "unordered" - - # Note: add_and_trim_labels is global scope which may have unexpected effects - # Ensure no additional labels are added to job_config after this point, - # as `add_and_trim_labels` ensures the label count does not exceed 64. - bq_io.add_and_trim_labels(job_config, api_name=api_name) - try: - iterator, query_job = bq_io.start_query_with_client( - self.bqclient, - sql, - job_config=job_config, - api_name=api_name, - max_results=max_results, - page_size=page_size, - metrics=self.metrics, - query_with_job=query_with_job, - ) - return iterator, query_job - - except google.api_core.exceptions.BadRequest as e: - # Unfortunately, this error type does not have a separate error code or exception type - if "Resources exceeded during query execution" in e.message: - new_message = "Computation is too complex to execute as a single query. Try using DataFrame.cache() on intermediate results, or setting bigframes.options.compute.enable_multi_query_execution." - raise bigframes.exceptions.QueryComplexityError(new_message) from e - else: - raise - - def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode: - return nodes.top_down(node, lambda x: self._cached_executions.get(x, x)) - - def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue): - """ - Can the block be evaluated very cheaply? - If True, the array_value probably is not worth caching. - """ - # Once rewriting is available, will want to rewrite before - # evaluating execution cost. - return tree_properties.is_trivially_executable( - self.replace_cached_subtrees(array_value.node) - ) - - def _cache_with_cluster_cols( - self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str] - ): - """Executes the query and uses the resulting table to rewrite future executions.""" - - sql, schema, ordering_info = self.compiler.compile_raw( - self.replace_cached_subtrees(array_value.node) - ) - tmp_table = self._sql_as_cached_temp_table( - sql, - schema, - cluster_cols=bq_io.select_cluster_cols(schema, cluster_cols), - ) - cached_replacement = array_value.as_cached( - cache_table=self.bqclient.get_table(tmp_table), - ordering=ordering_info, - ).node - self._cached_executions[array_value.node] = cached_replacement - - def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue): - """Executes the query and uses the resulting table to rewrite future executions.""" - offset_column = bigframes.core.guid.generate_guid("bigframes_offsets") - w_offsets, offset_column = array_value.promote_offsets() - sql = self.compiler.compile( - self.replace_cached_subtrees(w_offsets.node), ordered=False - ) - - tmp_table = self._sql_as_cached_temp_table( - sql, - w_offsets.schema.to_bigquery(), - cluster_cols=[offset_column], - ) - cached_replacement = array_value.as_cached( - cache_table=self.bqclient.get_table(tmp_table), - ordering=order.TotalOrdering.from_offset_col(offset_column), - ).node - self._cached_executions[array_value.node] = cached_replacement - - def _cache_with_session_awareness( - self, - array_value: bigframes.core.ArrayValue, - ) -> None: - session_forest = [obj._block._expr.node for obj in array_value.session.objects] - # These node types are cheap to re-compute - target, cluster_cols = bigframes.session.planner.session_aware_cache_plan( - array_value.node, list(session_forest) - ) - cluster_cols_sql_names = [id.sql for id in cluster_cols] - if len(cluster_cols) > 0: - self._cache_with_cluster_cols( - bigframes.core.ArrayValue(target), cluster_cols_sql_names - ) - elif self.strictly_ordered: - self._cache_with_offsets(bigframes.core.ArrayValue(target)) - else: - self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), []) - - def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue): - """Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces.""" - # Apply existing caching first - for _ in range(MAX_SUBTREE_FACTORINGS): - node_with_cache = self.replace_cached_subtrees(array_value.node) - if node_with_cache.planning_complexity < QUERY_COMPLEXITY_LIMIT: - return - - did_cache = self._cache_most_complex_subtree(array_value.node) - if not did_cache: - return - - def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool: - # TODO: If query fails, retry with lower complexity limit - selection = tree_properties.select_cache_target( - node, - min_complexity=(QUERY_COMPLEXITY_LIMIT / 500), - max_complexity=QUERY_COMPLEXITY_LIMIT, - cache=dict(self._cached_executions), - # Heuristic: subtree_compleixty * (copies of subtree)^2 - heuristic=lambda complexity, count: math.log(complexity) - + 2 * math.log(count), - ) - if selection is None: - # No good subtrees to cache, just return original tree - return False - - self._cache_with_cluster_cols(bigframes.core.ArrayValue(selection), []) - return True - - def _sql_as_cached_temp_table( - self, - sql: str, - schema: Sequence[bigquery.SchemaField], - cluster_cols: Sequence[str], - ) -> bigquery.TableReference: - assert len(cluster_cols) <= _MAX_CLUSTER_COLUMNS - temp_table = self.storage_manager.create_temp_table(schema, cluster_cols) - - # TODO: Get default job config settings - job_config = cast( - bigquery.QueryJobConfig, - bigquery.QueryJobConfig.from_api_repr({}), - ) - job_config.destination = temp_table - _, query_job = self._run_execute_query( - sql, - job_config=job_config, - api_name="cached", - ) - assert query_job is not None - query_job.result() - return query_job.destination - - def _validate_result_schema( - self, - array_value: bigframes.core.ArrayValue, - bq_schema: list[bigquery.SchemaField], - ): - actual_schema = _sanitize(tuple(bq_schema)) - ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema( - self.replace_cached_subtrees(array_value.node) - ).to_bigquery() - internal_schema = _sanitize(array_value.schema.to_bigquery()) - if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable: - return - - if internal_schema != actual_schema: - raise ValueError( - f"This error should only occur while testing. BigFrames internal schema: {internal_schema} does not match actual schema: {actual_schema}" - ) - - if ibis_schema != actual_schema: - raise ValueError( - f"This error should only occur while testing. Ibis schema: {ibis_schema} does not match actual schema: {actual_schema}" - ) - - -def _sanitize( - schema: Tuple[bigquery.SchemaField, ...] -) -> Tuple[bigquery.SchemaField, ...]: - # Schema inferred from SQL strings and Ibis expressions contain only names, types and modes, - # so we disregard other fields (e.g timedelta description for timedelta columns) for validations. - return tuple( - bigquery.SchemaField( - f.name, - f.field_type, - f.mode, # type:ignore - fields=_sanitize(f.fields), - ) - for f in schema - ) - - -def generate_head_plan(node: nodes.BigFrameNode, n: int): - return nodes.SliceNode(node, start=None, stop=n) - - -def generate_row_count_plan(node: nodes.BigFrameNode): - return nodes.RowCountNode(node) diff --git a/bigframes/session/read_api_execution.py b/bigframes/session/read_api_execution.py new file mode 100644 index 0000000000..32095e41f4 --- /dev/null +++ b/bigframes/session/read_api_execution.py @@ -0,0 +1,100 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, Optional + +from google.cloud import bigquery_storage_v1 + +from bigframes.core import bigframe_node, rewrite +from bigframes.session import executor, semi_executor + + +class ReadApiSemiExecutor(semi_executor.SemiExecutor): + """ + Executes plans reducible to a bq table scan by directly reading the table with the read api. + """ + + def __init__( + self, bqstoragereadclient: bigquery_storage_v1.BigQueryReadClient, project: str + ): + self.bqstoragereadclient = bqstoragereadclient + self.project = project + + def execute( + self, + plan: bigframe_node.BigFrameNode, + ordered: bool, + peek: Optional[int] = None, + ) -> Optional[executor.ExecuteResult]: + node = rewrite.try_reduce_to_table_scan(plan) + if not node: + return None + if node.explicitly_ordered and ordered: + return None + if peek: + # TODO: Support peeking + return None + + import google.cloud.bigquery_storage_v1.types as bq_storage_types + from google.protobuf import timestamp_pb2 + + bq_table = node.source.table.get_table_ref() + read_options: dict[str, Any] = { + "selected_fields": [item.source_id for item in node.scan_list.items] + } + if node.source.sql_predicate: + read_options["row_restriction"] = node.source.sql_predicate + read_options = bq_storage_types.ReadSession.TableReadOptions(**read_options) + + table_mod_options = {} + if node.source.at_time: + snapshot_time = timestamp_pb2.Timestamp() + snapshot_time.FromDatetime(node.source.at_time) + table_mod_options["snapshot_time"] = snapshot_time = snapshot_time + table_mods = bq_storage_types.ReadSession.TableModifiers(**table_mod_options) + + def iterator_supplier(): + requested_session = bq_storage_types.stream.ReadSession( + table=bq_table.to_bqstorage(), + data_format=bq_storage_types.DataFormat.ARROW, + read_options=read_options, + table_modifiers=table_mods, + ) + # Single stream to maintain ordering + request = bq_storage_types.CreateReadSessionRequest( + parent=f"projects/{self.project}", + read_session=requested_session, + max_stream_count=1, + ) + session = self.bqstoragereadclient.create_read_session( + request=request, retry=None + ) + + if not session.streams: + return iter([]) + + reader = self.bqstoragereadclient.read_rows( + session.streams[0].name, retry=None + ) + rowstream = reader.rows() + return map(lambda page: page.to_arrow(), rowstream.pages) + + return executor.ExecuteResult( + arrow_batches=iterator_supplier, + schema=plan.schema, + query_job=None, + total_bytes=None, + total_rows=node.source.n_rows, + ) diff --git a/bigframes/session/semi_executor.py b/bigframes/session/semi_executor.py new file mode 100644 index 0000000000..c41d7c96d3 --- /dev/null +++ b/bigframes/session/semi_executor.py @@ -0,0 +1,33 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from typing import Optional + +from bigframes.core import bigframe_node +from bigframes.session import executor + + +# Unstable interface, in development +class SemiExecutor(abc.ABC): + """ + A semi executor executes a subset of possible plans, returns None for unsupported plans. + """ + + def execute( + self, + plan: bigframe_node.BigFrameNode, + ordered: bool, + peek: Optional[int] = None, + ) -> Optional[executor.ExecuteResult]: + raise NotImplementedError("execute not implemented for this executor") diff --git a/tests/system/large/ml/test_linear_model.py b/tests/system/large/ml/test_linear_model.py index 96215c5e47..be98902007 100644 --- a/tests/system/large/ml/test_linear_model.py +++ b/tests/system/large/ml/test_linear_model.py @@ -222,8 +222,8 @@ def test_unordered_mode_linear_regression_configure_fit_score_predict( start_execution_count = end_execution_count result = model.score(X_train, y_train).to_pandas() end_execution_count = df._block._expr.session._metrics.execution_count - # The score function and to_pandas each initiate one query. - assert end_execution_count - start_execution_count == 2 + # The score function and to_pandas reuse same result. + assert end_execution_count - start_execution_count == 1 utils.check_pandas_df_schema_and_index( result, columns=utils.ML_REGRESSION_METRICS, index=1 diff --git a/tests/unit/core/test_blocks.py b/tests/unit/core/test_blocks.py index fb5a927e76..b1b276bda3 100644 --- a/tests/unit/core/test_blocks.py +++ b/tests/unit/core/test_blocks.py @@ -20,7 +20,7 @@ import bigframes import bigframes.core.blocks as blocks -import bigframes.session.executor +import bigframes.session.bq_caching_executor @pytest.mark.parametrize( @@ -80,7 +80,7 @@ def test_block_from_local(data): expected = pandas.DataFrame(data) mock_session = mock.create_autospec(spec=bigframes.Session) mock_executor = mock.create_autospec( - spec=bigframes.session.executor.BigQueryCachingExecutor + spec=bigframes.session.bq_caching_executor.BigQueryCachingExecutor ) # hard-coded the returned dimension of the session for that each of the test case contains 3 rows.