diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index 461ad9ed0..4fe233c61 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -147,6 +147,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos status = "" while True and (status not in ["succeeded", "failed", "canceled"]): time.sleep(0.5) # prevent being rate limited by replicate + print_verbose(f"replicate: polling endpoint: {prediction_url}") response = requests.get(prediction_url, headers=headers) if response.status_code == 200: response_data = response.json() @@ -154,9 +155,16 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos if "output" in response_data: output_string = "".join(response_data['output']) new_output = output_string[len(previous_output):] + print_verbose(f"New chunk: {new_output}") yield {"output": new_output, "status": status} previous_output = output_string status = response_data['status'] + if status == "failed": + replicate_error = response_data.get("error", "") + raise ReplicateError(status_code=400, message=f"Error: {replicate_error}") + else: + # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" + print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}") # Function to extract version ID from model string def model_to_version_id(model):