feat - set response headers in azure requests

This commit is contained in:
Ishaan Jaff 2024-07-01 20:12:39 -07:00
parent 107876ea46
commit 8295cd7be8

View file

@ -23,6 +23,7 @@ from typing_extensions import overload
import litellm import litellm
from litellm import OpenAIConfig from litellm import OpenAIConfig
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper, CustomStreamWrapper,
@ -500,7 +501,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: str, azure_ad_token: str,
print_verbose: Callable, print_verbose: Callable,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
logging_obj, logging_obj: LiteLLMLoggingObj,
optional_params, optional_params,
litellm_params, litellm_params,
logger_fn, logger_fn,
@ -679,9 +680,9 @@ class AzureChatCompletion(BaseLLM):
data: dict, data: dict,
timeout: Any, timeout: Any,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
logging_obj=None,
): ):
response = None response = None
try: try:
@ -737,6 +738,7 @@ class AzureChatCompletion(BaseLLM):
data=data, data=data,
timeout=timeout, timeout=timeout,
) )
logging_obj.model_call_details["response_headers"] = headers
stringified_response = response.model_dump() stringified_response = response.model_dump()
logging_obj.post_call( logging_obj.post_call(
@ -845,7 +847,7 @@ class AzureChatCompletion(BaseLLM):
async def async_streaming( async def async_streaming(
self, self,
logging_obj, logging_obj: LiteLLMLoggingObj,
api_base: str, api_base: str,
api_key: str, api_key: str,
api_version: str, api_version: str,
@ -900,6 +902,7 @@ class AzureChatCompletion(BaseLLM):
data=data, data=data,
timeout=timeout, timeout=timeout,
) )
logging_obj.model_call_details["response_headers"] = headers
# return response # return response
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(