fix(utils.py): catch 422-status errors

This commit is contained in:
Krrish Dholakia 2024-06-24 19:41:29 -07:00
parent 95f972ee9f
commit d182ea0f77
2 changed files with 28 additions and 7 deletions

View file

@ -1,13 +1,18 @@
import os, types import asyncio
import json import json
import requests # type: ignore import os
import time import time
from typing import Callable, Optional, Union, Tuple, Any import types
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper from typing import Any, Callable, Optional, Tuple, Union
import litellm, asyncio
import httpx # type: ignore import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt import requests # type: ignore
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .prompt_templates.factory import custom_prompt, prompt_factory
class ReplicateError(Exception): class ReplicateError(Exception):
@ -329,7 +334,15 @@ async def async_handle_prediction_response_streaming(
response_data = response.json() response_data = response.json()
status = response_data["status"] status = response_data["status"]
if "output" in response_data: if "output" in response_data:
try:
output_string = "".join(response_data["output"]) output_string = "".join(response_data["output"])
except Exception as e:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
response_data["output"]
),
)
new_output = output_string[len(previous_output) :] new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}") print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status} yield {"output": new_output, "status": status}

View file

@ -6068,6 +6068,14 @@ def exception_type(
model=model, model=model,
llm_provider="replicate", llm_provider="replicate",
) )
elif original_exception.status_code == 422:
exception_mapping_worked = True
raise UnprocessableEntityError(
message=f"ReplicateException - {original_exception.message}",
llm_provider="replicate",
model=model,
response=original_exception.response,
)
elif original_exception.status_code == 429: elif original_exception.status_code == 429:
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(