refactor(azure.py): refactor to have client init work across all endpoints

This commit is contained in:
Krrish Dholakia 2025-03-11 17:27:24 -07:00
parent d99d60a182
commit cbc2e84044
10 changed files with 296 additions and 129 deletions

View file

@ -15,6 +15,7 @@ import litellm
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
exception_type,
get_litellm_params,
get_llm_provider,
get_secret,
supports_httpx_timeout,
@ -86,6 +87,7 @@ def get_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -169,6 +171,7 @@ def get_assistants(
max_retries=optional_params.max_retries,
client=client,
aget_assistants=aget_assistants, # type: ignore
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -270,6 +273,7 @@ def create_assistants(
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -371,6 +375,7 @@ def create_assistants(
client=client,
async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -445,6 +450,8 @@ def delete_assistant(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
litellm_params_dict = get_litellm_params(**kwargs)
async_delete_assistants: Optional[bool] = kwargs.pop(
"async_delete_assistants", None
)
@ -544,6 +551,7 @@ def delete_assistant(
max_retries=optional_params.max_retries,
client=client,
async_delete_assistants=async_delete_assistants,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -639,6 +647,7 @@ def create_thread(
"""
acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -731,6 +740,7 @@ def create_thread(
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -795,7 +805,7 @@ def get_thread(
"""Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
@ -884,6 +894,7 @@ def get_thread(
max_retries=optional_params.max_retries,
client=client,
aget_thread=aget_thread,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -972,6 +983,7 @@ def add_message(
_message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
litellm_params_dict = get_litellm_params(**kwargs)
optional_params = GenericLiteLLMParams(**kwargs)
message_data = get_optional_params_add_message(
@ -1068,6 +1080,7 @@ def add_message(
max_retries=optional_params.max_retries,
client=client,
a_add_message=a_add_message,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1139,6 +1152,7 @@ def get_messages(
) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1225,6 +1239,7 @@ def get_messages(
max_retries=optional_params.max_retries,
client=client,
aget_messages=aget_messages,
litellm_params=litellm_params_dict,
)
else:
raise litellm.exceptions.BadRequestError(
@ -1337,6 +1352,7 @@ def run_thread(
"""Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1437,6 +1453,7 @@ def run_thread(
max_retries=optional_params.max_retries,
client=client,
arun_thread=arun_thread,
litellm_params=litellm_params_dict,
) # type: ignore
else:
raise litellm.exceptions.BadRequestError(

View file

@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
from litellm.utils import get_litellm_params, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI()
@ -546,6 +546,7 @@ def create_file(
try:
_is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -630,6 +631,7 @@ def create_file(
timeout=timeout,
max_retries=optional_params.max_retries,
create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
)
elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or ""

View file

@ -18,10 +18,10 @@ from ...types.llms.openai import (
SyncCursorPage,
Thread,
)
from ..base import BaseLLM
from .common_utils import BaseAzureLLM
class AzureAssistantsAPI(BaseLLM):
class AzureAssistantsAPI(BaseAzureLLM):
def __init__(self) -> None:
super().__init__()
@ -34,18 +34,17 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AzureOpenAI(**data) # type: ignore
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_openai_client = client
@ -60,18 +59,18 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncAzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data)
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
@ -89,6 +88,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -98,6 +98,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.list()
@ -146,6 +147,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_assistants=None,
litellm_params: Optional[dict] = None,
):
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
@ -156,6 +158,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -165,6 +168,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
api_version=api_version,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.list()
@ -184,6 +188,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -193,6 +198,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
@ -222,6 +228,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, OpenAIMessage]:
...
@ -238,6 +245,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> OpenAIMessage:
...
@ -255,6 +263,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
a_add_message: Optional[bool] = None,
litellm_params: Optional[dict] = None,
):
if a_add_message is not None and a_add_message is True:
return self.a_add_message(
@ -267,6 +276,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -300,6 +310,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -309,6 +320,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -329,6 +341,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@ -344,6 +357,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> SyncCursorPage[OpenAIMessage]:
...
@ -360,6 +374,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_messages=None,
litellm_params: Optional[dict] = None,
):
if aget_messages is not None and aget_messages is True:
return self.async_get_messages(
@ -371,6 +386,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -380,6 +396,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -399,6 +416,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -408,6 +426,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -435,6 +454,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -451,6 +471,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -468,6 +489,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
litellm_params: Optional[dict] = None,
):
"""
Here's an example:
@ -490,6 +512,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
messages=messages,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -499,6 +522,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
data = {}
@ -521,6 +545,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -530,6 +555,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -550,6 +576,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]:
...
@ -565,6 +592,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread:
...
@ -581,6 +609,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
aget_thread=None,
litellm_params: Optional[dict] = None,
):
if aget_thread is not None and aget_thread is True:
return self.async_get_thread(
@ -592,6 +621,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -601,6 +631,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -629,6 +660,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Run:
openai_client = self.async_get_azure_client(
api_key=api_key,
@ -638,6 +670,7 @@ class AzureAssistantsAPI(BaseLLM):
api_version=api_version,
azure_ad_token=azure_ad_token,
client=client,
litellm_params=litellm_params,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -645,7 +678,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -663,6 +696,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
@ -688,6 +722,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
@ -769,6 +804,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
litellm_params: Optional[dict] = None,
):
if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True:
@ -780,6 +816,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
return self.async_run_thread_stream(
client=azure_client,
@ -791,13 +828,14 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
stream=stream,
tools=tools,
@ -808,6 +846,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
openai_client = self.get_azure_client(
api_key=api_key,
@ -817,6 +856,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
if stream is not None and stream is True:
@ -830,6 +870,7 @@ class AzureAssistantsAPI(BaseLLM):
model=model,
tools=tools,
event_handler=event_handler,
litellm_params=litellm_params,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -837,7 +878,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
metadata=metadata, # type: ignore
model=model,
tools=tools,
)
@ -855,6 +896,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
create_assistant_data: dict,
litellm_params: Optional[dict] = None,
) -> Assistant:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -864,6 +906,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.create(
@ -882,6 +925,7 @@ class AzureAssistantsAPI(BaseLLM):
create_assistant_data: dict,
client=None,
async_create_assistants=None,
litellm_params: Optional[dict] = None,
):
if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants(
@ -893,6 +937,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
create_assistant_data=create_assistant_data,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -902,6 +947,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
@ -918,6 +964,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
assistant_id: str,
litellm_params: Optional[dict] = None,
):
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
@ -927,6 +974,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = await azure_openai_client.beta.assistants.delete(
@ -945,6 +993,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str,
async_delete_assistants: Optional[bool] = None,
client=None,
litellm_params: Optional[dict] = None,
):
if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant(
@ -956,6 +1005,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries,
client=client,
assistant_id=assistant_id,
litellm_params=litellm_params,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
@ -965,6 +1015,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout,
max_retries=max_retries,
client=client,
litellm_params=litellm_params,
)
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)

View file

@ -4,50 +4,70 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
Written separately to handle faking streaming for o1 and o3 models.
"""
from typing import Optional, Union
from typing import Any, Callable, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.types.llms.openai import Any
from litellm.types.utils import ModelResponse
from ...openai.openai import OpenAIChatCompletion
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
def _get_openai_client(
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
def completion(
self,
is_async: bool,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
max_retries: Optional[int] = 2,
dynamic_params: Optional[bool] = None,
azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
# Override to use Azure-specific client initialization
if not isinstance(client, AzureOpenAI) and not isinstance(
client, AsyncAzureOpenAI
):
client = None
return get_azure_openai_client(
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
client = self.get_azure_openai_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=is_async,
)
return super().completion(
model_response=model_response,
timeout=timeout,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
model=model,
messages=messages,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
api_version=api_version,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
acompletion=acompletion,
logger_fn=logger_fn,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
client=client,
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)

View file

@ -35,40 +35,6 @@ class AzureOpenAIError(BaseLLMException):
)
def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
@ -277,6 +243,33 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
class BaseAzureLLM:
def get_azure_openai_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
if _is_async is True:
openai_client = AsyncAzureOpenAI(**azure_client_params)
else:
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
openai_client = client
return openai_client
def initialize_azure_sdk_client(
self,
litellm_params: dict,
@ -294,6 +287,8 @@ class BaseAzureLLM:
client_secret = litellm_params.get("client_secret")
azure_username = litellm_params.get("azure_username")
azure_password = litellm_params.get("azure_password")
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if not api_key and tenant_id and client_id and client_secret:
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
@ -338,6 +333,10 @@ class BaseAzureLLM:
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
}
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

View file

@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import *
from ..common_utils import get_azure_openai_client
from ..common_utils import BaseAzureLLM
class AzureOpenAIFilesAPI(BaseLLM):
class AzureOpenAIFilesAPI(BaseAzureLLM):
"""
AzureOpenAI methods to support for batches
- create_file()
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
organization=None,
client=client,
_is_async=_is_async,
)
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
purpose: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client(
self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None, # openai param
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -3,11 +3,11 @@ from typing import Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.llms.azure.files.handler import get_azure_openai_client
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
"""
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
"""
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None
return get_azure_openai_client(
return self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version,
client=client,
_is_async=_is_async,

View file

@ -27,6 +27,7 @@ class OpenAIFineTuningAPI:
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[
Union[
OpenAI,

View file

@ -191,6 +191,42 @@ class CallTypes(Enum):
retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint"
anthropic_messages = "anthropic_messages"
get_assistants = "get_assistants"
aget_assistants = "aget_assistants"
create_assistants = "create_assistants"
acreate_assistants = "acreate_assistants"
delete_assistant = "delete_assistant"
adelete_assistant = "adelete_assistant"
acreate_thread = "acreate_thread"
create_thread = "create_thread"
aget_thread = "aget_thread"
get_thread = "get_thread"
a_add_message = "a_add_message"
add_message = "add_message"
aget_messages = "aget_messages"
get_messages = "get_messages"
arun_thread = "arun_thread"
run_thread = "run_thread"
arun_thread_stream = "arun_thread_stream"
run_thread_stream = "run_thread_stream"
afile_retrieve = "afile_retrieve"
file_retrieve = "file_retrieve"
afile_delete = "afile_delete"
file_delete = "file_delete"
afile_list = "afile_list"
file_list = "file_list"
acreate_file = "acreate_file"
create_file = "create_file"
afile_content = "afile_content"
file_content = "file_content"
create_fine_tuning_job = "create_fine_tuning_job"
acreate_fine_tuning_job = "acreate_fine_tuning_job"
acancel_fine_tuning_job = "acancel_fine_tuning_job"
cancel_fine_tuning_job = "cancel_fine_tuning_job"
alist_fine_tuning_jobs = "alist_fine_tuning_jobs"
list_fine_tuning_jobs = "list_fine_tuning_jobs"
aretrieve_fine_tuning_job = "aretrieve_fine_tuning_job"
retrieve_fine_tuning_job = "retrieve_fine_tuning_job"
CallTypesLiteral = Literal[

View file

@ -216,16 +216,18 @@ def test_select_azure_base_url_called(setup_mocks):
@pytest.mark.parametrize(
"call_type",
[
CallTypes.acompletion,
CallTypes.atext_completion,
CallTypes.aembedding,
CallTypes.atranscription,
CallTypes.aspeech,
CallTypes.aimage_generation,
# BATCHES ENDPOINTS
CallTypes.acreate_batch,
CallTypes.aretrieve_batch,
# ASSISTANT ENDPOINTS
call_type
for call_type in CallTypes.__members__.values()
if call_type.name.startswith("a")
and call_type.name
not in [
"amoderation",
"arerank",
"arealtime",
"anthropic_messages",
"add_message",
"arun_thread_stream",
]
],
)
@pytest.mark.asyncio
@ -267,6 +269,28 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"input_file_id": "123",
},
"aretrieve_batch": {"batch_id": "123"},
"aget_assistants": {"custom_llm_provider": "azure"},
"acreate_assistants": {"custom_llm_provider": "azure"},
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
"acreate_thread": {"custom_llm_provider": "azure"},
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
"a_add_message": {
"custom_llm_provider": "azure",
"thread_id": "123",
"role": "user",
"content": "Hello, how are you?",
},
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
"arun_thread": {
"custom_llm_provider": "azure",
"assistant_id": "123",
"thread_id": "123",
},
"acreate_file": {
"custom_llm_provider": "azure",
"file": MagicMock(),
"purpose": "assistants",
},
}
# Get appropriate input for this call type
@ -285,12 +309,34 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
patch_target = (
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
)
elif (
call_type == CallTypes.aget_assistants
or call_type == CallTypes.acreate_assistants
or call_type == CallTypes.adelete_assistant
or call_type == CallTypes.acreate_thread
or call_type == CallTypes.aget_thread
or call_type == CallTypes.a_add_message
or call_type == CallTypes.aget_messages
or call_type == CallTypes.arun_thread
):
patch_target = (
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
patch_target = (
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method
try:
get_attr = getattr(router, call_type.value, None)
if get_attr is None:
pytest.skip(
f"Skipping {call_type.value} because it is not supported on Router"
)
await getattr(router, call_type.value)(
model="gpt-3.5-turbo",
**input_kwarg,