mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 22 2025 p4 (#7932)
* feat(main.py): add new 'provider_specific_header' param allows passing extra header for specific provider * fix(litellm_pre_call_utils.py): add unit test for pre call utils * test(test_bedrock_completion.py): skip test now that bedrock supports this
This commit is contained in:
parent
65ca5f74b0
commit
bf1639cb92
5 changed files with 119 additions and 5 deletions
|
@ -179,6 +179,7 @@ from .types.utils import (
|
||||||
HiddenParams,
|
HiddenParams,
|
||||||
LlmProviders,
|
LlmProviders,
|
||||||
PromptTokensDetails,
|
PromptTokensDetails,
|
||||||
|
ProviderSpecificHeader,
|
||||||
all_litellm_params,
|
all_litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -832,6 +833,9 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
fallbacks = kwargs.get("fallbacks", None)
|
fallbacks = kwargs.get("fallbacks", None)
|
||||||
|
provider_specific_header = cast(
|
||||||
|
Optional[ProviderSpecificHeader], kwargs.get("provider_specific_header", None)
|
||||||
|
)
|
||||||
headers = kwargs.get("headers", None) or extra_headers
|
headers = kwargs.get("headers", None) or extra_headers
|
||||||
ensure_alternating_roles: Optional[bool] = kwargs.get(
|
ensure_alternating_roles: Optional[bool] = kwargs.get(
|
||||||
"ensure_alternating_roles", None
|
"ensure_alternating_roles", None
|
||||||
|
@ -937,6 +941,13 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
provider_specific_header is not None
|
||||||
|
and provider_specific_header["custom_llm_provider"] == custom_llm_provider
|
||||||
|
):
|
||||||
|
headers.update(provider_specific_header["extra_headers"])
|
||||||
|
|
||||||
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
if model_response is not None and hasattr(model_response, "_hidden_params"):
|
||||||
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
model_response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||||
model_response._hidden_params["region_name"] = kwargs.get(
|
model_response._hidden_params["region_name"] = kwargs.get(
|
||||||
|
|
|
@ -20,6 +20,7 @@ from litellm.proxy._types import (
|
||||||
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
ProviderSpecificHeader,
|
||||||
StandardLoggingUserAPIKeyMetadata,
|
StandardLoggingUserAPIKeyMetadata,
|
||||||
SupportedCacheControls,
|
SupportedCacheControls,
|
||||||
)
|
)
|
||||||
|
@ -729,19 +730,20 @@ def add_provider_specific_headers_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
):
|
):
|
||||||
|
anthropic_headers = {}
|
||||||
extra_headers = data.get("extra_headers", {}) or {}
|
|
||||||
|
|
||||||
# boolean to indicate if a header was added
|
# boolean to indicate if a header was added
|
||||||
added_header = False
|
added_header = False
|
||||||
for header in ANTHROPIC_API_HEADERS:
|
for header in ANTHROPIC_API_HEADERS:
|
||||||
if header in headers:
|
if header in headers:
|
||||||
header_value = headers[header]
|
header_value = headers[header]
|
||||||
extra_headers.update({header: header_value})
|
anthropic_headers[header] = header_value
|
||||||
added_header = True
|
added_header = True
|
||||||
|
|
||||||
if added_header is True:
|
if added_header is True:
|
||||||
data["extra_headers"] = extra_headers
|
data["provider_specific_header"] = ProviderSpecificHeader(
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
extra_headers=anthropic_headers,
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -1658,6 +1658,7 @@ all_litellm_params = [
|
||||||
"api_key",
|
"api_key",
|
||||||
"api_version",
|
"api_version",
|
||||||
"prompt_id",
|
"prompt_id",
|
||||||
|
"provider_specific_header",
|
||||||
"prompt_variables",
|
"prompt_variables",
|
||||||
"api_base",
|
"api_base",
|
||||||
"force_timeout",
|
"force_timeout",
|
||||||
|
@ -1879,3 +1880,8 @@ class HttpHandlerRequestFields(TypedDict, total=False):
|
||||||
params: dict # query params
|
params: dict # query params
|
||||||
files: dict # file uploads
|
files: dict # file uploads
|
||||||
content: Any # raw content
|
content: Any # raw content
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderSpecificHeader(TypedDict):
|
||||||
|
custom_llm_provider: str
|
||||||
|
extra_headers: dict
|
||||||
|
|
|
@ -4550,3 +4550,33 @@ def test_deepseek_reasoning_content_completion():
|
||||||
resp.choices[0].message.provider_specific_fields["reasoning_content"]
|
resp.choices[0].message.provider_specific_fields["reasoning_content"]
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"custom_llm_provider, expected_result",
|
||||||
|
[("anthropic", {"anthropic-beta": "test"}), ("bedrock", {}), ("vertex_ai", {})],
|
||||||
|
)
|
||||||
|
def test_provider_specific_header(custom_llm_provider, expected_result):
|
||||||
|
from litellm.types.utils import ProviderSpecificHeader
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
||||||
|
try:
|
||||||
|
resp = litellm.completion(
|
||||||
|
model="anthropic/claude-3-5-sonnet-v2@20241022",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
provider_specific_header=ProviderSpecificHeader(
|
||||||
|
custom_llm_provider="anthropic",
|
||||||
|
extra_headers={"anthropic-beta": "test"},
|
||||||
|
),
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
print(mock_post.call_args.kwargs["headers"])
|
||||||
|
assert "anthropic-beta" in mock_post.call_args.kwargs["headers"]
|
||||||
|
|
|
@ -1495,3 +1495,68 @@ def test_custom_openapi(mock_get_openapi_schema):
|
||||||
|
|
||||||
openapi_schema = custom_openapi()
|
openapi_schema = custom_openapi()
|
||||||
assert openapi_schema is not None
|
assert openapi_schema is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_specific_header():
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import (
|
||||||
|
add_provider_specific_headers_to_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "gemini-1.5-flash",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
"proxy_server_request": {
|
||||||
|
"url": "http://0.0.0.0:4000/v1/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
"headers": {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
"user-agent": "PostmanRuntime/7.32.3",
|
||||||
|
"accept": "*/*",
|
||||||
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
||||||
|
"host": "0.0.0.0:4000",
|
||||||
|
"accept-encoding": "gzip, deflate, br",
|
||||||
|
"connection": "keep-alive",
|
||||||
|
"content-length": "240",
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"model": "gemini-1.5-flash",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Tell me a joke"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
"user-agent": "PostmanRuntime/7.32.3",
|
||||||
|
"accept": "*/*",
|
||||||
|
"postman-token": "81cccd87-c91d-4b2f-b252-c0fe0ca82529",
|
||||||
|
"host": "0.0.0.0:4000",
|
||||||
|
"accept-encoding": "gzip, deflate, br",
|
||||||
|
"connection": "keep-alive",
|
||||||
|
"content-length": "240",
|
||||||
|
}
|
||||||
|
|
||||||
|
add_provider_specific_headers_to_request(
|
||||||
|
data=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert data["provider_specific_header"] == {
|
||||||
|
"custom_llm_provider": "anthropic",
|
||||||
|
"extra_headers": {
|
||||||
|
"anthropic-beta": "prompt-caching-2024-07-31",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue