fix(return-openai-compatible-headers): v0 is openai, azure, anthropic

Fixes https://github.com/BerriAI/litellm/issues/5957
This commit is contained in:
Krrish Dholakia 2024-09-28 16:41:40 -07:00
parent 5222fc8e1b
commit 498e14ba59
7 changed files with 146 additions and 82 deletions

View file

@ -3,20 +3,16 @@ import json
import os
import time
import types
import uuid
from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union
import httpx # type: ignore
import requests
from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
from typing_extensions import overload
import litellm
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
CustomStreamWrapper,
@ -28,13 +24,6 @@ from litellm.utils import (
)
from ...types.llms.openai import (
Assistant,
AssistantEventHandler,
AssistantStreamManager,
AssistantToolParam,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AsyncCursorPage,
Batch,
CancelBatchRequest,
ChatCompletionToolChoiceFunctionParam,
@ -43,15 +32,10 @@ from ...types.llms.openai import (
ChatCompletionToolParamFunctionChunk,
CreateBatchRequest,
HttpxBinaryResponseContent,
MessageData,
OpenAICreateThreadParamsMessage,
OpenAIMessage,
RetrieveBatchRequest,
Run,
SyncCursorPage,
Thread,
)
from ..base import BaseLLM
from .common_utils import process_azure_headers
azure_ad_cache = DualCache()
@ -761,6 +745,7 @@ class AzureChatCompletion(BaseLLM):
response_object=stringified_response,
model_response_object=model_response,
convert_tool_call_to_json_mode=json_mode,
_response_headers=headers,
)
except AzureOpenAIError as e:
raise e
@ -953,6 +938,7 @@ class AzureChatCompletion(BaseLLM):
model=model,
custom_llm_provider="azure",
logging_obj=logging_obj,
_response_headers=process_azure_headers(headers),
)
return streamwrapper

View file

@ -0,0 +1,26 @@
from typing import Union
import httpx
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"x-ratelimit-limit-requests"
]
if "x-ratelimit-remaining-requests" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"x-ratelimit-remaining-requests"
]
if "x-ratelimit-limit-tokens" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"]
if "x-ratelimit-remaining-tokens" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"x-ratelimit-remaining-tokens"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
return {**llm_response_headers, **openai_headers}

View file

