return azure response headers

This commit is contained in:
Ishaan Jaff 2024-07-01 17:09:06 -07:00
parent 48946f7528
commit 140f7fe254
2 changed files with 46 additions and 4 deletions

View file

@ -458,6 +458,36 @@ class AzureChatCompletion(BaseLLM):
return azure_client
async def make_azure_openai_chat_completion_request(
self,
azure_client: AsyncAzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (
await azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
else:
response = await azure_client.chat.completions.create(
**data, timeout=timeout
)
return None, response
except Exception as e:
raise e
def completion(
self,
model: str,
@ -701,8 +731,11 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
stringified_response = response.model_dump()
@ -861,9 +894,13 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data,
},
)
response = await azure_client.chat.completions.create(
**data, timeout=timeout
headers, response = await self.make_azure_openai_chat_completion_request(
azure_client=azure_client,
data=data,
timeout=timeout,
)
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,

View file

@ -658,6 +658,11 @@ class OpenAIChatCompletion(BaseLLM):
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
if litellm.return_response_headers is True:
raw_response = (