LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls

* test: add base test for `http:` url's

* fix(factory.py/get_image_details): follow redirects

allows http calls to work

* fix(codestral/): fix stream chunk parsing on last chunk of stream

* Azure ad token provider (#6917)

* Update azure.py

Added optional parameter azure ad token provider

* Added parameter to main.py

* Found token provider arg location

* Fixed embeddings

* Fixed ad token provider

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>

* fix: fix linting errors

* fix(main.py): leave out o1 route for azure ad token provider, for now

get v0 out for sync azure gpt route to begin with

* test: skip http:// test for fireworks ai

model does not support it

* refactor: cleanup dead code

* fix: revert http:// url passthrough for gemini

google ai studio raises errors

* test: fix test

---------

Co-authored-by: bahtman <anton@baht.dk>
This commit is contained in:
Krish Dholakia 2025-02-02 23:17:50 -08:00 committed by GitHub
parent 10d3da7660
commit 97b8de17ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 107 additions and 27 deletions

View file

@ -2,7 +2,7 @@ import asyncio
import json import json
import os import os
import time import time
from typing import Any, Callable, List, Literal, Optional, Union from typing import Any, Callable, Dict, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
@ -217,7 +217,7 @@ class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def validate_environment(self, api_key, azure_ad_token): def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
headers = { headers = {
"content-type": "application/json", "content-type": "application/json",
} }
@ -227,6 +227,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
headers["Authorization"] = f"Bearer {azure_ad_token}" headers["Authorization"] = f"Bearer {azure_ad_token}"
elif azure_ad_token_provider is not None:
azure_ad_token = azure_ad_token_provider()
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers return headers
def _get_sync_azure_client( def _get_sync_azure_client(
@ -235,6 +239,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str], api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
azure_ad_token: Optional[str], azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
model: str, model: str,
max_retries: int, max_retries: int,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -242,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
client_type: Literal["sync", "async"], client_type: Literal["sync", "async"],
): ):
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params: Dict[str, Any] = {
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
@ -259,6 +264,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None: if client is None:
if client_type == "sync": if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -326,6 +333,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
azure_ad_token_provider: Callable,
dynamic_params: bool, dynamic_params: bool,
print_verbose: Callable, print_verbose: Callable,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -373,6 +381,10 @@ class AzureChatCompletion(BaseLLM):
) )
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if acompletion is True: if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params) client = AsyncAzureOpenAI(**azure_client_params)
@ -400,6 +412,7 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key, api_key=api_key,
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
@ -412,6 +425,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version, api_version=api_version,
model=model, model=model,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
dynamic_params=dynamic_params, dynamic_params=dynamic_params,
timeout=timeout, timeout=timeout,
client=client, client=client,
@ -428,6 +442,7 @@ class AzureChatCompletion(BaseLLM):
api_key=api_key, api_key=api_key,
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
@ -468,6 +483,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if ( if (
client is None client is None
@ -535,6 +554,7 @@ class AzureChatCompletion(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None, convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
): ):
@ -564,6 +584,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# setting Azure client # setting Azure client
if client is None or dynamic_params: if client is None or dynamic_params:
@ -650,6 +672,7 @@ class AzureChatCompletion(BaseLLM):
model: str, model: str,
timeout: Any, timeout: Any,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None, client=None,
): ):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
@ -675,6 +698,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params: if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
@ -718,6 +743,7 @@ class AzureChatCompletion(BaseLLM):
model: str, model: str,
timeout: Any, timeout: Any,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None, client=None,
): ):
try: try:
@ -739,6 +765,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params: if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
else: else:
@ -844,6 +872,7 @@ class AzureChatCompletion(BaseLLM):
optional_params: dict, optional_params: dict,
api_key: Optional[str] = None, api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
max_retries: Optional[int] = None, max_retries: Optional[int] = None,
client=None, client=None,
aembedding=None, aembedding=None,
@ -883,6 +912,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -1240,6 +1271,7 @@ class AzureChatCompletion(BaseLLM):
api_version: Optional[str] = None, api_version: Optional[str] = None,
model_response: Optional[ImageResponse] = None, model_response: Optional[ImageResponse] = None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None, client=None,
aimg_generation=None, aimg_generation=None,
) -> ImageResponse: ) -> ImageResponse:
@ -1266,7 +1298,7 @@ class AzureChatCompletion(BaseLLM):
) )
# init AzureOpenAI Client # init AzureOpenAI Client
azure_client_params = { azure_client_params: Dict[str, Any] = {
"api_version": api_version, "api_version": api_version,
"azure_endpoint": api_base, "azure_endpoint": api_base,
"azure_deployment": model, "azure_deployment": model,
@ -1282,6 +1314,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if aimg_generation is True: if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
@ -1342,6 +1376,7 @@ class AzureChatCompletion(BaseLLM):
max_retries: int, max_retries: int,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None, aspeech: Optional[bool] = None,
client=None, client=None,
) -> HttpxBinaryResponseContent: ) -> HttpxBinaryResponseContent:
@ -1358,6 +1393,7 @@ class AzureChatCompletion(BaseLLM):
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
max_retries=max_retries, max_retries=max_retries,
timeout=timeout, timeout=timeout,
client=client, client=client,
@ -1368,6 +1404,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version, api_version=api_version,
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model, model=model,
max_retries=max_retries, max_retries=max_retries,
timeout=timeout, timeout=timeout,
@ -1393,6 +1430,7 @@ class AzureChatCompletion(BaseLLM):
api_base: Optional[str], api_base: Optional[str],
api_version: Optional[str], api_version: Optional[str],
azure_ad_token: Optional[str], azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
max_retries: int, max_retries: int,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
client=None, client=None,
@ -1403,6 +1441,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version, api_version=api_version,
api_key=api_key, api_key=api_key,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model=model, model=model,
max_retries=max_retries, max_retries=max_retries,
timeout=timeout, timeout=timeout,

