Merge pull request #9140 from BerriAI/litellm_router_client_init_migration

Delegate router azure client init logic to azure provider
This commit is contained in:
Krish Dholakia 2025-03-11 18:29:03 -07:00 committed by GitHub
commit ee53e41213
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 895 additions and 732 deletions

View file

@ -15,6 +15,7 @@ import litellm
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
exception_type,
get_litellm_params,
get_llm_provider,
get_secret,
supports_httpx_timeout,
@ -86,6 +87,7 @@ def get_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -169,6 +171,7 @@ def get_assistants(
max_retries=optional_params.max_retries,
client=client,
aget_assistants=aget_assistants, # type: ignore
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -270,6 +273,7 @@ def create_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -371,6 +375,7 @@ def create_assistants(
client=client,
async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -445,6 +450,8 @@ def delete_assistant(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
async_delete_assistants: Optional[bool] = kwargs.pop(
"async_delete_assistants", None
)
@ -544,6 +551,7 @@ def delete_assistant(
max_retries=optional_params.max_retries,
client=client,
async_delete_assistants=async_delete_assistants,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -639,6 +647,7 @@ def create_thread(
"""
acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -731,6 +740,7 @@ def create_thread(
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -795,7 +805,7 @@ def get_thread(
"""Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -884,6 +894,7 @@ def get_thread(
max_retries=optional_params.max_retries,
client=client,
aget_thread=aget_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -972,6 +983,7 @@ def add_message(
_message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
litellm_params_dict = get_litellm_params(**kwargs)
optional_params = GenericLiteLLMParams(**kwargs)
message_data = get_optional_params_add_message(
@ -1068,6 +1080,7 @@ def add_message(
max_retries=optional_params.max_retries,
client=client,
a_add_message=a_add_message,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1139,6 +1152,7 @@ def get_messages(
) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1225,6 +1239,7 @@ def get_messages(
max_retries=optional_params.max_retries,
client=client,
aget_messages=aget_messages,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1337,6 +1352,7 @@ def run_thread(
"""Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1437,6 +1453,7 @@ def run_thread(
max_retries=optional_params.max_retries,
client=client,
arun_thread=arun_thread,
litellm_params=litellm_params_dict,
) # type: ignore
else:
raise litellm.exceptions.BadRequestError(

View file

@ -111,6 +111,7 @@ def create_batch(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_params = get_litellm_params(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -217,6 +218,7 @@ def create_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
create_batch_data=_create_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
@ -320,15 +322,12 @@ def retrieve_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
**kwargs,
)
litellm_logging_obj.update_environment_variables(
model=None,
@ -424,6 +423,7 @@ def retrieve_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
retrieve_batch_data=_retrieve_batch_request,
litellm_params=litellm_params,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""
@ -526,6 +526,10 @@ def list_batches(
try:
# set API KEY
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
@ -603,6 +607,7 @@ def list_batches(
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(
@ -678,6 +683,10 @@ def cancel_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
**kwargs,
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -765,6 +774,7 @@ def cancel_batch(
timeout=timeout,
max_retries=optional_params.max_retries,
cancel_batch_data=_cancel_batch_request,
litellm_params=litellm_params,
)
else:
raise litellm.exceptions.BadRequestError(

View file

@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
from litellm.utils import get_litellm_params, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
@ -546,6 +546,7 @@ def create_file(
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -630,6 +631,7 @@ def create_file(
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""

View file

@ -58,6 +58,7 @@ def get_litellm_params(
async_call: Optional[bool] = None,
ssl_verify: Optional[bool] = None,
merge_reasoning_content_in_choices: Optional[bool] = None,
max_retries: Optional[int] = None,
**kwargs,
) -> dict:
litellm_params = {
@ -99,5 +100,13 @@ def get_litellm_params(
"async_call": async_call,
"ssl_verify": ssl_verify,
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
"azure_ad_token": kwargs.get("azure_ad_token"),
"tenant_id": kwargs.get("tenant_id"),
"client_id": kwargs.get("client_id"),
"client_secret": kwargs.get("client_secret"),
"azure_username": kwargs.get("azure_username"),
"azure_password": kwargs.get("azure_password"),
"max_retries": max_retries,
"timeout": kwargs.get("timeout"),
}
return litellm_params

View file

@ -18,10 +18,10 @@ from ...types.llms.openai import (
SyncCursorPage,
Thread,
)
from ..base import BaseLLM
from .common_utils import BaseAzureLLM
class AzureAssistantsAPI(BaseLLM):
class AzureAssistantsAPI(BaseAzureLLM):
def __init__(self) -> None:
super().__init__()
@ -34,18 +34,17 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AzureOpenAI(**data) # type: ignore
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_openai_client = client
@ -60,18 +59,18 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncAzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data)
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
@ -89,6 +88,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -98,6 +98,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.list()
@ -146,6 +147,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_assistants=None,
litellm_params: Optional[dict] = None,
):
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
@ -156,6 +158,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -165,6 +168,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
api_version=api_version,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.list()
@ -184,6 +188,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -193,6 +198,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
@ -222,6 +228,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, OpenAIMessage]:
...
@ -238,6 +245,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
...
@ -255,6 +263,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
a_add_message: Optional[bool] = None,
litellm_params: Optional[dict] = None,
):
if a_add_message is not None and a_add_message is True:
return self.a_add_message(
@ -267,6 +276,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -300,6 +310,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -309,6 +320,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -329,6 +341,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@ -344,6 +357,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> SyncCursorPage[OpenAIMessage]:
...
@ -360,6 +374,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_messages=None,
litellm_params: Optional[dict] = None,
):
if aget_messages is not None and aget_messages is True:
return self.async_get_messages(
@ -371,6 +386,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -380,6 +396,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -399,6 +416,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -408,6 +426,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -435,6 +454,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -451,6 +471,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -468,6 +489,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
litellm_params: Optional[dict] = None,
):
"""
Here's an example:
@ -490,6 +512,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
messages=messages,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -499,6 +522,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -521,6 +545,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -530,6 +555,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -550,6 +576,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -565,6 +592,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -581,6 +609,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_thread=None,
litellm_params: Optional[dict] = None,
):
if aget_thread is not None and aget_thread is True:
return self.async_get_thread(
@ -592,6 +621,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -601,6 +631,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -629,6 +660,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Run:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -638,6 +670,7 @@ class AzureAssistantsAPI(BaseLLM):
api_version=api_version,
azure_ad_token=azure_ad_token,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -645,7 +678,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -663,6 +696,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
@ -688,6 +722,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
@ -769,6 +804,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
litellm_params: Optional[dict] = None,
):
if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True:
@ -780,6 +816,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
return self.async_run_thread_stream(
client=azure_client,
@ -791,13 +828,14 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
stream=stream,
tools=tools,
@ -808,6 +846,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -817,6 +856,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
if stream is not None and stream is True:
@ -830,6 +870,7 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -837,7 +878,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -855,6 +896,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
create_assistant_data: dict,
litellm_params: Optional[dict] = None,
) -> Assistant:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -864,6 +906,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.create(
@ -882,6 +925,7 @@ class AzureAssistantsAPI(BaseLLM):
create_assistant_data: dict,
client=None,
async_create_assistants=None,
litellm_params: Optional[dict] = None,
):
if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants(
@ -893,6 +937,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -902,6 +947,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
@ -918,6 +964,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
assistant_id: str,
litellm_params: Optional[dict] = None,
):
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -927,6 +974,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.delete(
@ -945,6 +993,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
async_delete_assistants: Optional[bool] = None,
client=None,
litellm_params: Optional[dict] = None,
):
if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant(
@ -956,6 +1005,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
assistant_id=assistant_id,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -965,6 +1015,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)

View file

@ -9,11 +9,7 @@ from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.types.utils import FileTypes
from litellm.utils import TranscriptionResponse, convert_to_model_response_object
from .azure import (
AzureChatCompletion,
get_azure_ad_token_from_oidc,
select_azure_base_url_or_endpoint,
)
from .azure import AzureChatCompletion
class AzureAudioTranscription(AzureChatCompletion):
@ -32,29 +28,18 @@ class AzureAudioTranscription(AzureChatCompletion):
client=None,
azure_ad_token: Optional[str] = None,
atranscription: bool = False,
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse:
data = {"model": model, "file": audio_file, **optional_params}
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model,
api_version=api_version,
api_base=api_base,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if atranscription is True:
return self.async_audio_transcriptions( # type: ignore

View file

@ -1,6 +1,5 @@
import asyncio
import json
import os
import time
from typing import Any, Callable, Dict, List, Literal, Optional, Union
@ -8,7 +7,6 @@ import httpx # type: ignore
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
@ -25,15 +23,18 @@ from litellm.types.utils import (
from litellm.utils import (
CustomStreamWrapper,
convert_to_model_response_object,
get_secret,
modify_url,
)
from ...types.llms.openai import HttpxBinaryResponseContent
from ..base import BaseLLM
from .common_utils import AzureOpenAIError, process_azure_headers
azure_ad_cache = DualCache()
from .common_utils import (
AzureOpenAIError,
BaseAzureLLM,
get_azure_ad_token_from_oidc,
process_azure_headers,
select_azure_base_url_or_endpoint,
)
class AzureOpenAIAssistantsAPIConfig:
@ -98,93 +99,6 @@ class AzureOpenAIAssistantsAPIConfig:
return optional_params
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def _check_dynamic_azure_params(
azure_client_params: dict,
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
@ -206,7 +120,7 @@ def _check_dynamic_azure_params(
return False
class AzureChatCompletion(BaseLLM):
class AzureChatCompletion(BaseAzureLLM, BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -238,27 +152,16 @@ class AzureChatCompletion(BaseLLM):
timeout: Union[float, httpx.Timeout],
client: Optional[Any],
client_type: Literal["sync", "async"],
litellm_params: Optional[dict] = None,
):
# init AzureOpenAI Client
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model,
api_version=api_version,
api_base=api_base,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None:
if client_type == "sync":
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
@ -357,6 +260,13 @@ class AzureChatCompletion(BaseLLM):
max_retries = DEFAULT_MAX_RETRIES
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name=model,
api_version=api_version,
)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base:
@ -417,6 +327,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
max_retries=max_retries,
azure_client_params=azure_client_params,
)
else:
return self.acompletion(
@ -434,6 +345,7 @@ class AzureChatCompletion(BaseLLM):
logging_obj=logging_obj,
max_retries=max_retries,
convert_tool_call_to_json_mode=json_mode,
azure_client_params=azure_client_params,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
@ -470,28 +382,6 @@ class AzureChatCompletion(BaseLLM):
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if (
client is None
or not isinstance(client, AzureOpenAI)
@ -562,30 +452,10 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
convert_tool_call_to_json_mode: Optional[bool] = None,
client=None, # this is the AsyncAzureOpenAI
azure_client_params: dict = {},
):
response = None
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# setting Azure client
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
@ -742,28 +612,9 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[Callable] = None,
client=None,
azure_client_params: dict = {},
):
try:
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"http_client": litellm.aclient_session,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params)
else:
@ -824,6 +675,7 @@ class AzureChatCompletion(BaseLLM):
):
response = None
try:
if client is None:
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
else:
@ -875,6 +727,7 @@ class AzureChatCompletion(BaseLLM):
client=None,
aembedding=None,
headers: Optional[dict] = None,
litellm_params: Optional[dict] = None,
) -> EmbeddingResponse:
if headers:
optional_params["extra_headers"] = headers
@ -890,29 +743,14 @@ class AzureChatCompletion(BaseLLM):
)
# init AzureOpenAI Client
azure_client_params = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if aembedding:
azure_client_params["http_client"] = litellm.aclient_session
else:
azure_client_params["http_client"] = litellm.client_session
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model,
api_version=api_version,
api_base=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
@ -1272,6 +1110,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
client=None,
aimg_generation=None,
litellm_params: Optional[dict] = None,
) -> ImageResponse:
try:
if model and len(model) > 0:
@ -1296,25 +1135,13 @@ class AzureChatCompletion(BaseLLM):
)
# init AzureOpenAI Client
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model or "",
api_version=api_version,
api_base=api_base,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
@ -1377,6 +1204,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
aspeech: Optional[bool] = None,
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
max_retries = optional_params.pop("max_retries", 2)
@ -1395,6 +1223,7 @@ class AzureChatCompletion(BaseLLM):
max_retries=max_retries,
timeout=timeout,
client=client,
litellm_params=litellm_params,
) # type: ignore
azure_client: AzureOpenAI = self._get_sync_azure_client(
@ -1408,6 +1237,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
client_type="sync",
litellm_params=litellm_params,
) # type: ignore
response = azure_client.audio.speech.create(
@ -1432,6 +1262,7 @@ class AzureChatCompletion(BaseLLM):
max_retries: int,
timeout: Union[float, httpx.Timeout],
client=None,
litellm_params: Optional[dict] = None,
) -> HttpxBinaryResponseContent:
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
@ -1445,6 +1276,7 @@ class AzureChatCompletion(BaseLLM):
timeout=timeout,
client=client,
client_type="async",
litellm_params=litellm_params,
) # type: ignore
azure_response = await azure_client.audio.speech.create(

View file

@ -6,7 +6,6 @@ from typing import Any, Coroutine, Optional, Union, cast
import httpx
import litellm
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
@ -16,8 +15,10 @@ from litellm.types.llms.openai import (
)
from litellm.types.utils import LiteLLMBatch
from ..common_utils import BaseAzureLLM
class AzureBatchesAPI:
class AzureBatchesAPI(BaseAzureLLM):
"""
Azure methods to support for batches
- create_batch()
@ -29,38 +30,6 @@ class AzureBatchesAPI:
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
@ -79,16 +48,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -125,16 +94,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -173,16 +142,16 @@ class AzureBatchesAPI:
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:
@ -212,16 +181,16 @@ class AzureBatchesAPI:
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
)
if azure_client is None:

View file

@ -4,50 +4,68 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
Written separately to handle faking streaming for o1 and o3 models.
"""
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.types.utils import ModelResponse
from ...openai.openai import OpenAIChatCompletion
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
def _get_openai_client(
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
def completion(
self,
is_async: bool,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
dynamic_params: Optional[bool] = None,
azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
# Override to use Azure-specific client initialization
if not isinstance(client, AzureOpenAI) and not isinstance(
client, AsyncAzureOpenAI
):
client = None
return get_azure_openai_client(
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
client = self.get_azure_openai_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=is_async,
)
return super().completion(
model_response=model_response,
timeout=timeout,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
model=model,
messages=messages,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
api_version=api_version,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
acompletion=acompletion,
logger_fn=logger_fn,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
client=client,
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)

View file

@ -1,3 +1,5 @@
import json
import os
from typing import Callable, Optional, Union
import httpx
@ -5,9 +7,15 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.main import get_secret_str
azure_ad_cache = DualCache()
class AzureOpenAIError(BaseLLMException):
def __init__(
@ -27,39 +35,6 @@ class AzureOpenAIError(BaseLLMException):
)
def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
@ -178,3 +153,199 @@ def get_azure_ad_token_from_username_password(
verbose_logger.debug("token_provider %s", token_provider)
return token_provider
def get_azure_ad_token_from_oidc(azure_ad_token: str):
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret_str(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": "https://cognitiveservices.azure.com/.default",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
class BaseAzureLLM:
def get_azure_openai_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
if _is_async is True:
openai_client = AsyncAzureOpenAI(**azure_client_params)
else:
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
openai_client = client
return openai_client
def initialize_azure_sdk_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
model_name: str,
api_version: Optional[str],
) -> dict:
azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")
tenant_id = litellm_params.get("tenant_id")
client_id = litellm_params.get("client_id")
client_secret = litellm_params.get("client_secret")
azure_username = litellm_params.get("azure_username")
azure_password = litellm_params.get("azure_password")
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if not api_key and tenant_id and client_id and client_secret:
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
)
if azure_username and azure_password and client_id:
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
)
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
)
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
"http_client": litellm.client_session,
}
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
return azure_client_params

View file

@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import *
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIFilesAPI(BaseLLM):
class AzureOpenAIFilesAPI(BaseAzureLLM):
"""
AzureOpenAI methods to support for batches
- create_file()
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
organization=None,
client=client,
_is_async=_is_async,
)
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
purpose: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None, # openai param
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -3,11 +3,11 @@ from typing import Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.llms.azure.files.handler import get_azure_openai_client
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
"""
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
"""
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None
return get_azure_openai_client(
return self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -27,6 +27,7 @@ class OpenAIFineTuningAPI:
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,

View file

@ -1162,6 +1162,14 @@ def completion( # type: ignore # noqa: PLR0915
merge_reasoning_content_in_choices=kwargs.get(
"merge_reasoning_content_in_choices", None
),
azure_ad_token=kwargs.get("azure_ad_token"),
tenant_id=kwargs.get("tenant_id"),
client_id=kwargs.get("client_id"),
client_secret=kwargs.get("client_secret"),
azure_username=kwargs.get("azure_username"),
azure_password=kwargs.get("azure_password"),
max_retries=max_retries,
timeout=timeout,
)
logging.update_environment_variables(
model=model,
@ -3350,6 +3358,7 @@ def embedding( # noqa: PLR0915
}
}
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj # type: ignore
@ -3411,6 +3420,7 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
max_retries=max_retries,
headers=headers or extra_headers,
litellm_params=litellm_params_dict,
)
elif (
model in litellm.open_ai_embedding_models
@ -4537,6 +4547,8 @@ def image_generation( # noqa: PLR0915
**non_default_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
@ -4607,6 +4619,7 @@ def image_generation( # noqa: PLR0915
aimg_generation=aimg_generation,
client=client,
headers=headers,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"
@ -5002,6 +5015,7 @@ def transcription(
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
litellm_logging_obj.update_environment_variables(
model=model,
@ -5055,6 +5069,7 @@ def transcription(
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"
@ -5157,7 +5172,7 @@ async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
@client
def speech(
def speech( # noqa: PLR0915
model: str,
input: str,
voice: Optional[Union[str, dict]] = None,
@ -5198,7 +5213,7 @@ def speech(
if max_retries is None:
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
litellm_params_dict = get_litellm_params(**kwargs)
logging_obj = kwargs.get("litellm_logging_obj", None)
logging_obj.update_environment_variables(
model=model,
@ -5315,6 +5330,7 @@ def speech(
timeout=timeout,
client=client, # pass AsyncOpenAI, OpenAI client
aspeech=aspeech,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":

View file

@ -71,7 +71,6 @@ from litellm.router_utils.batch_utils import (
_get_router_metadata_variable_name,
replace_model_in_jsonl,
)
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
from litellm.router_utils.clientside_credential_handler import (
get_dynamic_litellm_params,
is_clientside_credential,
@ -5360,36 +5359,12 @@ class Router:
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(
key=cache_key,
local_only=True,
parent_otel_span=parent_otel_span,
)
return client
else:
cache_key = f"{model_id}_async_client"
client = self.cache.get_cache(
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
)
# if client is None:
# """
# Re-initialize the client
# """
# InitalizeOpenAISDKClient.set_client(
# litellm_router_instance=self, model=deployment
# )
# client = self.cache.get_cache(
# key=cache_key,
# local_only=True,
# parent_otel_span=parent_otel_span,
# )
return client
else:
if kwargs.get("stream") is True:
@ -5397,32 +5372,12 @@ class Router:
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client
else:
cache_key = f"{model_id}_client"
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
if client is None:
"""
Re-initialize the client
"""
InitalizeOpenAISDKClient.set_client(
litellm_router_instance=self, model=deployment
)
client = self.cache.get_cache(
key=cache_key, parent_otel_span=parent_otel_span
)
return client
def _pre_call_checks( # noqa: PLR0915

View file

@ -1,6 +1,5 @@
import asyncio
import os
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Optional
import httpx
import openai
@ -8,14 +7,6 @@ import openai
import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_router_logger
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
from litellm.llms.azure.common_utils import (
get_azure_ad_token_from_entrata_id,
get_azure_ad_token_from_username_password,
)
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.utils import calculate_max_parallel_requests
if TYPE_CHECKING:
@ -200,274 +191,6 @@ class InitalizeOpenAISDKClient:
organization_env_name = organization.replace("os.environ/", "")
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
# If we have api_key, then we have higher priority
if not api_key and litellm_params.get("tenant_id"):
verbose_router_logger.debug(
"Using Azure AD Token Provider for Azure Auth"
)
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
tenant_id=litellm_params.get("tenant_id"),
client_id=litellm_params.get("client_id"),
client_secret=litellm_params.get("client_secret"),
)
if litellm_params.get("azure_username") and litellm_params.get(
"azure_password"
):
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=litellm_params.get("azure_username"),
azure_password=litellm_params.get("azure_password"),
client_id=litellm_params.get("client_id"),
)
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
if api_base is None or not isinstance(api_base, str):
filtered_litellm_params = {
k: v
for k, v in model["litellm_params"].items()
if k != "api_key"
}
_filtered_model = {
"model_name": model["model_name"],
"litellm_params": filtered_litellm_params,
}
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
)
azure_ad_token = litellm_params.get("azure_ad_token")
if azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
try:
azure_ad_token_provider = get_azure_ad_token_provider()
except ValueError:
verbose_router_logger.debug(
"Azure AD Token Provider could not be used."
)
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
)
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients can have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
}
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
from litellm.llms.azure.azure import (
select_azure_base_url_or_endpoint,
)
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
), # type: ignore
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
if InitalizeOpenAISDKClient.should_initialize_sync_client(
litellm_router_instance=litellm_router_instance
):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
),
verify=litellm.ssl_verify,
),
)
litellm_router_instance.cache.set_cache(
key=cache_key,
value=_client,
ttl=client_ttl,
local_only=True,
) # cache for 1 hr
else:
_api_key = api_key # type: ignore
if _api_key is not None and isinstance(_api_key, str):