@ -10,7 +10,7 @@ import traceback
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
@ -29,70 +29,28 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.types.llms.anthropic import (
AnthopicMessagesAssistantMessageParam,
AnthropicChatCompletionUsageBlock,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockToolUse,
AnthropicResponseUsageBlock,
AnthropicSystemMessageContent,
ContentBlockDelta,
ContentBlockStart,
ContentBlockStop,
ContentJsonBlockDelta,
ContentTextBlockDelta,
MessageBlockDelta,
MessageDelta,
MessageStartBlock,
UsageDelta,
)
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionImageObject,
ChatCompletionImageUrlObject,
ChatCompletionRequest,
ChatCompletionResponseMessage,
ChatCompletionSystemMessage,
ChatCompletionTextObject,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolChoiceValues,
ChatCompletionToolMessage,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock,
ChatCompletionUserMessage,
OpenAIMessageContent,
)
from litellm.types.utils import Choices, GenericStreamingChunk
from litellm.types.utils import GenericStreamingChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ...base import BaseLLM
from ...prompt_templates.factory import (
anthropic_messages_pt,
custom_prompt,
prompt_factory,
)
from ..common_utils import AnthropicError
from ..common_utils import AnthropicError, process_anthropic_headers
from .transformation import AnthropicConfig
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
# constants from https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_constants.py
# makes headers for API call
def validate_environment(
api_key, user_headers, model, messages: List[AllMessageValues]
@ -130,7 +88,7 @@ async def make_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
):
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_aclient
@ -166,7 +124,7 @@ async def make_call(
additional_args={"complete_input_dict": data},
)
return completion_stream
return completion_stream, response.headers
def make_sync_call(
@ -178,7 +136,7 @@ def make_sync_call(
messages: list,
logging_obj,
timeout: Optional[Union[float, httpx.Timeout]],
):
) -> Tuple[Any, httpx.Headers]:
if client is None:
client = litellm.module_level_client # re-use a module level client
@ -222,7 +180,7 @@ def make_sync_call(
additional_args={"complete_input_dict": data},
)
return completion_stream
return completion_stream, response.headers
class AnthropicChatCompletion(BaseLLM):
@ -244,14 +202,10 @@ class AnthropicChatCompletion(BaseLLM):
encoding,
json_mode: bool,
) -> ModelResponse:
_hidden_params = {}
_response_headers = dict(response.headers)
if _response_headers is not None:
llm_response_headers = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
}
_hidden_params["additional_headers"] = llm_response_headers
_hidden_params: Dict = {}
_hidden_params["additional_headers"] = process_anthropic_headers(
dict(response.headers)
)
## LOGGING
logging_obj.post_call(
input=messages,
@ -370,7 +324,7 @@ class AnthropicChatCompletion(BaseLLM):
):
data["stream"] = True
completion_stream = await make_call(
completion_stream, headers = await make_call(
client=client,
api_base=api_base,
headers=headers,
@ -385,6 +339,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
_response_headers=process_anthropic_headers(headers),
)
return streamwrapper
@ -558,7 +513,7 @@ class AnthropicChatCompletion(BaseLLM):
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
completion_stream = make_sync_call(
completion_stream, headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
@ -573,6 +528,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
custom_llm_provider="anthropic",
logging_obj=logging_obj,
_response_headers=process_anthropic_headers(headers),
)
else:

View file

@ -2,7 +2,7 @@
This file contains common utils for anthropic calls.
"""
from typing import Optional
from typing import Optional, Union
import httpx
@ -24,3 +24,30 @@ class AnthropicError(Exception):
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "anthropic-ratelimit-requests-limit" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"anthropic-ratelimit-requests-limit"
]
if "anthropic-ratelimit-requests-remaining" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"anthropic-ratelimit-requests-remaining"
]
if "anthropic-ratelimit-tokens-limit" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers[
"anthropic-ratelimit-tokens-limit"
]
if "anthropic-ratelimit-tokens-remaining" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"anthropic-ratelimit-tokens-remaining"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
additional_headers = {**llm_response_headers, **openai_headers}
return additional_headers

View file

@ -1355,3 +1355,13 @@ class CustomStreamingDecoder:
class StandardPassThroughResponseObject(TypedDict):
response: str
OPENAI_RESPONSE_HEADERS = [
"x-ratelimit-remaining-requests",
"x-ratelimit-remaining-tokens",
"x-ratelimit-limit-requests",
"x-ratelimit-limit-tokens",
"x-ratelimit-reset-requests",
"x-ratelimit-reset-tokens",
]

View file

@ -83,6 +83,7 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import (
OPENAI_RESPONSE_HEADERS,
CallTypes,
ChatCompletionDeltaToolCall,
Choices,
@ -5760,13 +5761,35 @@ def convert_to_model_response_object(
received_args = locals()
if _response_headers is not None:
openai_headers = {}
if "x-ratelimit-limit-requests" in _response_headers:
openai_headers["x-ratelimit-limit-requests"] = _response_headers[
"x-ratelimit-limit-requests"
]
if "x-ratelimit-remaining-requests" in _response_headers:
openai_headers["x-ratelimit-remaining-requests"] = _response_headers[
"x-ratelimit-remaining-requests"
]
if "x-ratelimit-limit-tokens" in _response_headers:
openai_headers["x-ratelimit-limit-tokens"] = _response_headers[
"x-ratelimit-limit-tokens"
]
if "x-ratelimit-remaining-tokens" in _response_headers:
openai_headers["x-ratelimit-remaining-tokens"] = _response_headers[
"x-ratelimit-remaining-tokens"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in _response_headers.items()
}
if hidden_params is not None:
hidden_params["additional_headers"] = llm_response_headers
hidden_params["additional_headers"] = {
**llm_response_headers,
**openai_headers,
}
else:
hidden_params = {"additional_headers": llm_response_headers}
hidden_params = {
"additional_headers": {**llm_response_headers, **openai_headers}
}
### CHECK IF ERROR IN RESPONSE ### - openrouter returns these in the dictionary
if (
response_object is not None
@ -6370,11 +6393,26 @@ class CustomStreamWrapper:
self._hidden_params = {
"model_id": (_model_info.get("id", None)),
} # returned as x-litellm-model-id response header in proxy
if _response_headers is not None:
openai_headers = {}
processed_headers = {}
additional_headers = {}
for k, v in _response_headers.items():
if k in OPENAI_RESPONSE_HEADERS: # return openai-compatible headers
openai_headers[k] = v
if k.startswith(
"llm_provider-"
): # return raw provider headers (incl. openai-compatible ones)
processed_headers[k] = v
else:
additional_headers["{}-{}".format("llm_provider", k)] = v
self._hidden_params["additional_headers"] = {
"{}-{}".format("llm_provider", k): v
for k, v in _response_headers.items()
**openai_headers,
**processed_headers,
**additional_headers,
}
self._response_headers = _response_headers
self.response_id = None
self.logging_loop = None

View file

@ -4543,4 +4543,25 @@ async def test_completion_ai21_chat():
}
],
)
pass
@pytest.mark.parametrize(
"model",
["gpt-4o", "azure/chatgpt-v-2", "claude-3-sonnet-20240229"], #
)
@pytest.mark.parametrize(
"stream",
[False, True],
)
def test_completion_response_ratelimit_headers(model, stream):
response = completion(
model=model,
messages=[{"role": "user", "content": "Hello world"}],
stream=stream,
)
hidden_params = response._hidden_params
additional_headers = hidden_params.get("additional_headers", {})
print(additional_headers)
assert "x-ratelimit-remaining-requests" in additional_headers
assert "x-ratelimit-remaining-tokens" in additional_headers