View file

@ -49,6 +49,7 @@ class AzureTextCompletion(BaseLLM):
api_version: str, api_version: str,
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
azure_ad_token_provider: Optional[Callable],
print_verbose: Callable, print_verbose: Callable,
timeout, timeout,
logging_obj, logging_obj,
@ -170,6 +171,7 @@ class AzureTextCompletion(BaseLLM):
"http_client": litellm.client_session, "http_client": litellm.client_session,
"max_retries": max_retries, "max_retries": max_retries,
"timeout": timeout, "timeout": timeout,
"azure_ad_token_provider": azure_ad_token_provider,
} }
azure_client_params = select_azure_base_url_or_endpoint( azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params azure_client_params=azure_client_params

View file

@ -5,6 +5,7 @@ import litellm
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
class CodestralTextCompletionConfig(OpenAITextCompletionConfig): class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
""" """
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
@ -77,6 +78,7 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
return optional_params return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = None finish_reason = None
@ -90,7 +92,15 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
"is_finished": is_finished, "is_finished": is_finished,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
try:
chunk_data_dict = json.loads(chunk_data) chunk_data_dict = json.loads(chunk_data)
except json.JSONDecodeError:
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True) original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or [] _choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0] _choice = _choices[0]

View file

@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None "azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN") ) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None: if extra_headers is not None:
@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915
api_type=api_type, api_type=api_type,
dynamic_params=dynamic_params, dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None "azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN") ) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None: if extra_headers is not None:
@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version, api_version=api_version,
api_type=api_type, api_type=api_type,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None) cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None) max_parallel_requests = kwargs.pop("max_parallel_requests", None)
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None) metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915
api_key=api_key, api_key=api_key,
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=logging, logging_obj=logging,
timeout=timeout, timeout=timeout,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915
logger_fn = kwargs.get("logger_fn", None) logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915
timeout=timeout, timeout=timeout,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
optional_params=optional_params, optional_params=optional_params,
model_response=model_response, model_response=model_response,
@ -5251,6 +5266,7 @@ def speech(
) or get_secret( ) or get_secret(
"AZURE_AD_TOKEN" "AZURE_AD_TOKEN"
) )
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
if extra_headers: if extra_headers:
optional_params["extra_headers"] = extra_headers optional_params["extra_headers"] = extra_headers
@ -5264,6 +5280,7 @@ def speech(
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization, organization=organization,
max_retries=max_retries, max_retries=max_retries,
timeout=timeout, timeout=timeout,

View file

@ -29,11 +29,3 @@ model_list:
litellm_settings: litellm_settings:
callbacks: ["langsmith"] callbacks: ["langsmith"]
disable_no_log_param: true
general_settings:
enable_jwt_auth: True
litellm_jwtauth:
user_id_jwt_field: "sub"
user_email_jwt_field: "email"
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD

View file

@ -461,8 +461,15 @@ class BaseLLMChatTest(ABC):
pass pass
@pytest.mark.parametrize("detail", [None, "low", "high"]) @pytest.mark.parametrize("detail", [None, "low", "high"])
@pytest.mark.parametrize(
"image_url",
[
"http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
],
)
@pytest.mark.flaky(retries=4, delay=1) @pytest.mark.flaky(retries=4, delay=1)
def test_image_url(self, detail): def test_image_url(self, detail, image_url):
litellm.set_verbose = True litellm.set_verbose = True
from litellm.utils import supports_vision from litellm.utils import supports_vision
@ -472,6 +479,10 @@ class BaseLLMChatTest(ABC):
base_completion_call_args = self.get_base_completion_call_args() base_completion_call_args = self.get_base_completion_call_args()
if not supports_vision(base_completion_call_args["model"], None): if not supports_vision(base_completion_call_args["model"], None):
pytest.skip("Model does not support image input") pytest.skip("Model does not support image input")
elif "http://" in image_url and "fireworks_ai" in base_completion_call_args.get(
"model"
):
pytest.skip("Model does not support http:// input")
messages = [ messages = [
{ {
@ -481,7 +492,7 @@ class BaseLLMChatTest(ABC):
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" "url": image_url,
}, },
}, },
], ],

