LiteLLM Minor Fixes & Improvements (01/16/2025) - p2 (#7828)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 14s

* fix(vertex_ai/gemini/transformation.py): handle 'http://' image urls

* test: add base test for `http:` url's

* fix(factory.py/get_image_details): follow redirects

allows http calls to work

* fix(codestral/): fix stream chunk parsing on last chunk of stream

* Azure ad token provider (#6917)

* Update azure.py

Added optional parameter azure ad token provider

* Added parameter to main.py

* Found token provider arg location

* Fixed embeddings

* Fixed ad token provider

---------

Co-authored-by: Krish Dholakia <krrishdholakia@gmail.com>

* fix: fix linting errors

* fix(main.py): leave out o1 route for azure ad token provider, for now

get v0 out for sync azure gpt route to begin with

* test: skip http:// test for fireworks ai

model does not support it

* refactor: cleanup dead code

* fix: revert http:// url passthrough for gemini

google ai studio raises errors

* test: fix test

---------

Co-authored-by: bahtman <anton@baht.dk>
This commit is contained in:
Krish Dholakia 2025-02-02 23:17:50 -08:00 committed by GitHub
parent 10d3da7660
commit 97b8de17ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 107 additions and 27 deletions

View file

@ -1214,6 +1214,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers
if extra_headers is not None:
@ -1269,6 +1273,7 @@ def completion( # type: ignore # noqa: PLR0915
api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@ -1314,6 +1319,10 @@ def completion( # type: ignore # noqa: PLR0915
"azure_ad_token", None
) or get_secret("AZURE_AD_TOKEN")
azure_ad_token_provider = litellm_params.get(
"azure_ad_token_provider", None
)
headers = headers or litellm.headers
if extra_headers is not None:
@ -1337,6 +1346,7 @@ def completion( # type: ignore # noqa: PLR0915
api_version=api_version,
api_type=api_type,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@ -3244,6 +3254,7 @@ def embedding( # noqa: PLR0915
cooldown_time = kwargs.get("cooldown_time", None)
mock_response: Optional[List[float]] = kwargs.get("mock_response", None) # type: ignore
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
azure_ad_token_provider = kwargs.pop("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
@ -3374,6 +3385,7 @@ def embedding( # noqa: PLR0915
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
@ -4449,6 +4461,7 @@ def image_generation( # noqa: PLR0915
logger_fn = kwargs.get("logger_fn", None)
mock_response: Optional[str] = kwargs.get("mock_response", None) # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
@ -4562,6 +4575,8 @@ def image_generation( # noqa: PLR0915
timeout=timeout,
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
@ -5251,6 +5266,7 @@ def speech(
) or get_secret(
"AZURE_AD_TOKEN"
)
azure_ad_token_provider = kwargs.get("azure_ad_token_provider", None)
if extra_headers:
optional_params["extra_headers"] = extra_headers
@ -5264,6 +5280,7 @@ def speech(
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
max_retries=max_retries,
timeout=timeout,