View file

@ -191,6 +191,42 @@ class CallTypes(Enum):
retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint"
anthropic_messages = "anthropic_messages"
get_assistants = "get_assistants"
aget_assistants = "aget_assistants"
create_assistants = "create_assistants"
acreate_assistants = "acreate_assistants"
delete_assistant = "delete_assistant"
adelete_assistant = "adelete_assistant"
acreate_thread = "acreate_thread"
create_thread = "create_thread"
aget_thread = "aget_thread"
get_thread = "get_thread"
a_add_message = "a_add_message"
add_message = "add_message"
aget_messages = "aget_messages"
get_messages = "get_messages"
arun_thread = "arun_thread"
run_thread = "run_thread"
arun_thread_stream = "arun_thread_stream"
run_thread_stream = "run_thread_stream"
afile_retrieve = "afile_retrieve"
file_retrieve = "file_retrieve"
afile_delete = "afile_delete"
file_delete = "file_delete"
afile_list = "afile_list"
file_list = "file_list"
acreate_file = "acreate_file"
create_file = "create_file"
afile_content = "afile_content"
file_content = "file_content"
create_fine_tuning_job = "create_fine_tuning_job"
acreate_fine_tuning_job = "acreate_fine_tuning_job"
acancel_fine_tuning_job = "acancel_fine_tuning_job"
cancel_fine_tuning_job = "cancel_fine_tuning_job"
alist_fine_tuning_jobs = "alist_fine_tuning_jobs"
list_fine_tuning_jobs = "list_fine_tuning_jobs"
aretrieve_fine_tuning_job = "aretrieve_fine_tuning_job"
retrieve_fine_tuning_job = "retrieve_fine_tuning_job"
CallTypesLiteral = Literal[

View file

@ -0,0 +1,371 @@
import json
import os
import sys
from typing import Callable, Optional
from unittest.mock import MagicMock, patch
import pytest
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.utils import CallTypes
# Mock the necessary dependencies
@pytest.fixture
def setup_mocks():
with patch(
"litellm.llms.azure.common_utils.get_azure_ad_token_from_entrata_id"
) as mock_entrata_token, patch(
"litellm.llms.azure.common_utils.get_azure_ad_token_from_username_password"
) as mock_username_password_token, patch(
"litellm.llms.azure.common_utils.get_azure_ad_token_from_oidc"
) as mock_oidc_token, patch(
"litellm.llms.azure.common_utils.get_azure_ad_token_provider"
) as mock_token_provider, patch(
"litellm.llms.azure.common_utils.litellm"
) as mock_litellm, patch(
"litellm.llms.azure.common_utils.verbose_logger"
) as mock_logger, patch(
"litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint"
) as mock_select_url:
# Configure mocks
mock_litellm.AZURE_DEFAULT_API_VERSION = "2023-05-15"
mock_litellm.enable_azure_ad_token_refresh = False
mock_entrata_token.return_value = lambda: "mock-entrata-token"
mock_username_password_token.return_value = (
lambda: "mock-username-password-token"
)
mock_oidc_token.return_value = "mock-oidc-token"
mock_token_provider.return_value = lambda: "mock-default-token"
mock_select_url.side_effect = (
lambda azure_client_params, **kwargs: azure_client_params
)
yield {
"entrata_token": mock_entrata_token,
"username_password_token": mock_username_password_token,
"oidc_token": mock_oidc_token,
"token_provider": mock_token_provider,
"litellm": mock_litellm,
"logger": mock_logger,
"select_url": mock_select_url,
}
def test_initialize_with_api_key(setup_mocks):
# Test with api_key provided
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={},
api_key="test-api-key",
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version="2023-06-01",
)
# Verify expected result
assert result["api_key"] == "test-api-key"
assert result["azure_endpoint"] == "https://test.openai.azure.com"
assert result["api_version"] == "2023-06-01"
assert "azure_ad_token" in result
assert result["azure_ad_token"] is None
def test_initialize_with_tenant_credentials(setup_mocks):
# Test with tenant_id, client_id, and client_secret provided
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={
"tenant_id": "test-tenant-id",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
},
api_key=None,
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify that get_azure_ad_token_from_entrata_id was called
setup_mocks["entrata_token"].assert_called_once_with(
tenant_id="test-tenant-id",
client_id="test-client-id",
client_secret="test-client-secret",
)
# Verify expected result
assert result["api_key"] is None
assert result["azure_endpoint"] == "https://test.openai.azure.com"
assert "azure_ad_token_provider" in result
def test_initialize_with_username_password(setup_mocks):
# Test with azure_username, azure_password, and client_id provided
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={
"azure_username": "test-username",
"azure_password": "test-password",
"client_id": "test-client-id",
},
api_key=None,
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify that get_azure_ad_token_from_username_password was called
setup_mocks["username_password_token"].assert_called_once_with(
azure_username="test-username",
azure_password="test-password",
client_id="test-client-id",
)
# Verify expected result
assert "azure_ad_token_provider" in result
def test_initialize_with_oidc_token(setup_mocks):
# Test with azure_ad_token that starts with "oidc/"
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={"azure_ad_token": "oidc/test-token"},
api_key=None,
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify that get_azure_ad_token_from_oidc was called
setup_mocks["oidc_token"].assert_called_once_with("oidc/test-token")
# Verify expected result
assert result["azure_ad_token"] == "mock-oidc-token"
def test_initialize_with_enable_token_refresh(setup_mocks):
# Enable token refresh
setup_mocks["litellm"].enable_azure_ad_token_refresh = True
# Test with token refresh enabled
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={},
api_key=None,
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify that get_azure_ad_token_provider was called
setup_mocks["token_provider"].assert_called_once()
# Verify expected result
assert "azure_ad_token_provider" in result
def test_initialize_with_token_refresh_error(setup_mocks):
# Enable token refresh but make it raise an error
setup_mocks["litellm"].enable_azure_ad_token_refresh = True
setup_mocks["token_provider"].side_effect = ValueError("Token provider error")
# Test with token refresh enabled but raising error
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={},
api_key=None,
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify error was logged
setup_mocks["logger"].debug.assert_any_call(
"Azure AD Token Provider could not be used."
)
def test_api_version_from_env_var(setup_mocks):
# Test api_version from environment variable
with patch.dict(os.environ, {"AZURE_API_VERSION": "2023-07-01"}):
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={},
api_key="test-api-key",
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version=None,
)
# Verify expected result
assert result["api_version"] == "2023-07-01"
def test_select_azure_base_url_called(setup_mocks):
# Test that select_azure_base_url_or_endpoint is called
result = BaseAzureLLM().initialize_azure_sdk_client(
litellm_params={},
api_key="test-api-key",
api_base="https://test.openai.azure.com",
model_name="gpt-4",
api_version="2023-06-01",
)
# Verify that select_azure_base_url_or_endpoint was called
setup_mocks["select_url"].assert_called_once()
@pytest.mark.parametrize(
"call_type",
[
call_type
for call_type in CallTypes.__members__.values()
if call_type.name.startswith("a")
and call_type.name
not in [
"amoderation",
"arerank",
"arealtime",
"anthropic_messages",
"add_message",
"arun_thread_stream",
]
],
)
@pytest.mark.asyncio
async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
from litellm.router import Router
# Create a router with an Azure model
azure_model_name = "azure/chatgpt-v-2"
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": azure_model_name,
"api_key": "test-api-key",
"api_version": os.getenv("AZURE_API_VERSION", "2023-05-15"),
"api_base": os.getenv(
"AZURE_API_BASE", "https://test.openai.azure.com"
),
},
}
],
)
# Prepare test input based on call type
test_inputs = {
"acompletion": {
"messages": [{"role": "user", "content": "Hello, how are you?"}]
},
"atext_completion": {"prompt": "Hello, how are you?"},
"aimage_generation": {"prompt": "Hello, how are you?"},
"aembedding": {"input": "Hello, how are you?"},
"arerank": {"input": "Hello, how are you?"},
"atranscription": {"file": "path/to/file"},
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
"acreate_batch": {
"completion_window": 10,
"endpoint": "https://test.openai.azure.com",
"input_file_id": "123",
},
"aretrieve_batch": {"batch_id": "123"},
"aget_assistants": {"custom_llm_provider": "azure"},
"acreate_assistants": {"custom_llm_provider": "azure"},
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
"acreate_thread": {"custom_llm_provider": "azure"},
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
"a_add_message": {
"custom_llm_provider": "azure",
"thread_id": "123",
"role": "user",
"content": "Hello, how are you?",
},
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
"arun_thread": {
"custom_llm_provider": "azure",
"assistant_id": "123",
"thread_id": "123",
},
"acreate_file": {
"custom_llm_provider": "azure",
"file": MagicMock(),
"purpose": "assistants",
},
}
# Get appropriate input for this call type
input_kwarg = test_inputs.get(call_type.value, {})
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
if call_type == CallTypes.atranscription:
patch_target = (
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
)
elif call_type == CallTypes.arerank:
patch_target = (
"litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_batch or call_type == CallTypes.aretrieve_batch:
patch_target = (
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
)
elif (
call_type == CallTypes.aget_assistants
or call_type == CallTypes.acreate_assistants
or call_type == CallTypes.adelete_assistant
or call_type == CallTypes.acreate_thread
or call_type == CallTypes.aget_thread
or call_type == CallTypes.a_add_message
or call_type == CallTypes.aget_messages
or call_type == CallTypes.arun_thread
):
patch_target = (
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
patch_target = (
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method
try:
get_attr = getattr(router, call_type.value, None)
if get_attr is None:
pytest.skip(
f"Skipping {call_type.value} because it is not supported on Router"
)
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,
num_retries=0,
azure_ad_token="oidc/test-token",
)
except Exception as e:
print(e)
# Verify initialize_azure_sdk_client was called
mock_init_azure.assert_called_once()
# Verify it was called with the right model name
calls = mock_init_azure.call_args_list
azure_calls = [call for call in calls]
litellm_params = azure_calls[0].kwargs["litellm_params"]
print("litellm_params", litellm_params)
assert (
"azure_ad_token" in litellm_params
), "azure_ad_token not found in parameters"
assert (
litellm_params["azure_ad_token"] == "oidc/test-token"
), "azure_ad_token is not correct"
# More detailed verification (optional)
for call in azure_calls:
assert "api_key" in call.kwargs, "api_key not found in parameters"
assert "api_base" in call.kwargs, "api_base not found in parameters"

View file

@ -556,12 +556,11 @@ async def test_azure_instruct(
@pytest.mark.parametrize("max_retries", [0, 4])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("sync_mode", [True, False])
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
@pytest.mark.asyncio
async def test_azure_embedding_max_retries_0(
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
):
from litellm import aembedding, embedding
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
"model": "azure/azure-embedding-model",
"input": "Hello world",
"max_retries": max_retries,
"stream": stream,
}
try:
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
print(e)
mock_select_azure_base_url_or_endpoint.assert_called_once()
print(
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
)
assert (
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
"max_retries"