test: fix test

This commit is contained in:
Krrish Dholakia 2024-08-27 10:46:57 -07:00
parent cd7dd2a511
commit 18b67a455e
3 changed files with 21 additions and 11 deletions

View file

@ -770,7 +770,7 @@ class OpenAIChatCompletion(BaseLLM):
openai_aclient: AsyncOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
@ -783,7 +783,10 @@ class OpenAIChatCompletion(BaseLLM):
)
)
headers = dict(raw_response.headers)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except OpenAIError as e:
@ -800,7 +803,7 @@ class OpenAIChatCompletion(BaseLLM):
openai_client: OpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
) -> Tuple[dict, BaseModel]:
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
@ -811,7 +814,10 @@ class OpenAIChatCompletion(BaseLLM):
**data, timeout=timeout
)
headers = dict(raw_response.headers)
if hasattr(raw_response, "headers"):
headers = dict(raw_response.headers)
else:
headers = {}
response = raw_response.parse()
return headers, response
except OpenAIError as e:

View file

@ -1637,18 +1637,19 @@ def test_completion_perplexity_api():
pydantic_obj = ChatCompletion(**response_object)
def _return_pydantic_obj(*args, **kwargs):
return pydantic_obj
new_response = MagicMock()
new_response.headers = {"hello": "world"}
print(f"pydantic_obj: {pydantic_obj}")
new_response.parse.return_value = pydantic_obj
return new_response
openai_client = OpenAI()
openai_client.chat.completions.create = MagicMock()
with patch.object(
openai_client.chat.completions, "create", side_effect=_return_pydantic_obj
openai_client.chat.completions.with_raw_response,
"create",
side_effect=_return_pydantic_obj,
) as mock_client:
pass
# litellm.set_verbose= True
messages = [
{"role": "system", "content": "You're a good bot"},

View file

@ -637,7 +637,10 @@ def client(original_function):
if is_coroutine is True:
pass
else:
if isinstance(original_response, ModelResponse):
if (
isinstance(original_response, ModelResponse)
and len(original_response.choices) > 0
):
model_response: Optional[str] = original_response.choices[
0
].message.content # type: ignore