diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index ce62e51e90..56549cfd4a 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -1,13 +1,18 @@ -import os, types +import asyncio import json -import requests # type: ignore +import os import time -from typing import Callable, Optional, Union, Tuple, Any -from litellm.utils import ModelResponse, Usage, CustomStreamWrapper -import litellm, asyncio +import types +from typing import Any, Callable, Optional, Tuple, Union + import httpx # type: ignore -from .prompt_templates.factory import prompt_factory, custom_prompt +import requests # type: ignore + +import litellm from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .prompt_templates.factory import custom_prompt, prompt_factory class ReplicateError(Exception): @@ -329,7 +334,15 @@ async def async_handle_prediction_response_streaming( response_data = response.json() status = response_data["status"] if "output" in response_data: - output_string = "".join(response_data["output"]) + try: + output_string = "".join(response_data["output"]) + except Exception as e: + raise ReplicateError( + status_code=422, + message="Unable to parse response. Got={}".format( + response_data["output"] + ), + ) new_output = output_string[len(previous_output) :] print_verbose(f"New chunk: {new_output}") yield {"output": new_output, "status": status} diff --git a/litellm/utils.py b/litellm/utils.py index ce66d0fbb0..1bc8bf771f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6068,6 +6068,14 @@ def exception_type( model=model, llm_provider="replicate", ) + elif original_exception.status_code == 422: + exception_mapping_worked = True + raise UnprocessableEntityError( + message=f"ReplicateException - {original_exception.message}", + llm_provider="replicate", + model=model, + response=original_exception.response, + ) elif original_exception.status_code == 429: exception_mapping_worked = True raise RateLimitError(