fix(return-openai-compatible-headers): v0 is openai, azure, anthropic

Fixes https://github.com/BerriAI/litellm/issues/5957
This commit is contained in:
Krrish Dholakia 2024-09-28 16:41:40 -07:00
parent 5222fc8e1b
commit 498e14ba59
7 changed files with 146 additions and 82 deletions

View file

@ -3,20 +3,16 @@ import json
import os import os
import time import time
import types import types
import uuid
from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
from typing_extensions import overload from typing_extensions import overload
import litellm import litellm
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import EmbeddingResponse from litellm.types.utils import EmbeddingResponse
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
@ -28,13 +24,6 @@ from litellm.utils import (
) )
from ...types.llms.openai import ( from ...types.llms.openai import (
Assistant,
AssistantEventHandler,
AssistantStreamManager,
AssistantToolParam,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
Batch, Batch,
CancelBatchRequest, CancelBatchRequest,
ChatCompletionToolChoiceFunctionParam, ChatCompletionToolChoiceFunctionParam,
@ -43,15 +32,10 @@ from ...types.llms.openai import (
ChatCompletionToolParamFunctionChunk, ChatCompletionToolParamFunctionChunk,
CreateBatchRequest, CreateBatchRequest,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
RetrieveBatchRequest, RetrieveBatchRequest,
Run,
SyncCursorPage,
Thread,
) )
from ..base import BaseLLM from ..base import BaseLLM
from .common_utils import process_azure_headers
azure_ad_cache = DualCache() azure_ad_cache = DualCache()
@ -761,6 +745,7 @@ class AzureChatCompletion(BaseLLM):
response_object=stringified_response, response_object=stringified_response,
model_response_object=model_response, model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode, convert_tool_call_to_json_mode=json_mode,
_response_headers=headers,
) )
except AzureOpenAIError as e: except AzureOpenAIError as e:
raise e raise e
@ -953,6 +938,7 @@ class AzureChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="azure", custom_llm_provider="azure",
logging_obj=logging_obj, logging_obj=logging_obj,
_response_headers=process_azure_headers(headers),
) )
return streamwrapper return streamwrapper

View file

@ -0,0 +1,26 @@
from typing import Union
import httpx
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"x-ratelimit-limit-requests"
]
if "x-ratelimit-remaining-requests" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"x-ratelimit-remaining-requests"
]
if "x-ratelimit-limit-tokens" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"]
if "x-ratelimit-remaining-tokens" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"x-ratelimit-remaining-tokens"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
return {**llm_response_headers, **openai_headers}

View file

@ -10,7 +10,7 @@ import traceback
import types import types
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -29,70 +29,28 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client, get_async_httpx_client,
) )
from litellm.types.llms.anthropic import ( from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
AnthropicChatCompletionUsageBlock, AnthropicChatCompletionUsageBlock,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
AnthropicSystemMessageContent,
ContentBlockDelta, ContentBlockDelta,
ContentBlockStart, ContentBlockStart,
ContentBlockStop, ContentBlockStop,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta, MessageBlockDelta,
MessageDelta,
MessageStartBlock, MessageStartBlock,
UsageDelta, UsageDelta,
) )
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionRequest,
ChatCompletionResponseMessage,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
ChatCompletionUserMessage,
OpenAIMessageContent,
) )
from litellm.types.utils import Choices, GenericStreamingChunk from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM from ...base import BaseLLM
from ...prompt_templates.factory import ( from ..common_utils import AnthropicError, process_anthropic_headers
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
from ..common_utils import AnthropicError
from .transformation import AnthropicConfig from .transformation import AnthropicConfig
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
# makes headers for API call # makes headers for API call
def validate_environment( def validate_environment(
api_key, user_headers, model, messages: List[AllMessageValues] api_key, user_headers, model, messages: List[AllMessageValues]
@ -130,7 +88,7 @@ async def make_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
): ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_aclient client = litellm.module_level_aclient
@ -166,7 +124,7 @@ async def make_call(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return completion_stream return completion_stream, response.headers
def make_sync_call( def make_sync_call(
@ -178,7 +136,7 @@ def make_sync_call(
messages: list, messages: list,
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
): ) -> Tuple[Any, httpx.Headers]:
if client is None: if client is None:
client = litellm.module_level_client # re-use a module level client client = litellm.module_level_client # re-use a module level client
@ -222,7 +180,7 @@ def make_sync_call(
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return completion_stream return completion_stream, response.headers
class AnthropicChatCompletion(BaseLLM): class AnthropicChatCompletion(BaseLLM):
@ -244,14 +202,10 @@ class AnthropicChatCompletion(BaseLLM):
encoding, encoding,
json_mode: bool, json_mode: bool,
) -> ModelResponse: ) -> ModelResponse:
_hidden_params = {} _hidden_params: Dict = {}
_response_headers = dict(response.headers) _hidden_params["additional_headers"] = process_anthropic_headers(
if _response_headers is not None: dict(response.headers)
llm_response_headers = { )
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
_hidden_params["additional_headers"] = llm_response_headers
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=messages, input=messages,
@ -370,7 +324,7 @@ class AnthropicChatCompletion(BaseLLM):
): ):
data["stream"] = True data["stream"] = True
completion_stream = await make_call( completion_stream, headers = await make_call(
client=client, client=client,
api_base=api_base, api_base=api_base,
headers=headers, headers=headers,
@ -385,6 +339,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="anthropic", custom_llm_provider="anthropic",
logging_obj=logging_obj, logging_obj=logging_obj,
_response_headers=process_anthropic_headers(headers),
) )
return streamwrapper return streamwrapper
@ -558,7 +513,7 @@ class AnthropicChatCompletion(BaseLLM):
stream is True stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format) ): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream data["stream"] = stream
completion_stream = make_sync_call( completion_stream, headers = make_sync_call(
client=client, client=client,
api_base=api_base, api_base=api_base,
headers=headers, # type: ignore headers=headers, # type: ignore
@ -573,6 +528,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
custom_llm_provider="anthropic", custom_llm_provider="anthropic",
logging_obj=logging_obj, logging_obj=logging_obj,
_response_headers=process_anthropic_headers(headers),
) )
else: else:

