mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #9140 from BerriAI/litellm_router_client_init_migration
Delegate router azure client init logic to azure provider
This commit is contained in:
commit
ee53e41213
19 changed files with 895 additions and 732 deletions
|
@ -15,6 +15,7 @@ import litellm
|
|||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import (
|
||||
exception_type,
|
||||
get_litellm_params,
|
||||
get_llm_provider,
|
||||
get_secret,
|
||||
supports_httpx_timeout,
|
||||
|
@ -86,6 +87,7 @@ def get_assistants(
|
|||
optional_params = GenericLiteLLMParams(
|
||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -169,6 +171,7 @@ def get_assistants(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_assistants=aget_assistants, # type: ignore
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -270,6 +273,7 @@ def create_assistants(
|
|||
optional_params = GenericLiteLLMParams(
|
||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -371,6 +375,7 @@ def create_assistants(
|
|||
client=client,
|
||||
async_create_assistants=async_create_assistants,
|
||||
create_assistant_data=create_assistant_data,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -445,6 +450,8 @@ def delete_assistant(
|
|||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
async_delete_assistants: Optional[bool] = kwargs.pop(
|
||||
"async_delete_assistants", None
|
||||
)
|
||||
|
@ -544,6 +551,7 @@ def delete_assistant(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
async_delete_assistants=async_delete_assistants,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -639,6 +647,7 @@ def create_thread(
|
|||
"""
|
||||
acreate_thread = kwargs.get("acreate_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -731,6 +740,7 @@ def create_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
acreate_thread=acreate_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -795,7 +805,7 @@ def get_thread(
|
|||
"""Get the thread object, given a thread_id"""
|
||||
aget_thread = kwargs.pop("aget_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -884,6 +894,7 @@ def get_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_thread=aget_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -972,6 +983,7 @@ def add_message(
|
|||
_message_data = MessageData(
|
||||
role=role, content=content, attachments=attachments, metadata=metadata
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
message_data = get_optional_params_add_message(
|
||||
|
@ -1068,6 +1080,7 @@ def add_message(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
a_add_message=a_add_message,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -1139,6 +1152,7 @@ def get_messages(
|
|||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
aget_messages = kwargs.pop("aget_messages", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -1225,6 +1239,7 @@ def get_messages(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
aget_messages=aget_messages,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -1337,6 +1352,7 @@ def run_thread(
|
|||
"""Run a given thread + assistant."""
|
||||
arun_thread = kwargs.pop("arun_thread", None)
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -1437,6 +1453,7 @@ def run_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
arun_thread=arun_thread,
|
||||
litellm_params=litellm_params_dict,
|
||||
) # type: ignore
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
|
|
@ -111,6 +111,7 @@ def create_batch(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||
litellm_params = get_litellm_params(**kwargs)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -217,6 +218,7 @@ def create_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_batch_data=_create_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
@ -320,15 +322,12 @@ def retrieve_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
**kwargs,
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
|
@ -424,6 +423,7 @@ def retrieve_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
retrieve_batch_data=_retrieve_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
@ -526,6 +526,10 @@ def list_batches(
|
|||
try:
|
||||
# set API KEY
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
|
@ -603,6 +607,7 @@ def list_batches(
|
|||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
@ -678,6 +683,10 @@ def cancel_batch(
|
|||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
@ -765,6 +774,7 @@ def cancel_batch(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
cancel_batch_data=_cancel_batch_request,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
|
|
|
@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
|
|||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
from litellm.utils import get_litellm_params, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
|
@ -546,6 +546,7 @@ def create_file(
|
|||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
|
@ -630,6 +631,7 @@ def create_file(
|
|||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
|
|
|
@ -58,6 +58,7 @@ def get_litellm_params(
|
|||
async_call: Optional[bool] = None,
|
||||
ssl_verify: Optional[bool] = None,
|
||||
merge_reasoning_content_in_choices: Optional[bool] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
litellm_params = {
|
||||
|
@ -99,5 +100,13 @@ def get_litellm_params(
|
|||
"async_call": async_call,
|
||||
"ssl_verify": ssl_verify,
|
||||
"merge_reasoning_content_in_choices": merge_reasoning_content_in_choices,
|
||||
"azure_ad_token": kwargs.get("azure_ad_token"),
|
||||
"tenant_id": kwargs.get("tenant_id"),
|
||||
"client_id": kwargs.get("client_id"),
|
||||
"client_secret": kwargs.get("client_secret"),
|
||||
"azure_username": kwargs.get("azure_username"),
|
||||
"azure_password": kwargs.get("azure_password"),
|
||||
"max_retries": max_retries,
|
||||
"timeout": kwargs.get("timeout"),
|
||||
}
|
||||
return litellm_params
|
||||
|
|
|
@ -18,10 +18,10 @@ from ...types.llms.openai import (
|
|||
SyncCursorPage,
|
||||
Thread,
|
||||
)
|
||||
from ..base import BaseLLM
|
||||
from .common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureAssistantsAPI(BaseLLM):
|
||||
class AzureAssistantsAPI(BaseAzureLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -34,18 +34,17 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
)
|
||||
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
||||
|
@ -60,18 +59,18 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncAzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
azure_openai_client = AsyncAzureOpenAI(**data)
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
@ -89,6 +88,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -98,6 +98,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.list()
|
||||
|
@ -146,6 +147,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_assistants=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants is True:
|
||||
return self.async_get_assistants(
|
||||
|
@ -156,6 +158,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -165,6 +168,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
api_version=api_version,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.list()
|
||||
|
@ -184,6 +188,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> OpenAIMessage:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -193,6 +198,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||
|
@ -222,6 +228,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
a_add_message: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, OpenAIMessage]:
|
||||
...
|
||||
|
||||
|
@ -238,6 +245,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
a_add_message: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> OpenAIMessage:
|
||||
...
|
||||
|
||||
|
@ -255,6 +263,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message is True:
|
||||
return self.a_add_message(
|
||||
|
@ -267,6 +276,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -300,6 +310,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -309,6 +320,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
@ -329,6 +341,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_messages: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||
...
|
||||
|
||||
|
@ -344,6 +357,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_messages: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
...
|
||||
|
||||
|
@ -360,6 +374,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_messages=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages is True:
|
||||
return self.async_get_messages(
|
||||
|
@ -371,6 +386,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -380,6 +396,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
@ -399,6 +416,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -408,6 +426,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = {}
|
||||
|
@ -435,6 +454,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
acreate_thread: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
|
@ -451,6 +471,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AzureOpenAI],
|
||||
acreate_thread: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
|
@ -468,6 +489,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client=None,
|
||||
acreate_thread=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Here's an example:
|
||||
|
@ -490,6 +512,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
messages=messages,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -499,6 +522,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = {}
|
||||
|
@ -521,6 +545,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -530,6 +555,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
@ -550,6 +576,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_thread: Literal[True],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
|
@ -565,6 +592,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_thread: Optional[Literal[False]],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
|
@ -581,6 +609,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_thread=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread is True:
|
||||
return self.async_get_thread(
|
||||
|
@ -592,6 +621,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -601,6 +631,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
@ -629,6 +660,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Run:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -638,6 +670,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
|
@ -645,7 +678,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
@ -663,6 +696,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
|
@ -688,6 +722,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
|
@ -769,6 +804,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread is True:
|
||||
if stream is not None and stream is True:
|
||||
|
@ -780,6 +816,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=azure_client,
|
||||
|
@ -791,13 +828,14 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
|
@ -808,6 +846,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -817,6 +856,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if stream is not None and stream is True:
|
||||
|
@ -830,6 +870,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
|
@ -837,7 +878,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
metadata=metadata, # type: ignore
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
@ -855,6 +896,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
create_assistant_data: dict,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Assistant:
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -864,6 +906,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.create(
|
||||
|
@ -882,6 +925,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
create_assistant_data: dict,
|
||||
client=None,
|
||||
async_create_assistants=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if async_create_assistants is not None and async_create_assistants is True:
|
||||
return self.async_create_assistants(
|
||||
|
@ -893,6 +937,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
create_assistant_data=create_assistant_data,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -902,6 +947,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
||||
|
@ -918,6 +964,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
assistant_id: str,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -927,6 +974,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.delete(
|
||||
|
@ -945,6 +993,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
assistant_id: str,
|
||||
async_delete_assistants: Optional[bool] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||
return self.async_delete_assistant(
|
||||
|
@ -956,6 +1005,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
client=client,
|
||||
assistant_id=assistant_id,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
|
@ -965,6 +1015,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
||||
|
|
|
@ -9,11 +9,7 @@ from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
|||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import TranscriptionResponse, convert_to_model_response_object
|
||||
|
||||
from .azure import (
|
||||
AzureChatCompletion,
|
||||
get_azure_ad_token_from_oidc,
|
||||
select_azure_base_url_or_endpoint,
|
||||
)
|
||||
from .azure import AzureChatCompletion
|
||||
|
||||
|
||||
class AzureAudioTranscription(AzureChatCompletion):
|
||||
|
@ -32,29 +28,18 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
client=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
atranscription: bool = False,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> TranscriptionResponse:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
|
@ -8,7 +7,6 @@ import httpx # type: ignore
|
|||
from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.constants import DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
|
@ -25,15 +23,18 @@ from litellm.types.utils import (
|
|||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
convert_to_model_response_object,
|
||||
get_secret,
|
||||
modify_url,
|
||||
)
|
||||
|
||||
from ...types.llms.openai import HttpxBinaryResponseContent
|
||||
from ..base import BaseLLM
|
||||
from .common_utils import AzureOpenAIError, process_azure_headers
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
from .common_utils import (
|
||||
AzureOpenAIError,
|
||||
BaseAzureLLM,
|
||||
get_azure_ad_token_from_oidc,
|
||||
process_azure_headers,
|
||||
select_azure_base_url_or_endpoint,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIAssistantsAPIConfig:
|
||||
|
@ -98,93 +99,6 @@ class AzureOpenAIAssistantsAPIConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def _check_dynamic_azure_params(
|
||||
azure_client_params: dict,
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
|
||||
|
@ -206,7 +120,7 @@ def _check_dynamic_azure_params(
|
|||
return False
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -238,27 +152,16 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
client: Optional[Any],
|
||||
client_type: Literal["sync", "async"],
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
if client is None:
|
||||
if client_type == "sync":
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
|
@ -357,6 +260,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
max_retries = DEFAULT_MAX_RETRIES
|
||||
json_mode: Optional[bool] = optional_params.pop("json_mode", False)
|
||||
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
)
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
|
@ -417,6 +327,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
azure_client_params=azure_client_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
|
@ -434,6 +345,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
azure_client_params=azure_client_params,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
|
@ -470,28 +382,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = (
|
||||
azure_ad_token_provider
|
||||
)
|
||||
|
||||
if (
|
||||
client is None
|
||||
or not isinstance(client, AzureOpenAI)
|
||||
|
@ -562,30 +452,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
convert_tool_call_to_json_mode: Optional[bool] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
azure_client_params: dict = {},
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.aclient_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
# setting Azure client
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
|
@ -742,28 +612,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
azure_client_params: dict = {},
|
||||
):
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"http_client": litellm.aclient_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
|
@ -824,6 +675,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
):
|
||||
response = None
|
||||
try:
|
||||
|
||||
if client is None:
|
||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
|
@ -875,6 +727,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
aembedding=None,
|
||||
headers: Optional[dict] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> EmbeddingResponse:
|
||||
if headers:
|
||||
optional_params["extra_headers"] = headers
|
||||
|
@ -890,29 +743,14 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
if aembedding:
|
||||
azure_client_params["http_client"] = litellm.aclient_session
|
||||
else:
|
||||
azure_client_params["http_client"] = litellm.client_session
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -1272,6 +1110,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
|
@ -1296,25 +1135,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
# init AzureOpenAI Client
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"azure_endpoint": api_base,
|
||||
"azure_deployment": model,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
model_name=model or "",
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
)
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
elif azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
||||
|
||||
|
@ -1377,6 +1204,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token_provider: Optional[Callable] = None,
|
||||
aspeech: Optional[bool] = None,
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
@ -1395,6 +1223,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
azure_client: AzureOpenAI = self._get_sync_azure_client(
|
||||
|
@ -1408,6 +1237,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
client_type="sync",
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
response = azure_client.audio.speech.create(
|
||||
|
@ -1432,6 +1262,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
client=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
|
||||
azure_client: AsyncAzureOpenAI = self._get_sync_azure_client(
|
||||
|
@ -1445,6 +1276,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
client_type="async",
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
azure_response = await azure_client.audio.speech.create(
|
||||
|
|
|
@ -6,7 +6,6 @@ from typing import Any, Coroutine, Optional, Union, cast
|
|||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
|
@ -16,8 +15,10 @@ from litellm.types.llms.openai import (
|
|||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
class AzureBatchesAPI:
|
||||
|
||||
class AzureBatchesAPI(BaseAzureLLM):
|
||||
"""
|
||||
Azure methods to support for batches
|
||||
- create_batch()
|
||||
|
@ -29,38 +30,6 @@ class AzureBatchesAPI:
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
async def acreate_batch(
|
||||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
|
@ -79,16 +48,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -125,16 +94,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -173,16 +142,16 @@ class AzureBatchesAPI:
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
@ -212,16 +181,16 @@ class AzureBatchesAPI:
|
|||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
)
|
||||
if azure_client is None:
|
||||
|
|
|
@ -4,50 +4,68 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
|
|||
Written separately to handle faking streaming for o1 and o3 models.
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai.openai import OpenAIChatCompletion
|
||||
from ..common_utils import get_azure_openai_client
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
|
||||
def _get_openai_client(
|
||||
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
|
||||
def completion(
|
||||
self,
|
||||
is_async: bool,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||
max_retries: Optional[int] = 2,
|
||||
dynamic_params: Optional[bool] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
AsyncOpenAI,
|
||||
AzureOpenAI,
|
||||
AsyncAzureOpenAI,
|
||||
]
|
||||
]:
|
||||
|
||||
# Override to use Azure-specific client initialization
|
||||
if not isinstance(client, AzureOpenAI) and not isinstance(
|
||||
client, AsyncAzureOpenAI
|
||||
):
|
||||
client = None
|
||||
|
||||
return get_azure_openai_client(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
client = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=is_async,
|
||||
)
|
||||
return super().completion(
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client,
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
@ -5,9 +7,15 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
|||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
|
||||
class AzureOpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
|
@ -27,39 +35,6 @@ class AzureOpenAIError(BaseLLMException):
|
|||
)
|
||||
|
||||
|
||||
def get_azure_openai_client(
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
|
||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "x-ratelimit-limit-requests" in headers:
|
||||
|
@ -178,3 +153,199 @@ def get_azure_ad_token_from_username_password(
|
|||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||
azure_client_id = os.getenv("AZURE_CLIENT_ID", None)
|
||||
azure_tenant_id = os.getenv("AZURE_TENANT_ID", None)
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret_str(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": "https://cognitiveservices.azure.com/.default",
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
class BaseAzureLLM:
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
)
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
def initialize_azure_sdk_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_name: str,
|
||||
api_version: Optional[str],
|
||||
) -> dict:
|
||||
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
# If we have api_key, then we have higher priority
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
tenant_id = litellm_params.get("tenant_id")
|
||||
client_id = litellm_params.get("client_id")
|
||||
client_secret = litellm_params.get("client_secret")
|
||||
azure_username = litellm_params.get("azure_username")
|
||||
azure_password = litellm_params.get("azure_password")
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
timeout = litellm_params.get("timeout")
|
||||
if not api_key and tenant_id and client_id and client_secret:
|
||||
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
if azure_username and azure_password and client_id:
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=azure_username,
|
||||
azure_password=azure_password,
|
||||
client_id=client_id,
|
||||
)
|
||||
|
||||
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_logger.debug("Azure AD Token Provider could not be used.")
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
_api_key = api_key
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
verbose_logger.debug(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||
)
|
||||
azure_client_params = {
|
||||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
"http_client": litellm.client_session,
|
||||
}
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
azure_client_params["timeout"] = timeout
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
||||
return azure_client_params
|
||||
|
|
|
@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
|||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import *
|
||||
|
||||
from ..common_utils import get_azure_openai_client
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIFilesAPI(BaseLLM):
|
||||
class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
- create_file()
|
||||
|
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
|
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
organization=None,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
|
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=None,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
purpose: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=None, # openai param
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import Optional, Union
|
|||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.llms.azure.files.handler import get_azure_openai_client
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
|
||||
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
||||
"""
|
||||
|
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
|||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
|
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
|||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
client = None
|
||||
|
||||
return get_azure_openai_client(
|
||||
return self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -27,6 +27,7 @@ class OpenAIFineTuningAPI:
|
|||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[
|
||||
Union[
|
||||
OpenAI,
|
||||
|
|
|
@ -1162,6 +1162,14 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
merge_reasoning_content_in_choices=kwargs.get(
|
||||
"merge_reasoning_content_in_choices", None
|
||||
),
|
||||
azure_ad_token=kwargs.get("azure_ad_token"),
|
||||
tenant_id=kwargs.get("tenant_id"),
|
||||
client_id=kwargs.get("client_id"),
|
||||
client_secret=kwargs.get("client_secret"),
|
||||
azure_username=kwargs.get("azure_username"),
|
||||
azure_password=kwargs.get("azure_password"),
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -3350,6 +3358,7 @@ def embedding( # noqa: PLR0915
|
|||
}
|
||||
}
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
logging: Logging = litellm_logging_obj # type: ignore
|
||||
|
@ -3411,6 +3420,7 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
max_retries=max_retries,
|
||||
headers=headers or extra_headers,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif (
|
||||
model in litellm.open_ai_embedding_models
|
||||
|
@ -4537,6 +4547,8 @@ def image_generation( # noqa: PLR0915
|
|||
**non_default_params,
|
||||
)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
logging: Logging = litellm_logging_obj
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -4607,6 +4619,7 @@ def image_generation( # noqa: PLR0915
|
|||
aimg_generation=aimg_generation,
|
||||
client=client,
|
||||
headers=headers,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "openai"
|
||||
|
@ -5002,6 +5015,7 @@ def transcription(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -5055,6 +5069,7 @@ def transcription(
|
|||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "openai"
|
||||
|
@ -5157,7 +5172,7 @@ async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent:
|
|||
|
||||
|
||||
@client
|
||||
def speech(
|
||||
def speech( # noqa: PLR0915
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, dict]] = None,
|
||||
|
@ -5198,7 +5213,7 @@ def speech(
|
|||
|
||||
if max_retries is None:
|
||||
max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -5315,6 +5330,7 @@ def speech(
|
|||
timeout=timeout,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
aspeech=aspeech,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
||||
|
||||
|
|
|
@ -71,7 +71,6 @@ from litellm.router_utils.batch_utils import (
|
|||
_get_router_metadata_variable_name,
|
||||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
||||
from litellm.router_utils.clientside_credential_handler import (
|
||||
get_dynamic_litellm_params,
|
||||
is_clientside_credential,
|
||||
|
@ -5360,36 +5359,12 @@ class Router:
|
|||
client = self.cache.get_cache(
|
||||
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
|
||||
)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(
|
||||
key=cache_key,
|
||||
local_only=True,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
return client
|
||||
else:
|
||||
cache_key = f"{model_id}_async_client"
|
||||
client = self.cache.get_cache(
|
||||
key=cache_key, local_only=True, parent_otel_span=parent_otel_span
|
||||
)
|
||||
# if client is None:
|
||||
# """
|
||||
# Re-initialize the client
|
||||
# """
|
||||
# InitalizeOpenAISDKClient.set_client(
|
||||
# litellm_router_instance=self, model=deployment
|
||||
# )
|
||||
# client = self.cache.get_cache(
|
||||
# key=cache_key,
|
||||
# local_only=True,
|
||||
# parent_otel_span=parent_otel_span,
|
||||
# )
|
||||
return client
|
||||
else:
|
||||
if kwargs.get("stream") is True:
|
||||
|
@ -5397,32 +5372,12 @@ class Router:
|
|||
client = self.cache.get_cache(
|
||||
key=cache_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(
|
||||
key=cache_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
return client
|
||||
else:
|
||||
cache_key = f"{model_id}_client"
|
||||
client = self.cache.get_cache(
|
||||
key=cache_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
if client is None:
|
||||
"""
|
||||
Re-initialize the client
|
||||
"""
|
||||
InitalizeOpenAISDKClient.set_client(
|
||||
litellm_router_instance=self, model=deployment
|
||||
)
|
||||
client = self.cache.get_cache(
|
||||
key=cache_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
return client
|
||||
|
||||
def _pre_call_checks( # noqa: PLR0915
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
|
@ -8,14 +7,6 @@ import openai
|
|||
import litellm
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.llms.azure.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.llms.azure.common_utils import (
|
||||
get_azure_ad_token_from_entrata_id,
|
||||
get_azure_ad_token_from_username_password,
|
||||
)
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.utils import calculate_max_parallel_requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -200,274 +191,6 @@ class InitalizeOpenAISDKClient:
|
|||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = get_secret_str(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
# If we have api_key, then we have higher priority
|
||||
if not api_key and litellm_params.get("tenant_id"):
|
||||
verbose_router_logger.debug(
|
||||
"Using Azure AD Token Provider for Azure Auth"
|
||||
)
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
tenant_id=litellm_params.get("tenant_id"),
|
||||
client_id=litellm_params.get("client_id"),
|
||||
client_secret=litellm_params.get("client_secret"),
|
||||
)
|
||||
if litellm_params.get("azure_username") and litellm_params.get(
|
||||
"azure_password"
|
||||
):
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=litellm_params.get("azure_username"),
|
||||
azure_password=litellm_params.get("azure_password"),
|
||||
client_id=litellm_params.get("client_id"),
|
||||
)
|
||||
|
||||
if custom_llm_provider == "azure" or custom_llm_provider == "azure_text":
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
filtered_litellm_params = {
|
||||
k: v
|
||||
for k, v in model["litellm_params"].items()
|
||||
if k != "api_key"
|
||||
}
|
||||
_filtered_model = {
|
||||
"model_name": model["model_name"],
|
||||
"litellm_params": filtered_litellm_params,
|
||||
}
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {_filtered_model}"
|
||||
)
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
if azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
elif (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider()
|
||||
except ValueError:
|
||||
verbose_router_logger.debug(
|
||||
"Azure AD Token Provider could not be used."
|
||||
)
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
if "gateway.ai.cloudflare.com" in api_base:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
azure_model = model_name.replace("azure/", "")
|
||||
api_base += f"{azure_model}"
|
||||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
# streaming clients can have diff timeouts
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
else:
|
||||
_api_key = api_key
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
verbose_router_logger.debug(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||
)
|
||||
azure_client_params = {
|
||||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = (
|
||||
azure_ad_token_provider
|
||||
)
|
||||
from litellm.llms.azure.azure import (
|
||||
select_azure_base_url_or_endpoint,
|
||||
)
|
||||
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params
|
||||
)
|
||||
|
||||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
), # type: ignore
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
# streaming clients should have diff timeouts
|
||||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
),
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
if InitalizeOpenAISDKClient.should_initialize_sync_client(
|
||||
litellm_router_instance=litellm_router_instance
|
||||
):
|
||||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
),
|
||||
verify=litellm.ssl_verify,
|
||||
),
|
||||
)
|
||||
litellm_router_instance.cache.set_cache(
|
||||
key=cache_key,
|
||||
value=_client,
|
||||
ttl=client_ttl,
|
||||
local_only=True,
|
||||
) # cache for 1 hr
|
||||
|
||||
else:
|
||||
_api_key = api_key # type: ignore
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
|
|
|
@ -191,6 +191,42 @@ class CallTypes(Enum):
|
|||
retrieve_batch = "retrieve_batch"
|
||||
pass_through = "pass_through_endpoint"
|
||||
anthropic_messages = "anthropic_messages"
|
||||
get_assistants = "get_assistants"
|
||||
aget_assistants = "aget_assistants"
|
||||
create_assistants = "create_assistants"
|
||||
acreate_assistants = "acreate_assistants"
|
||||
delete_assistant = "delete_assistant"
|
||||
adelete_assistant = "adelete_assistant"
|
||||
acreate_thread = "acreate_thread"
|
||||
create_thread = "create_thread"
|
||||
aget_thread = "aget_thread"
|
||||
get_thread = "get_thread"
|
||||
a_add_message = "a_add_message"
|
||||
add_message = "add_message"
|
||||
aget_messages = "aget_messages"
|
||||
get_messages = "get_messages"
|
||||
arun_thread = "arun_thread"
|
||||
run_thread = "run_thread"
|
||||
arun_thread_stream = "arun_thread_stream"
|
||||
run_thread_stream = "run_thread_stream"
|
||||
afile_retrieve = "afile_retrieve"
|
||||
file_retrieve = "file_retrieve"
|
||||
afile_delete = "afile_delete"
|
||||
file_delete = "file_delete"
|
||||
afile_list = "afile_list"
|
||||
file_list = "file_list"
|
||||
acreate_file = "acreate_file"
|
||||
create_file = "create_file"
|
||||
afile_content = "afile_content"
|
||||
file_content = "file_content"
|
||||
create_fine_tuning_job = "create_fine_tuning_job"
|
||||
acreate_fine_tuning_job = "acreate_fine_tuning_job"
|
||||
acancel_fine_tuning_job = "acancel_fine_tuning_job"
|
||||
cancel_fine_tuning_job = "cancel_fine_tuning_job"
|
||||
alist_fine_tuning_jobs = "alist_fine_tuning_jobs"
|
||||
list_fine_tuning_jobs = "list_fine_tuning_jobs"
|
||||
aretrieve_fine_tuning_job = "aretrieve_fine_tuning_job"
|
||||
retrieve_fine_tuning_job = "retrieve_fine_tuning_job"
|
||||
|
||||
|
||||
CallTypesLiteral = Literal[
|
||||
|
|
371
tests/litellm/llms/azure/test_azure_common_utils.py
Normal file
371
tests/litellm/llms/azure/test_azure_common_utils.py
Normal file
|
@ -0,0 +1,371 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
|
||||
# Mock the necessary dependencies
|
||||
@pytest.fixture
|
||||
def setup_mocks():
|
||||
with patch(
|
||||
"litellm.llms.azure.common_utils.get_azure_ad_token_from_entrata_id"
|
||||
) as mock_entrata_token, patch(
|
||||
"litellm.llms.azure.common_utils.get_azure_ad_token_from_username_password"
|
||||
) as mock_username_password_token, patch(
|
||||
"litellm.llms.azure.common_utils.get_azure_ad_token_from_oidc"
|
||||
) as mock_oidc_token, patch(
|
||||
"litellm.llms.azure.common_utils.get_azure_ad_token_provider"
|
||||
) as mock_token_provider, patch(
|
||||
"litellm.llms.azure.common_utils.litellm"
|
||||
) as mock_litellm, patch(
|
||||
"litellm.llms.azure.common_utils.verbose_logger"
|
||||
) as mock_logger, patch(
|
||||
"litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint"
|
||||
) as mock_select_url:
|
||||
|
||||
# Configure mocks
|
||||
mock_litellm.AZURE_DEFAULT_API_VERSION = "2023-05-15"
|
||||
mock_litellm.enable_azure_ad_token_refresh = False
|
||||
|
||||
mock_entrata_token.return_value = lambda: "mock-entrata-token"
|
||||
mock_username_password_token.return_value = (
|
||||
lambda: "mock-username-password-token"
|
||||
)
|
||||
mock_oidc_token.return_value = "mock-oidc-token"
|
||||
mock_token_provider.return_value = lambda: "mock-default-token"
|
||||
|
||||
mock_select_url.side_effect = (
|
||||
lambda azure_client_params, **kwargs: azure_client_params
|
||||
)
|
||||
|
||||
yield {
|
||||
"entrata_token": mock_entrata_token,
|
||||
"username_password_token": mock_username_password_token,
|
||||
"oidc_token": mock_oidc_token,
|
||||
"token_provider": mock_token_provider,
|
||||
"litellm": mock_litellm,
|
||||
"logger": mock_logger,
|
||||
"select_url": mock_select_url,
|
||||
}
|
||||
|
||||
|
||||
def test_initialize_with_api_key(setup_mocks):
|
||||
# Test with api_key provided
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={},
|
||||
api_key="test-api-key",
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version="2023-06-01",
|
||||
)
|
||||
|
||||
# Verify expected result
|
||||
assert result["api_key"] == "test-api-key"
|
||||
assert result["azure_endpoint"] == "https://test.openai.azure.com"
|
||||
assert result["api_version"] == "2023-06-01"
|
||||
assert "azure_ad_token" in result
|
||||
assert result["azure_ad_token"] is None
|
||||
|
||||
|
||||
def test_initialize_with_tenant_credentials(setup_mocks):
|
||||
# Test with tenant_id, client_id, and client_secret provided
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={
|
||||
"tenant_id": "test-tenant-id",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
},
|
||||
api_key=None,
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify that get_azure_ad_token_from_entrata_id was called
|
||||
setup_mocks["entrata_token"].assert_called_once_with(
|
||||
tenant_id="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
client_secret="test-client-secret",
|
||||
)
|
||||
|
||||
# Verify expected result
|
||||
assert result["api_key"] is None
|
||||
assert result["azure_endpoint"] == "https://test.openai.azure.com"
|
||||
assert "azure_ad_token_provider" in result
|
||||
|
||||
|
||||
def test_initialize_with_username_password(setup_mocks):
|
||||
# Test with azure_username, azure_password, and client_id provided
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={
|
||||
"azure_username": "test-username",
|
||||
"azure_password": "test-password",
|
||||
"client_id": "test-client-id",
|
||||
},
|
||||
api_key=None,
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify that get_azure_ad_token_from_username_password was called
|
||||
setup_mocks["username_password_token"].assert_called_once_with(
|
||||
azure_username="test-username",
|
||||
azure_password="test-password",
|
||||
client_id="test-client-id",
|
||||
)
|
||||
|
||||
# Verify expected result
|
||||
assert "azure_ad_token_provider" in result
|
||||
|
||||
|
||||
def test_initialize_with_oidc_token(setup_mocks):
|
||||
# Test with azure_ad_token that starts with "oidc/"
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={"azure_ad_token": "oidc/test-token"},
|
||||
api_key=None,
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify that get_azure_ad_token_from_oidc was called
|
||||
setup_mocks["oidc_token"].assert_called_once_with("oidc/test-token")
|
||||
|
||||
# Verify expected result
|
||||
assert result["azure_ad_token"] == "mock-oidc-token"
|
||||
|
||||
|
||||
def test_initialize_with_enable_token_refresh(setup_mocks):
|
||||
# Enable token refresh
|
||||
setup_mocks["litellm"].enable_azure_ad_token_refresh = True
|
||||
|
||||
# Test with token refresh enabled
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify that get_azure_ad_token_provider was called
|
||||
setup_mocks["token_provider"].assert_called_once()
|
||||
|
||||
# Verify expected result
|
||||
assert "azure_ad_token_provider" in result
|
||||
|
||||
|
||||
def test_initialize_with_token_refresh_error(setup_mocks):
|
||||
# Enable token refresh but make it raise an error
|
||||
setup_mocks["litellm"].enable_azure_ad_token_refresh = True
|
||||
setup_mocks["token_provider"].side_effect = ValueError("Token provider error")
|
||||
|
||||
# Test with token refresh enabled but raising error
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify error was logged
|
||||
setup_mocks["logger"].debug.assert_any_call(
|
||||
"Azure AD Token Provider could not be used."
|
||||
)
|
||||
|
||||
|
||||
def test_api_version_from_env_var(setup_mocks):
|
||||
# Test api_version from environment variable
|
||||
with patch.dict(os.environ, {"AZURE_API_VERSION": "2023-07-01"}):
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={},
|
||||
api_key="test-api-key",
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
# Verify expected result
|
||||
assert result["api_version"] == "2023-07-01"
|
||||
|
||||
|
||||
def test_select_azure_base_url_called(setup_mocks):
|
||||
# Test that select_azure_base_url_or_endpoint is called
|
||||
result = BaseAzureLLM().initialize_azure_sdk_client(
|
||||
litellm_params={},
|
||||
api_key="test-api-key",
|
||||
api_base="https://test.openai.azure.com",
|
||||
model_name="gpt-4",
|
||||
api_version="2023-06-01",
|
||||
)
|
||||
|
||||
# Verify that select_azure_base_url_or_endpoint was called
|
||||
setup_mocks["select_url"].assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"call_type",
|
||||
[
|
||||
call_type
|
||||
for call_type in CallTypes.__members__.values()
|
||||
if call_type.name.startswith("a")
|
||||
and call_type.name
|
||||
not in [
|
||||
"amoderation",
|
||||
"arerank",
|
||||
"arealtime",
|
||||
"anthropic_messages",
|
||||
"add_message",
|
||||
"arun_thread_stream",
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||
from litellm.router import Router
|
||||
|
||||
# Create a router with an Azure model
|
||||
azure_model_name = "azure/chatgpt-v-2"
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": azure_model_name,
|
||||
"api_key": "test-api-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION", "2023-05-15"),
|
||||
"api_base": os.getenv(
|
||||
"AZURE_API_BASE", "https://test.openai.azure.com"
|
||||
),
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Prepare test input based on call type
|
||||
test_inputs = {
|
||||
"acompletion": {
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}]
|
||||
},
|
||||
"atext_completion": {"prompt": "Hello, how are you?"},
|
||||
"aimage_generation": {"prompt": "Hello, how are you?"},
|
||||
"aembedding": {"input": "Hello, how are you?"},
|
||||
"arerank": {"input": "Hello, how are you?"},
|
||||
"atranscription": {"file": "path/to/file"},
|
||||
"aspeech": {"input": "Hello, how are you?", "voice": "female"},
|
||||
"acreate_batch": {
|
||||
"completion_window": 10,
|
||||
"endpoint": "https://test.openai.azure.com",
|
||||
"input_file_id": "123",
|
||||
},
|
||||
"aretrieve_batch": {"batch_id": "123"},
|
||||
"aget_assistants": {"custom_llm_provider": "azure"},
|
||||
"acreate_assistants": {"custom_llm_provider": "azure"},
|
||||
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
|
||||
"acreate_thread": {"custom_llm_provider": "azure"},
|
||||
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||
"a_add_message": {
|
||||
"custom_llm_provider": "azure",
|
||||
"thread_id": "123",
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?",
|
||||
},
|
||||
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||
"arun_thread": {
|
||||
"custom_llm_provider": "azure",
|
||||
"assistant_id": "123",
|
||||
"thread_id": "123",
|
||||
},
|
||||
"acreate_file": {
|
||||
"custom_llm_provider": "azure",
|
||||
"file": MagicMock(),
|
||||
"purpose": "assistants",
|
||||
},
|
||||
}
|
||||
|
||||
# Get appropriate input for this call type
|
||||
input_kwarg = test_inputs.get(call_type.value, {})
|
||||
|
||||
patch_target = "litellm.main.azure_chat_completions.initialize_azure_sdk_client"
|
||||
if call_type == CallTypes.atranscription:
|
||||
patch_target = (
|
||||
"litellm.main.azure_audio_transcriptions.initialize_azure_sdk_client"
|
||||
)
|
||||
elif call_type == CallTypes.arerank:
|
||||
patch_target = (
|
||||
"litellm.rerank_api.main.azure_rerank.initialize_azure_sdk_client"
|
||||
)
|
||||
elif call_type == CallTypes.acreate_batch or call_type == CallTypes.aretrieve_batch:
|
||||
patch_target = (
|
||||
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
|
||||
)
|
||||
elif (
|
||||
call_type == CallTypes.aget_assistants
|
||||
or call_type == CallTypes.acreate_assistants
|
||||
or call_type == CallTypes.adelete_assistant
|
||||
or call_type == CallTypes.acreate_thread
|
||||
or call_type == CallTypes.aget_thread
|
||||
or call_type == CallTypes.a_add_message
|
||||
or call_type == CallTypes.aget_messages
|
||||
or call_type == CallTypes.arun_thread
|
||||
):
|
||||
patch_target = (
|
||||
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
|
||||
)
|
||||
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
|
||||
patch_target = (
|
||||
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
|
||||
)
|
||||
|
||||
# Mock the initialize_azure_sdk_client function
|
||||
with patch(patch_target) as mock_init_azure:
|
||||
# Also mock async_function_with_fallbacks to prevent actual API calls
|
||||
# Call the appropriate router method
|
||||
try:
|
||||
get_attr = getattr(router, call_type.value, None)
|
||||
if get_attr is None:
|
||||
pytest.skip(
|
||||
f"Skipping {call_type.value} because it is not supported on Router"
|
||||
)
|
||||
await getattr(router, call_type.value)(
|
||||
model="gpt-3.5-turbo",
|
||||
**input_kwarg,
|
||||
num_retries=0,
|
||||
azure_ad_token="oidc/test-token",
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Verify initialize_azure_sdk_client was called
|
||||
mock_init_azure.assert_called_once()
|
||||
|
||||
# Verify it was called with the right model name
|
||||
calls = mock_init_azure.call_args_list
|
||||
azure_calls = [call for call in calls]
|
||||
|
||||
litellm_params = azure_calls[0].kwargs["litellm_params"]
|
||||
print("litellm_params", litellm_params)
|
||||
|
||||
assert (
|
||||
"azure_ad_token" in litellm_params
|
||||
), "azure_ad_token not found in parameters"
|
||||
assert (
|
||||
litellm_params["azure_ad_token"] == "oidc/test-token"
|
||||
), "azure_ad_token is not correct"
|
||||
|
||||
# More detailed verification (optional)
|
||||
for call in azure_calls:
|
||||
assert "api_key" in call.kwargs, "api_key not found in parameters"
|
||||
assert "api_base" in call.kwargs, "api_base not found in parameters"
|
|
@ -556,12 +556,11 @@ async def test_azure_instruct(
|
|||
|
||||
|
||||
@pytest.mark.parametrize("max_retries", [0, 4])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@patch("litellm.llms.azure.azure.select_azure_base_url_or_endpoint")
|
||||
@patch("litellm.llms.azure.common_utils.select_azure_base_url_or_endpoint")
|
||||
@pytest.mark.asyncio
|
||||
async def test_azure_embedding_max_retries_0(
|
||||
mock_select_azure_base_url_or_endpoint, max_retries, stream, sync_mode
|
||||
mock_select_azure_base_url_or_endpoint, max_retries, sync_mode
|
||||
):
|
||||
from litellm import aembedding, embedding
|
||||
|
||||
|
@ -569,7 +568,6 @@ async def test_azure_embedding_max_retries_0(
|
|||
"model": "azure/azure-embedding-model",
|
||||
"input": "Hello world",
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
try:
|
||||
|
@ -581,6 +579,10 @@ async def test_azure_embedding_max_retries_0(
|
|||
print(e)
|
||||
|
||||
mock_select_azure_base_url_or_endpoint.assert_called_once()
|
||||
print(
|
||||
"mock_select_azure_base_url_or_endpoint.call_args.kwargs",
|
||||
mock_select_azure_base_url_or_endpoint.call_args.kwargs,
|
||||
)
|
||||
assert (
|
||||
mock_select_azure_base_url_or_endpoint.call_args.kwargs["azure_client_params"][
|
||||
"max_retries"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue