(feat) replicate add exception mapping for streaming + better logging when polling

This commit is contained in:
ishaan-jaff 2023-11-10 12:46:31 -08:00
parent 1c1a260065
commit af98f74c82

View file

@ -147,6 +147,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
time.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = requests.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
@ -154,9 +155,16 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
if "output" in response_data:
output_string = "".join(response_data['output'])
new_output = output_string[len(previous_output):]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data['status']
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(status_code=400, message=f"Error: {replicate_error}")
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}")
# Function to extract version ID from model string
def model_to_version_id(model):