View file

@ -1289,10 +1289,11 @@ def test_process_gemini_image_http_url(
http_url: Test HTTP URL http_url: Test HTTP URL
mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function
mock_blob: Mocked BlobType instance mock_blob: Mocked BlobType instance
Vertex AI supports image urls. Ensure no network requests are made.
""" """
# Arrange
expected_image_data = "..." expected_image_data = "..."
mock_convert_url_to_base64.return_value = expected_image_data mock_convert_url_to_base64.return_value = expected_image_data
# Act # Act
result = _process_gemini_image(http_url) result = _process_gemini_image(http_url)
# assert result["file_data"]["file_uri"] == http_url

View file

@ -205,20 +205,29 @@ def anthropic_messages():
] ]
def test_anthropic_vertex_ai_prompt_caching(anthropic_messages): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_anthropic_vertex_ai_prompt_caching(anthropic_messages, sync_mode):
litellm._turn_on_debug() litellm._turn_on_debug()
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
load_vertex_ai_credentials() load_vertex_ai_credentials()
client = HTTPHandler() client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
with patch.object(client, "post", return_value=MagicMock()) as mock_post: with patch.object(client, "post", return_value=MagicMock()) as mock_post:
try: try:
if sync_mode:
response = completion( response = completion(
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ", model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
messages=anthropic_messages, messages=anthropic_messages,
client=client, client=client,
) )
else:
response = await litellm.acompletion(
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
messages=anthropic_messages,
client=client,
)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")

View file

@ -730,7 +730,6 @@ def test_stream_chunk_builder_openai_audio_output_usage():
usage_dict == response_usage_dict usage_dict == response_usage_dict
), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}" ), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}"
def test_stream_chunk_builder_empty_initial_chunk(): def test_stream_chunk_builder_empty_initial_chunk():
from litellm.litellm_core_utils.streaming_chunk_builder_utils import ( from litellm.litellm_core_utils.streaming_chunk_builder_utils import (
ChunkProcessor, ChunkProcessor,