LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls

* test: add base test for `http:` url's

* fix(factory.py/get_image_details): follow redirects

allows http calls to work

* fix(codestral/): fix stream chunk parsing on last chunk of stream

* Azure ad token provider (#6917)

* Update azure.py

Added optional parameter azure ad token provider

* Added parameter to main.py

* Found token provider arg location

* Fixed embeddings

* Fixed ad token provider

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>

* fix: fix linting errors

* fix(main.py): leave out o1 route for azure ad token provider, for now

get v0 out for sync azure gpt route to begin with

* test: skip http:// test for fireworks ai

model does not support it

* refactor: cleanup dead code

* fix: revert http:// url passthrough for gemini

google ai studio raises errors

* test: fix test

---------

Co-authored-by: bahtman <anton@baht.dk>
This commit is contained in:
Krish Dholakia 2025-02-02 23:17:50 -08:00 committed by GitHub
parent 10d3da7660
commit 97b8de17ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 107 additions and 27 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import json
import os
import time
from typing import Any, Callable, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union
import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI
@ -217,7 +217,7 @@ class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, azure_ad_token):
def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
headers = {
"content-type": "application/json",
}
@ -227,6 +227,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif azure_ad_token_provider is not None:
azure_ad_token = azure_ad_token_provider()
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def _get_sync_azure_client(
@ -235,6 +239,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str],
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
model: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
@ -242,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
client_type: Literal["sync", "async"],
):
# init AzureOpenAI Client
azure_client_params = {
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
@ -259,6 +264,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -326,6 +333,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str,
api_type: str,
azure_ad_token: str,
azure_ad_token_provider: Callable,
dynamic_params: bool,
print_verbose: Callable,
timeout: Union[float, httpx.Timeout],
@ -373,6 +381,10 @@ class AzureChatCompletion(BaseLLM):
)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
@ -400,6 +412,7 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
)
@ -412,6 +425,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
dynamic_params=dynamic_params,
timeout=timeout,
client=client,
@ -428,6 +442,7 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout,
client=client,
)
@ -468,6 +483,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if (
client is None
@ -535,6 +554,7 @@ class AzureChatCompletion(BaseLLM):
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
):
@ -564,6 +584,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# setting Azure client
if client is None or dynamic_params:
@ -650,6 +672,7 @@ class AzureChatCompletion(BaseLLM):
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
):
max_retries = data.pop("max_retries", 2)
@ -675,6 +698,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params)
@ -718,6 +743,7 @@ class AzureChatCompletion(BaseLLM):
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
):
try:
@ -739,6 +765,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
@ -844,6 +872,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict,
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
max_retries: Optional[int] = None,
client=None,
aembedding=None,
@ -883,6 +912,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
## LOGGING
logging_obj.pre_call(
@ -1240,6 +1271,7 @@ class AzureChatCompletion(BaseLLM):
api_version: Optional[str] = None,
model_response: Optional[ImageResponse] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
aimg_generation=None,
) -> ImageResponse:
@ -1266,7 +1298,7 @@ class AzureChatCompletion(BaseLLM):
)
# init AzureOpenAI Client
azure_client_params = {
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
@ -1282,6 +1314,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
@ -1342,6 +1376,7 @@ class AzureChatCompletion(BaseLLM):
max_retries: int,
timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None,
client=None,
) -> HttpxBinaryResponseContent:
@ -1358,6 +1393,7 @@ class AzureChatCompletion(BaseLLM):
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
max_retries=max_retries,
timeout=timeout,
client=client,
@ -1368,6 +1404,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,
@ -1393,6 +1430,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
@ -1403,6 +1441,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model,
max_retries=max_retries,
timeout=timeout,

View file

@ -49,6 +49,7 @@ class AzureTextCompletion(BaseLLM):
api_version: str,
api_type: str,
azure_ad_token: str,
azure_ad_token_provider: Optional[Callable],
print_verbose: Callable,
timeout,
logging_obj,
@ -170,6 +171,7 @@ class AzureTextCompletion(BaseLLM):
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
"azure_ad_token_provider": azure_ad_token_provider,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params

View file

@ -5,6 +5,7 @@ import litellm
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk
class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
@ -77,6 +78,7 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None
@ -90,7 +92,15 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
try:
chunk_data_dict = json.loads(chunk_data)
except json.JSONDecodeError:
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]

View file

@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers
if extra_headers is not None:
@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers
if extra_headers is not None:
@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version,
api_type=api_type,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915
timeout=timeout,
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
@ -5251,6 +5266,7 @@ def speech(
) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
if extra_headers:
optional_params["extra_headers"] = extra_headers
@ -5264,6 +5280,7 @@ def speech(
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
max_retries=max_retries,
timeout=timeout,

View file

@ -29,11 +29,3 @@ model_list:
litellm_settings:
callbacks: ["langsmith"]
disable_no_log_param: true
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD

View file

@ -461,8 +461,15 @@ class BaseLLMChatTest(ABC):
pass
@pytest.mark.parametrize("detail", [None, "low", "high"])
@pytest.mark.parametrize(
"image_url",
[
"http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
],
)
@pytest.mark.flaky(retries=4, delay=1)
def test_image_url(self, detail):
def test_image_url(self, detail, image_url):
litellm.set_verbose = True
from litellm.utils import supports_vision
@ -472,6 +479,10 @@ class BaseLLMChatTest(ABC):
base_completion_call_args = self.get_base_completion_call_args()
if not supports_vision(base_completion_call_args["model"], None):
pytest.skip("Model does not support image input")
elif "http://" in image_url and "fireworks_ai" in base_completion_call_args.get(
"model"
):
pytest.skip("Model does not support http:// input")
messages = [
{
@ -481,7 +492,7 @@ class BaseLLMChatTest(ABC):
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
"url": image_url,
},
},
],

View file

@ -1289,10 +1289,11 @@ def test_process_gemini_image_http_url(
http_url: Test HTTP URL
mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function
mock_blob: Mocked BlobType instance
Vertex AI supports image urls. Ensure no network requests are made.
"""
# Arrange
expected_image_data = "..."
mock_convert_url_to_base64.return_value = expected_image_data
# Act
result = _process_gemini_image(http_url)
# assert result["file_data"]["file_uri"] == http_url

View file

@ -205,20 +205,29 @@ def anthropic_messages():
]
def test_anthropic_vertex_ai_prompt_caching(anthropic_messages):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_anthropic_vertex_ai_prompt_caching(anthropic_messages, sync_mode):
litellm._turn_on_debug()
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
load_vertex_ai_credentials()
client = HTTPHandler()
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
try:
response = completion(
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
messages=anthropic_messages,
client=client,
)
if sync_mode:
response = completion(
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
messages=anthropic_messages,
client=client,
)
else:
response = await litellm.acompletion(
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
messages=anthropic_messages,
client=client,
)
except Exception as e:
print(f"Error: {e}")

View file

@ -730,7 +730,6 @@ def test_stream_chunk_builder_openai_audio_output_usage():
usage_dict == response_usage_dict
), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}"
def test_stream_chunk_builder_empty_initial_chunk():
from litellm.litellm_core_utils.streaming_chunk_builder_utils import (
ChunkProcessor,