mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm dev 01 21 2025 p1 (#7898)
* fix(utils.py): don't pass 'anthropic-beta' header to vertex - will cause request to fail * fix(utils.py): add flag to allow user to disable filtering invalid headers ensure user can control behaviour * style(utils.py): cleanup message * test(test_utils.py): add unit test to cover invalid header filtering * fix(proxy_server.py): fix custom openapi schema generation * fix(utils.py): pass extra headers if set * fix(main.py): fix image variation to use 'client' param
This commit is contained in:
parent
b73980ecd5
commit
dec558ba4c
9 changed files with 161 additions and 21 deletions
|
@ -105,6 +105,7 @@ turn_off_message_logging: Optional[bool] = False
|
||||||
log_raw_request_response: bool = False
|
log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
redact_user_api_key_info: Optional[bool] = False
|
redact_user_api_key_info: Optional[bool] = False
|
||||||
|
filter_invalid_headers: Optional[bool] = False
|
||||||
add_user_information_to_llm_headers: Optional[bool] = (
|
add_user_information_to_llm_headers: Optional[bool] = (
|
||||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||||
)
|
)
|
||||||
|
|
|
@ -844,9 +844,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
)
|
)
|
||||||
if headers is None:
|
if headers is None:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
if extra_headers is not None:
|
|
||||||
headers.update(extra_headers)
|
|
||||||
num_retries = kwargs.get(
|
num_retries = kwargs.get(
|
||||||
"num_retries", None
|
"num_retries", None
|
||||||
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
|
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
|
||||||
|
@ -1042,9 +1039,14 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
extra_headers=extra_headers,
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra_headers = optional_params.pop("extra_headers", None)
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
if litellm.add_function_to_prompt and optional_params.get(
|
if litellm.add_function_to_prompt and optional_params.get(
|
||||||
"functions_unsupported_model", None
|
"functions_unsupported_model", None
|
||||||
): # if user opts to add it to prompt, when API doesn't support function calling
|
): # if user opts to add it to prompt, when API doesn't support function calling
|
||||||
|
@ -4670,7 +4672,7 @@ def image_variation(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
# get non-default params
|
# get non-default params
|
||||||
|
client = kwargs.get("client", None)
|
||||||
# get logging object
|
# get logging object
|
||||||
litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))
|
litellm_logging_obj = cast(LiteLLMLoggingObj, kwargs.get("litellm_logging_obj"))
|
||||||
|
|
||||||
|
@ -4744,6 +4746,7 @@ def image_variation(
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
optional_params={},
|
optional_params={},
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
# return the response
|
# return the response
|
||||||
|
|
|
@ -17,6 +17,7 @@ from litellm.proxy._types import (
|
||||||
TeamCallbackMetadata,
|
TeamCallbackMetadata,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||||
from litellm.types.services import ServiceTypes
|
from litellm.types.services import ServiceTypes
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
StandardLoggingUserAPIKeyMetadata,
|
StandardLoggingUserAPIKeyMetadata,
|
||||||
|
@ -396,6 +397,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
dict: The modified data dictionary.
|
dict: The modified data dictionary.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import llm_router, premium_user
|
from litellm.proxy.proxy_server import llm_router, premium_user
|
||||||
|
|
||||||
safe_add_api_version_from_query_params(data, request)
|
safe_add_api_version_from_query_params(data, request)
|
||||||
|
@ -626,6 +628,7 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -726,10 +729,6 @@ def add_provider_specific_headers_to_request(
|
||||||
data: dict,
|
data: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
):
|
):
|
||||||
ANTHROPIC_API_HEADERS = [
|
|
||||||
"anthropic-version",
|
|
||||||
"anthropic-beta",
|
|
||||||
]
|
|
||||||
|
|
||||||
extra_headers = data.get("extra_headers", {}) or {}
|
extra_headers = data.get("extra_headers", {}) or {}
|
||||||
|
|
||||||
|
|
|
@ -562,20 +562,71 @@ app = FastAPI(
|
||||||
|
|
||||||
### CUSTOM API DOCS [ENTERPRISE FEATURE] ###
|
### CUSTOM API DOCS [ENTERPRISE FEATURE] ###
|
||||||
# Custom OpenAPI schema generator to include only selected routes
|
# Custom OpenAPI schema generator to include only selected routes
|
||||||
def custom_openapi():
|
from fastapi.routing import APIWebSocketRoute
|
||||||
|
|
||||||
|
|
||||||
|
def get_openapi_schema():
|
||||||
if app.openapi_schema:
|
if app.openapi_schema:
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
openapi_schema = get_openapi(
|
||||||
title=app.title,
|
title=app.title,
|
||||||
version=app.version,
|
version=app.version,
|
||||||
description=app.description,
|
description=app.description,
|
||||||
routes=app.routes,
|
routes=app.routes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Find all WebSocket routes
|
||||||
|
websocket_routes = [
|
||||||
|
route for route in app.routes if isinstance(route, APIWebSocketRoute)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add each WebSocket route to the schema
|
||||||
|
for route in websocket_routes:
|
||||||
|
# Get the base path without query parameters
|
||||||
|
base_path = route.path.split("{")[0].rstrip("?")
|
||||||
|
|
||||||
|
# Extract parameters from the route
|
||||||
|
parameters = []
|
||||||
|
if hasattr(route, "dependant"):
|
||||||
|
for param in route.dependant.query_params:
|
||||||
|
parameters.append(
|
||||||
|
{
|
||||||
|
"name": param.name,
|
||||||
|
"in": "query",
|
||||||
|
"required": param.required,
|
||||||
|
"schema": {
|
||||||
|
"type": "string"
|
||||||
|
}, # You can make this more specific if needed
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
openapi_schema["paths"][base_path] = {
|
||||||
|
"get": {
|
||||||
|
"summary": f"WebSocket: {route.name or base_path}",
|
||||||
|
"description": "WebSocket connection endpoint",
|
||||||
|
"operationId": f"websocket_{route.name or base_path.replace('/', '_')}",
|
||||||
|
"parameters": parameters,
|
||||||
|
"responses": {"101": {"description": "WebSocket Protocol Switched"}},
|
||||||
|
"tags": ["WebSocket"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
openapi_schema = get_openapi_schema()
|
||||||
|
|
||||||
# Filter routes to include only specific ones
|
# Filter routes to include only specific ones
|
||||||
openai_routes = LiteLLMRoutes.openai_routes.value
|
openai_routes = LiteLLMRoutes.openai_routes.value
|
||||||
paths_to_include: dict = {}
|
paths_to_include: dict = {}
|
||||||
for route in openai_routes:
|
for route in openai_routes:
|
||||||
paths_to_include[route] = openapi_schema["paths"][route]
|
if route in openapi_schema["paths"]:
|
||||||
|
paths_to_include[route] = openapi_schema["paths"][route]
|
||||||
openapi_schema["paths"] = paths_to_include
|
openapi_schema["paths"] = paths_to_include
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
|
@ -334,3 +334,13 @@ from .openai import ChatCompletionUsageBlock
|
||||||
class AnthropicChatCompletionUsageBlock(ChatCompletionUsageBlock, total=False):
|
class AnthropicChatCompletionUsageBlock(ChatCompletionUsageBlock, total=False):
|
||||||
cache_creation_input_tokens: int
|
cache_creation_input_tokens: int
|
||||||
cache_read_input_tokens: int
|
cache_read_input_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
ANTHROPIC_API_HEADERS = {
|
||||||
|
"anthropic-version",
|
||||||
|
"anthropic-beta",
|
||||||
|
}
|
||||||
|
|
||||||
|
ANTHROPIC_API_ONLY_HEADERS = { # fails if calling anthropic on vertex ai / bedrock
|
||||||
|
"anthropic-beta",
|
||||||
|
}
|
||||||
|
|
|
@ -112,6 +112,7 @@ from litellm.router_utils.get_retry_from_policy import (
|
||||||
reset_retry_policy,
|
reset_retry_policy,
|
||||||
)
|
)
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
from litellm.types.llms.anthropic import ANTHROPIC_API_ONLY_HEADERS
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
AllPromptValues,
|
AllPromptValues,
|
||||||
|
@ -2601,6 +2602,25 @@ def _remove_unsupported_params(
|
||||||
return non_default_params
|
return non_default_params
|
||||||
|
|
||||||
|
|
||||||
|
def get_clean_extra_headers(extra_headers: dict, custom_llm_provider: str) -> dict:
|
||||||
|
"""
|
||||||
|
For `anthropic-beta` headers, ensure provider is anthropic.
|
||||||
|
|
||||||
|
Vertex AI raises an exception if `anthropic-beta` is passed in.
|
||||||
|
"""
|
||||||
|
if litellm.filter_invalid_headers is not True: # allow user to opt out of filtering
|
||||||
|
return extra_headers
|
||||||
|
clean_extra_headers = {}
|
||||||
|
for k, v in extra_headers.items():
|
||||||
|
if k in ANTHROPIC_API_ONLY_HEADERS and custom_llm_provider != "anthropic":
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Provider {custom_llm_provider} does not support {k} header. Dropping from request, to prevent errors."
|
||||||
|
) # Switching between anthropic api and vertex ai anthropic fails when anthropic-beta is passed in. Welcome feedback on this.
|
||||||
|
else:
|
||||||
|
clean_extra_headers[k] = v
|
||||||
|
return clean_extra_headers
|
||||||
|
|
||||||
|
|
||||||
def get_optional_params( # noqa: PLR0915
|
def get_optional_params( # noqa: PLR0915
|
||||||
# use the openai defaults
|
# use the openai defaults
|
||||||
# https://platform.openai.com/docs/api-reference/chat/create
|
# https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
@ -2739,6 +2759,12 @@ def get_optional_params( # noqa: PLR0915
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
## Supports anthropic headers
|
||||||
|
if extra_headers is not None:
|
||||||
|
extra_headers = get_clean_extra_headers(
|
||||||
|
extra_headers=extra_headers, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
## raise exception if function calling passed in for a provider that doesn't support it
|
## raise exception if function calling passed in for a provider that doesn't support it
|
||||||
if (
|
if (
|
||||||
"functions" in non_default_params
|
"functions" in non_default_params
|
||||||
|
@ -3508,6 +3534,12 @@ def get_optional_params( # noqa: PLR0915
|
||||||
for k in passed_params.keys():
|
for k in passed_params.keys():
|
||||||
if k not in default_params.keys():
|
if k not in default_params.keys():
|
||||||
optional_params[k] = passed_params[k]
|
optional_params[k] = passed_params[k]
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params.setdefault("extra_headers", {})
|
||||||
|
optional_params["extra_headers"] = {
|
||||||
|
**optional_params["extra_headers"],
|
||||||
|
**extra_headers,
|
||||||
|
}
|
||||||
print_verbose(f"Final returned optional params: {optional_params}")
|
print_verbose(f"Final returned optional params: {optional_params}")
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
|
@ -68,16 +68,21 @@ async def test_openai_image_variation_litellm_sdk(image_url, sync_mode):
|
||||||
await aimage_variation(image=image_url, n=2, size="1024x1024")
|
await aimage_variation(image=image_url, n=2, size="1024x1024")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False]) # ,
|
def test_topaz_image_variation(image_url):
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_topaz_image_variation(image_url, sync_mode):
|
|
||||||
from litellm import image_variation, aimage_variation
|
from litellm import image_variation, aimage_variation
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
if sync_mode:
|
client = HTTPHandler()
|
||||||
image_variation(
|
with patch.object(client, "post") as mock_post:
|
||||||
model="topaz/Standard V2", image=image_url, n=2, size="1024x1024"
|
try:
|
||||||
)
|
image_variation(
|
||||||
else:
|
model="topaz/Standard V2",
|
||||||
response = await aimage_variation(
|
image=image_url,
|
||||||
model="topaz/Standard V2", image=image_url, n=2, size="1024x1024"
|
n=2,
|
||||||
)
|
size="1024x1024",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
|
@ -1494,3 +1494,26 @@ def test_get_num_retries(num_retries):
|
||||||
"num_retries": num_retries,
|
"num_retries": num_retries,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("filter_invalid_headers", [True, False])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"custom_llm_provider, expected_result",
|
||||||
|
[("anthropic", {"anthropic-beta": "123"}), ("bedrock", {}), ("vertex_ai", {})],
|
||||||
|
)
|
||||||
|
def test_get_clean_extra_headers(
|
||||||
|
filter_invalid_headers, custom_llm_provider, expected_result, monkeypatch
|
||||||
|
):
|
||||||
|
from litellm.utils import get_clean_extra_headers
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "filter_invalid_headers", filter_invalid_headers)
|
||||||
|
|
||||||
|
if filter_invalid_headers:
|
||||||
|
assert (
|
||||||
|
get_clean_extra_headers({"anthropic-beta": "123"}, custom_llm_provider)
|
||||||
|
== expected_result
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert get_clean_extra_headers(
|
||||||
|
{"anthropic-beta": "123"}, custom_llm_provider
|
||||||
|
) == {"anthropic-beta": "123"}
|
||||||
|
|
|
@ -1479,3 +1479,19 @@ async def test_health_check_not_called_when_disabled(monkeypatch):
|
||||||
|
|
||||||
# Verify health check wasn't called
|
# Verify health check wasn't called
|
||||||
mock_prisma.health_check.assert_not_called()
|
mock_prisma.health_check.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"litellm.proxy.proxy_server.get_openapi_schema",
|
||||||
|
return_value={
|
||||||
|
"paths": {
|
||||||
|
"/new/route": {"get": {"summary": "New"}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_custom_openapi(mock_get_openapi_schema):
|
||||||
|
from litellm.proxy.proxy_server import custom_openapi
|
||||||
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
openapi_schema = custom_openapi()
|
||||||
|
assert openapi_schema is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue