diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 53a9d40c6e..3d11cd123e 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -244,7 +244,7 @@ def predict( Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series): - Input DataFrame or Series, which contains only one column of prompts. + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction. Prompts can include preamble, questions, suggestions, instructions, or examples. temperature (float, default 0.0): @@ -307,14 +307,10 @@ def predict( (X,) = utils.convert_to_dataframe(X) - if len(X.columns) != 1: - raise ValueError( - f"Only support one column as input. {constants.FEEDBACK_LINK}" - ) - - # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) - X = X.rename(columns={col_label: "prompt"}) + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "prompt"}) options = { "temperature": temperature, @@ -522,7 +518,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series): - Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples. + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction. Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. @@ -531,14 +527,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models (X,) = utils.convert_to_dataframe(X) - if len(X.columns) != 1: - raise ValueError( - f"Only support one column as input. {constants.FEEDBACK_LINK}" - ) - - # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) - X = X.rename(columns={col_label: "content"}) + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "content"}) options = { "flatten_json_output": True, @@ -679,7 +671,7 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series): - Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples. + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "content" column for prediction. Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. @@ -688,14 +680,10 @@ def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models (X,) = utils.convert_to_dataframe(X) - if len(X.columns) != 1: - raise ValueError( - f"Only support one column as input. {constants.FEEDBACK_LINK}" - ) - - # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) - X = X.rename(columns={col_label: "content"}) + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "content"}) options = { "flatten_json_output": True, @@ -893,7 +881,7 @@ def predict( Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series): - Input DataFrame or Series, which contains only one column of prompts. + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction. Prompts can include preamble, questions, suggestions, instructions, or examples. temperature (float, default 0.9): @@ -938,14 +926,10 @@ def predict( (X,) = utils.convert_to_dataframe(X) - if len(X.columns) != 1: - raise ValueError( - f"Only support one column as input. {constants.FEEDBACK_LINK}" - ) - - # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) - X = X.rename(columns={col_label: "prompt"}) + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "prompt"}) options = { "temperature": temperature, @@ -1181,7 +1165,7 @@ def predict( Args: X (bigframes.dataframe.DataFrame or bigframes.series.Series): - Input DataFrame or Series, which contains only one column of prompts. + Input DataFrame or Series, can contain one or more columns. If multiple columns are in the DataFrame, it must contain a "prompt" column for prediction. Prompts can include preamble, questions, suggestions, instructions, or examples. max_output_tokens (int, default 128): @@ -1222,14 +1206,10 @@ def predict( (X,) = utils.convert_to_dataframe(X) - if len(X.columns) != 1: - raise ValueError( - f"Only support one column as input. {constants.FEEDBACK_LINK}" - ) - - # BQML identified the column by name - col_label = cast(blocks.Label, X.columns[0]) - X = X.rename(columns={col_label: "prompt"}) + if len(X.columns) == 1: + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "prompt"}) options = { "max_output_tokens": max_output_tokens, diff --git a/tests/system/load/test_llm.py b/tests/system/load/test_llm.py index 1d13300115..51b45485ad 100644 --- a/tests/system/load/test_llm.py +++ b/tests/system/load/test_llm.py @@ -156,3 +156,27 @@ def test_claude3_text_generator_predict_with_params_success( utils.check_pandas_df_schema_and_index( df, columns=utils.ML_GENERATE_TEXT_OUTPUT, index=3, col_exact=False ) + + +@pytest.mark.parametrize( + "model_name", + ("claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"), +) +@pytest.mark.flaky(retries=3, delay=120) +def test_claude3_text_generator_predict_multi_col_success( + llm_text_df, model_name, session, session_us_east5, bq_connection +): + if model_name in ("claude-3-5-sonnet", "claude-3-opus"): + session = session_us_east5 + + llm_text_df["additional_col"] = 1 + claude3_text_generator_model = llm.Claude3TextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + df = claude3_text_generator_model.predict(llm_text_df).to_pandas() + utils.check_pandas_df_schema_and_index( + df, + columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"], + index=3, + col_exact=False, + ) diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 3093a36534..a4a09731a1 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -15,6 +15,7 @@ import pytest from bigframes.ml import llm +import bigframes.pandas as bpd from tests.system import utils @@ -166,6 +167,20 @@ def test_text_generator_predict_arbitrary_col_label_success( ) +@pytest.mark.flaky(retries=2) +def test_text_generator_predict_multiple_cols_success( + palm2_text_generator_model, llm_text_df: bpd.DataFrame +): + df = llm_text_df.assign(additional_col=1) + pd_df = palm2_text_generator_model.predict(df).to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"], + index=3, + col_exact=False, + ) + + @pytest.mark.flaky(retries=2) def test_text_generator_predict_with_params_success( palm2_text_generator_model, llm_text_df @@ -212,11 +227,33 @@ def test_text_embedding_generator_predict_default_params_success( model_name=model_name, connection_name=bq_connection, session=session ) df = text_embedding_model.predict(llm_text_df).to_pandas() - assert df.shape == (3, 4) - assert "ml_generate_embedding_result" in df.columns - series = df["ml_generate_embedding_result"] - value = series[0] - assert len(value) == 768 + utils.check_pandas_df_schema_and_index( + df, columns=utils.ML_GENERATE_EMBEDDING_OUTPUT, index=3, col_exact=False + ) + assert len(df["ml_generate_embedding_result"][0]) == 768 + + +@pytest.mark.parametrize( + "model_name", + ("text-embedding-004", "text-multilingual-embedding-002"), +) +@pytest.mark.flaky(retries=2) +def test_text_embedding_generator_multi_cols_predict_success( + llm_text_df: bpd.DataFrame, model_name, session, bq_connection +): + df = llm_text_df.assign(additional_col=1) + df = df.rename(columns={"prompt": "content"}) + text_embedding_model = llm.TextEmbeddingGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + pd_df = text_embedding_model.predict(df).to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=utils.ML_GENERATE_EMBEDDING_OUTPUT + ["additional_col"], + index=3, + col_exact=False, + ) + assert len(pd_df["ml_generate_embedding_result"][0]) == 768 @pytest.mark.parametrize( @@ -295,6 +332,33 @@ def test_gemini_text_generator_predict_with_params_success( ) +@pytest.mark.parametrize( + "model_name", + ( + "gemini-pro", + "gemini-1.5-pro-preview-0514", + "gemini-1.5-flash-preview-0514", + "gemini-1.5-pro-001", + "gemini-1.5-flash-001", + ), +) +@pytest.mark.flaky(retries=2) +def test_gemini_text_generator_multi_cols_predict_success( + llm_text_df: bpd.DataFrame, model_name, session, bq_connection +): + df = llm_text_df.assign(additional_col=1) + gemini_text_generator_model = llm.GeminiTextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + pd_df = gemini_text_generator_model.predict(df).to_pandas() + utils.check_pandas_df_schema_and_index( + pd_df, + columns=utils.ML_GENERATE_TEXT_OUTPUT + ["additional_col"], + index=3, + col_exact=False, + ) + + @pytest.mark.flaky(retries=2) def test_llm_palm_score(llm_fine_tune_df_default_index): model = llm.PaLM2TextGenerator(model_name="text-bison") diff --git a/tests/system/utils.py b/tests/system/utils.py index 26e3e97e24..83d0e683bc 100644 --- a/tests/system/utils.py +++ b/tests/system/utils.py @@ -50,6 +50,12 @@ "ml_generate_text_status", "prompt", ] +ML_GENERATE_EMBEDDING_OUTPUT = [ + "ml_generate_embedding_result", + "ml_generate_embedding_statistics", + "ml_generate_embedding_status", + "content", +] def skip_legacy_pandas(test):