From f003deb2bd355cc2c1f8a2886ed8596a62f6986a Mon Sep 17 00:00:00 2001 From: Garrett Wu Date: Tue, 31 Dec 2024 02:14:13 +0000 Subject: [PATCH] chore: fix wordings of Gemini max_retries --- bigframes/ml/llm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 2427009cf1..d42138b006 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -986,9 +986,8 @@ def predict( The default is `False`. max_retries (int, default 0): - Max number of retry rounds if any rows failed in the prediction. Each round need to make progress (has succeeded rows) to continue the next retry round. - Each round will append newly succeeded rows. When the max retry rounds is reached, the remaining failed rows will be appended to the end of the result. - + Max number of retries if the prediction for any rows failed. Each try needs to make progress (i.e. has successfully predicted rows) to continue the retry. + Each retry will append newly succeeded rows. When the max retries are reached, the remaining rows (the ones without successful predictions) will be appended to the end of the result. Returns: bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. """ @@ -1034,11 +1033,15 @@ def predict( for _ in range(max_retries + 1): df = self._bqml_model.generate_text(df_fail, options) - df_succ = df[df[_ML_GENERATE_TEXT_STATUS].str.len() == 0] - df_fail = df[df[_ML_GENERATE_TEXT_STATUS].str.len() > 0] + success = df[_ML_GENERATE_TEXT_STATUS].str.len() == 0 + df_succ = df[success] + df_fail = df[~success] if df_succ.empty: - warnings.warn("Can't make any progress, stop retrying.", RuntimeWarning) + if max_retries > 0: + warnings.warn( + "Can't make any progress, stop retrying.", RuntimeWarning + ) break df_result = (