fix(openai.py): coverage for correctly re-raising exception headers on openai chat completion + embedding endpoints

This commit is contained in:
Krrish Dholakia 2024-08-24 12:55:15 -07:00
parent 068aafdff9
commit de2373d52b
3 changed files with 153 additions and 33 deletions

View file

@ -786,8 +786,14 @@ class OpenAIChatCompletion(BaseLLM):
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def make_sync_openai_chat_completion_request(
self,
@ -801,21 +807,21 @@ class OpenAIChatCompletion(BaseLLM):
- call chat.completions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = openai_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
raw_response = openai_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = openai_client.chat.completions.create(
**data, timeout=timeout
)
return None, response
except Exception as e:
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def completion(
self,
@ -1290,16 +1296,12 @@ class OpenAIChatCompletion(BaseLLM):
- call embeddings.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore
return None, response
raw_response = await openai_aclient.embeddings.with_raw_response.create(
**data, timeout=timeout
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
@ -1365,14 +1367,14 @@ class OpenAIChatCompletion(BaseLLM):
response_type="embedding",
_response_headers=headers,
) # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
except OpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
raise OpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def embedding(
self,