View file

@ -2,7 +2,7 @@
This file contains common utils for anthropic calls. This file contains common utils for anthropic calls.
""" """
from typing import Optional from typing import Optional, Union
import httpx import httpx
@ -24,3 +24,30 @@ class AnthropicError(Exception):
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "anthropic-ratelimit-requests-limit" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"anthropic-ratelimit-requests-limit"
]
if "anthropic-ratelimit-requests-remaining" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"anthropic-ratelimit-requests-remaining"
]
if "anthropic-ratelimit-tokens-limit" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers[
"anthropic-ratelimit-tokens-limit"
]
if "anthropic-ratelimit-tokens-remaining" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"anthropic-ratelimit-tokens-remaining"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
additional_headers = {**llm_response_headers, **openai_headers}
return additional_headers

View file

@ -1355,3 +1355,13 @@ class CustomStreamingDecoder:
class StandardPassThroughResponseObject(TypedDict): class StandardPassThroughResponseObject(TypedDict):
response: str response: str
OPENAI_RESPONSE_HEADERS = [
"x-ratelimit-remaining-requests",
"x-ratelimit-remaining-tokens",
"x-ratelimit-limit-requests",
"x-ratelimit-limit-tokens",
"x-ratelimit-reset-requests",
"x-ratelimit-reset-tokens",
]

View file

@ -83,6 +83,7 @@ from litellm.types.llms.openai import (
) )
from litellm.types.utils import FileTypes # type: ignore from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import ( from litellm.types.utils import (
OPENAI_RESPONSE_HEADERS,
CallTypes, CallTypes,
ChatCompletionDeltaToolCall, ChatCompletionDeltaToolCall,
Choices, Choices,
@ -5760,13 +5761,35 @@ def convert_to_model_response_object(
received_args = locals() received_args = locals()
if _response_headers is not None: if _response_headers is not None:
openai_headers = {}
if "x-ratelimit-limit-requests" in _response_headers:
openai_headers["x-ratelimit-limit-requests"] = _response_headers[
"x-ratelimit-limit-requests"
]
if "x-ratelimit-remaining-requests" in _response_headers:
openai_headers["x-ratelimit-remaining-requests"] = _response_headers[
"x-ratelimit-remaining-requests"
]
if "x-ratelimit-limit-tokens" in _response_headers:
openai_headers["x-ratelimit-limit-tokens"] = _response_headers[
"x-ratelimit-limit-tokens"
]
if "x-ratelimit-remaining-tokens" in _response_headers:
openai_headers["x-ratelimit-remaining-tokens"] = _response_headers[
"x-ratelimit-remaining-tokens"
]
llm_response_headers = { llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items() "{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
} }
if hidden_params is not None: if hidden_params is not None:
hidden_params["additional_headers"] = llm_response_headers hidden_params["additional_headers"] = {
**llm_response_headers,
**openai_headers,
}
else: else:
hidden_params = {"additional_headers": llm_response_headers} hidden_params = {
"additional_headers": {**llm_response_headers, **openai_headers}
}
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary ### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if ( if (
response_object is not None response_object is not None
@ -6370,11 +6393,26 @@ class CustomStreamWrapper:
self._hidden_params = { self._hidden_params = {
"model_id": (_model_info.get("id", None)), "model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy } # returned as x-litellm-model-id response header in proxy
if _response_headers is not None: if _response_headers is not None:
openai_headers = {}
processed_headers = {}
additional_headers = {}
for k, v in _response_headers.items():
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
openai_headers[k] = v
if k.startswith(
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v
self._hidden_params["additional_headers"] = { self._hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v **openai_headers,
for k, v in _response_headers.items() **processed_headers,
**additional_headers,
} }
self._response_headers = _response_headers self._response_headers = _response_headers
self.response_id = None self.response_id = None
self.logging_loop = None self.logging_loop = None

View file

@ -4543,4 +4543,25 @@ async def test_completion_ai21_chat():
} }
], ],
) )
pass
@pytest.mark.parametrize(
"model",
["gpt-4o", "azure/chatgpt-v-2", "claude-3-sonnet-20240229"], #
)
@pytest.mark.parametrize(
"stream",
[False, True],
)
def test_completion_response_ratelimit_headers(model, stream):
response = completion(
model=model,
messages=[{"role": "user", "content": "Hello world"}],
stream=stream,
)
hidden_params = response._hidden_params
additional_headers = hidden_params.get("additional_headers", {})
print(additional_headers)
assert "x-ratelimit-remaining-requests" in additional_headers
assert "x-ratelimit-remaining-tokens" in additional_headers