mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s
* fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls * test: add base test for `http:` url's * fix(factory.py/get_image_details): follow redirects allows http calls to work * fix(codestral/): fix stream chunk parsing on last chunk of stream * Azure ad token provider (#6917) * Update azure.py Added optional parameter azure ad token provider * Added parameter to main.py * Found token provider arg location * Fixed embeddings * Fixed ad token provider --------- Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com> * fix: fix linting errors * fix(main.py): leave out o1 route for azure ad token provider, for now get v0 out for sync azure gpt route to begin with * test: skip http:// test for fireworks ai model does not support it * refactor: cleanup dead code * fix: revert http:// url passthrough for gemini google ai studio raises errors * test: fix test --------- Co-authored-by: bahtman <anton@baht.dk>
This commit is contained in:
parent
10d3da7660
commit
97b8de17ab
9 changed files with 107 additions and 27 deletions
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, List, Literal, Optional, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
@ -217,7 +217,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def validate_environment(self, api_key, azure_ad_token):
|
def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
|
||||||
headers = {
|
headers = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
@ -227,6 +227,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_ad_token = azure_ad_token_provider()
|
||||||
|
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _get_sync_azure_client(
|
def _get_sync_azure_client(
|
||||||
|
@ -235,6 +239,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
azure_ad_token: Optional[str],
|
azure_ad_token: Optional[str],
|
||||||
|
azure_ad_token_provider: Optional[Callable],
|
||||||
model: str,
|
model: str,
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
@ -242,7 +247,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
client_type: Literal["sync", "async"],
|
client_type: Literal["sync", "async"],
|
||||||
):
|
):
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params = {
|
azure_client_params: Dict[str, Any] = {
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"azure_deployment": model,
|
"azure_deployment": model,
|
||||||
|
@ -259,6 +264,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
if client is None:
|
if client is None:
|
||||||
if client_type == "sync":
|
if client_type == "sync":
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
@ -326,6 +333,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version: str,
|
api_version: str,
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
|
azure_ad_token_provider: Callable,
|
||||||
dynamic_params: bool,
|
dynamic_params: bool,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
@ -373,6 +381,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = (
|
||||||
|
azure_ad_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
client = AsyncAzureOpenAI(**azure_client_params)
|
client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
@ -400,6 +412,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
@ -412,6 +425,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
model=model,
|
model=model,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
dynamic_params=dynamic_params,
|
dynamic_params=dynamic_params,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -428,6 +442,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
@ -468,6 +483,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = (
|
||||||
|
azure_ad_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
client is None
|
client is None
|
||||||
|
@ -535,6 +554,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||||
client=None, # this is the AsyncAzureOpenAI
|
client=None, # this is the AsyncAzureOpenAI
|
||||||
):
|
):
|
||||||
|
@ -564,6 +584,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
if client is None or dynamic_params:
|
if client is None or dynamic_params:
|
||||||
|
@ -650,6 +672,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
|
@ -675,6 +698,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
||||||
if client is None or dynamic_params:
|
if client is None or dynamic_params:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
|
@ -718,6 +743,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
client=None,
|
client=None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
@ -739,6 +765,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
if client is None or dynamic_params:
|
if client is None or dynamic_params:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
|
@ -844,6 +872,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
max_retries: Optional[int] = None,
|
max_retries: Optional[int] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
@ -883,6 +912,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
|
@ -1240,6 +1271,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
model_response: Optional[ImageResponse] = None,
|
model_response: Optional[ImageResponse] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
|
@ -1266,7 +1298,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params = {
|
azure_client_params: Dict[str, Any] = {
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"azure_endpoint": api_base,
|
"azure_endpoint": api_base,
|
||||||
"azure_deployment": model,
|
"azure_deployment": model,
|
||||||
|
@ -1282,6 +1314,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
||||||
if aimg_generation is True:
|
if aimg_generation is True:
|
||||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
||||||
|
@ -1342,6 +1376,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
aspeech: Optional[bool] = None,
|
aspeech: Optional[bool] = None,
|
||||||
client=None,
|
client=None,
|
||||||
) -> HttpxBinaryResponseContent:
|
) -> HttpxBinaryResponseContent:
|
||||||
|
@ -1358,6 +1393,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -1368,6 +1404,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
model=model,
|
model=model,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
@ -1393,6 +1430,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
api_version: Optional[str],
|
api_version: Optional[str],
|
||||||
azure_ad_token: Optional[str],
|
azure_ad_token: Optional[str],
|
||||||
|
azure_ad_token_provider: Optional[Callable],
|
||||||
max_retries: int,
|
max_retries: int,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
client=None,
|
client=None,
|
||||||
|
@ -1403,6 +1441,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
model=model,
|
model=model,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
@ -49,6 +49,7 @@ class AzureTextCompletion(BaseLLM):
|
||||||
api_version: str,
|
api_version: str,
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
|
azure_ad_token_provider: Optional[Callable],
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout,
|
timeout,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -170,6 +171,7 @@ class AzureTextCompletion(BaseLLM):
|
||||||
"http_client": litellm.client_session,
|
"http_client": litellm.client_session,
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": timeout,
|
"timeout": timeout,
|
||||||
|
"azure_ad_token_provider": azure_ad_token_provider,
|
||||||
}
|
}
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(
|
azure_client_params = select_azure_base_url_or_endpoint(
|
||||||
azure_client_params=azure_client_params
|
azure_client_params=azure_client_params
|
||||||
|
|
|
@ -5,6 +5,7 @@ import litellm
|
||||||
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
|
|
||||||
|
|
||||||
class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
|
@ -77,6 +78,7 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
is_finished = False
|
is_finished = False
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
@ -90,7 +92,15 @@ class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
chunk_data_dict = json.loads(chunk_data)
|
try:
|
||||||
|
chunk_data_dict = json.loads(chunk_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
|
||||||
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
|
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
|
||||||
_choices = chunk_data_dict.get("choices", []) or []
|
_choices = chunk_data_dict.get("choices", []) or []
|
||||||
_choice = _choices[0]
|
_choice = _choices[0]
|
||||||
|
|
|
@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"azure_ad_token", None
|
"azure_ad_token", None
|
||||||
) or get_secret("AZURE_AD_TOKEN")
|
) or get_secret("AZURE_AD_TOKEN")
|
||||||
|
|
||||||
|
azure_ad_token_provider = litellm_params.get(
|
||||||
|
"azure_ad_token_provider", None
|
||||||
|
)
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
|
@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_type=api_type,
|
api_type=api_type,
|
||||||
dynamic_params=dynamic_params,
|
dynamic_params=dynamic_params,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"azure_ad_token", None
|
"azure_ad_token", None
|
||||||
) or get_secret("AZURE_AD_TOKEN")
|
) or get_secret("AZURE_AD_TOKEN")
|
||||||
|
|
||||||
|
azure_ad_token_provider = litellm_params.get(
|
||||||
|
"azure_ad_token_provider", None
|
||||||
|
)
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
if extra_headers is not None:
|
if extra_headers is not None:
|
||||||
|
@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_type=api_type,
|
api_type=api_type,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915
|
||||||
cooldown_time = kwargs.get("cooldown_time", None)
|
cooldown_time = kwargs.get("cooldown_time", None)
|
||||||
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
|
||||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||||
|
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", None)
|
metadata = kwargs.get("metadata", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
|
@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
|
@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915
|
||||||
logger_fn = kwargs.get("logger_fn", None)
|
logger_fn = kwargs.get("logger_fn", None)
|
||||||
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
|
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
|
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
metadata = kwargs.get("metadata", {})
|
metadata = kwargs.get("metadata", {})
|
||||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||||
|
@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
@ -5251,6 +5266,7 @@ def speech(
|
||||||
) or get_secret(
|
) or get_secret(
|
||||||
"AZURE_AD_TOKEN"
|
"AZURE_AD_TOKEN"
|
||||||
)
|
)
|
||||||
|
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
|
||||||
|
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
optional_params["extra_headers"] = extra_headers
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
@ -5264,6 +5280,7 @@ def speech(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|
|
@ -29,11 +29,3 @@ model_list:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["langsmith"]
|
callbacks: ["langsmith"]
|
||||||
disable_no_log_param: true
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
enable_jwt_auth: True
|
|
||||||
litellm_jwtauth:
|
|
||||||
user_id_jwt_field: "sub"
|
|
||||||
user_email_jwt_field: "email"
|
|
||||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
|
|
@ -461,8 +461,15 @@ class BaseLLMChatTest(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@pytest.mark.parametrize("detail", [None, "low", "high"])
|
@pytest.mark.parametrize("detail", [None, "low", "high"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"image_url",
|
||||||
|
[
|
||||||
|
"http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.flaky(retries=4, delay=1)
|
@pytest.mark.flaky(retries=4, delay=1)
|
||||||
def test_image_url(self, detail):
|
def test_image_url(self, detail, image_url):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
from litellm.utils import supports_vision
|
from litellm.utils import supports_vision
|
||||||
|
|
||||||
|
@ -472,6 +479,10 @@ class BaseLLMChatTest(ABC):
|
||||||
base_completion_call_args = self.get_base_completion_call_args()
|
base_completion_call_args = self.get_base_completion_call_args()
|
||||||
if not supports_vision(base_completion_call_args["model"], None):
|
if not supports_vision(base_completion_call_args["model"], None):
|
||||||
pytest.skip("Model does not support image input")
|
pytest.skip("Model does not support image input")
|
||||||
|
elif "http://" in image_url and "fireworks_ai" in base_completion_call_args.get(
|
||||||
|
"model"
|
||||||
|
):
|
||||||
|
pytest.skip("Model does not support http:// input")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
@ -481,7 +492,7 @@ class BaseLLMChatTest(ABC):
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
"url": image_url,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
|
|
@ -1289,10 +1289,11 @@ def test_process_gemini_image_http_url(
|
||||||
http_url: Test HTTP URL
|
http_url: Test HTTP URL
|
||||||
mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function
|
mock_convert_to_anthropic: Mocked convert_to_anthropic_image_obj function
|
||||||
mock_blob: Mocked BlobType instance
|
mock_blob: Mocked BlobType instance
|
||||||
|
|
||||||
|
Vertex AI supports image urls. Ensure no network requests are made.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
|
||||||
expected_image_data = "..."
|
expected_image_data = "..."
|
||||||
mock_convert_url_to_base64.return_value = expected_image_data
|
mock_convert_url_to_base64.return_value = expected_image_data
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = _process_gemini_image(http_url)
|
result = _process_gemini_image(http_url)
|
||||||
|
# assert result["file_data"]["file_uri"] == http_url
|
||||||
|
|
|
@ -205,20 +205,29 @@ def anthropic_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_anthropic_vertex_ai_prompt_caching(anthropic_messages):
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anthropic_vertex_ai_prompt_caching(anthropic_messages, sync_mode):
|
||||||
litellm._turn_on_debug()
|
litellm._turn_on_debug()
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
|
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
client = HTTPHandler()
|
client = HTTPHandler() if sync_mode else AsyncHTTPHandler()
|
||||||
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
with patch.object(client, "post", return_value=MagicMock()) as mock_post:
|
||||||
try:
|
try:
|
||||||
response = completion(
|
if sync_mode:
|
||||||
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
|
response = completion(
|
||||||
messages=anthropic_messages,
|
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
|
||||||
client=client,
|
messages=anthropic_messages,
|
||||||
)
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="vertex_ai/claude-3-5-sonnet-v2@20241022 ",
|
||||||
|
messages=anthropic_messages,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
|
|
@ -730,7 +730,6 @@ def test_stream_chunk_builder_openai_audio_output_usage():
|
||||||
usage_dict == response_usage_dict
|
usage_dict == response_usage_dict
|
||||||
), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}"
|
), f"\nExpected: {usage_dict}\nGot: {response_usage_dict}"
|
||||||
|
|
||||||
|
|
||||||
def test_stream_chunk_builder_empty_initial_chunk():
|
def test_stream_chunk_builder_empty_initial_chunk():
|
||||||
from litellm.litellm_core_utils.streaming_chunk_builder_utils import (
|
from litellm.litellm_core_utils.streaming_chunk_builder_utils import (
|
||||||
ChunkProcessor,
|
ChunkProcessor,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue