diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py index 651ab92ec6..72708b8324 100644 --- a/litellm/llms/ai21.py +++ b/litellm/llms/ai21.py @@ -2,7 +2,7 @@ import os, types, traceback import json from enum import Enum import requests -import time +import time, httpx from typing import Callable, Optional from litellm.utils import ModelResponse, Choices, Message import litellm @@ -11,6 +11,8 @@ class AI21Error(Exception): def __init__(self, status_code, message): self.status_code = status_code self.message = message + self.request = httpx.Request(method="POST", url="https://api.replicate.com/v1/deployments") + self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( self.message ) # Call the base class constructor with the parameters it needs diff --git a/litellm/main.py b/litellm/main.py index df04e18166..c0a8837105 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -147,7 +147,7 @@ async def acompletion(*args, **kwargs): else: # Await normally init_response = completion(*args, **kwargs) - if isinstance(init_response, dict): + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): response = init_response else: response = await init_response diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 543c3e3eaa..8b0d06d272 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -257,7 +257,7 @@ def test_acompletion_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -# test_acompletion_on_router() +test_acompletion_on_router() def test_function_calling_on_router(): try: