refactor(azure.py): refactor to have client init work across all endpoints

This commit is contained in:
Krrish Dholakia 2025-03-11 17:27:24 -07:00
parent d99d60a182
commit cbc2e84044
10 changed files with 296 additions and 129 deletions

View file

@ -15,6 +15,7 @@ import litellm
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ( from litellm.utils import (
exception_type, exception_type,
get_litellm_params,
get_llm_provider, get_llm_provider,
get_secret, get_secret,
supports_httpx_timeout, supports_httpx_timeout,
@ -86,6 +87,7 @@ def get_assistants(
optional_params = GenericLiteLLMParams( optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
) )
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -169,6 +171,7 @@ def get_assistants(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
aget_assistants=aget_assistants, # type: ignore aget_assistants=aget_assistants, # type: ignore
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -270,6 +273,7 @@ def create_assistants(
optional_params = GenericLiteLLMParams( optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
) )
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -371,6 +375,7 @@ def create_assistants(
client=client, client=client,
async_create_assistants=async_create_assistants, async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data, create_assistant_data=create_assistant_data,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -445,6 +450,8 @@ def delete_assistant(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
) )
litellm_params_dict = get_litellm_params(**kwargs)
async_delete_assistants: Optional[bool] = kwargs.pop( async_delete_assistants: Optional[bool] = kwargs.pop(
"async_delete_assistants", None "async_delete_assistants", None
) )
@ -544,6 +551,7 @@ def delete_assistant(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
async_delete_assistants=async_delete_assistants, async_delete_assistants=async_delete_assistants,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -639,6 +647,7 @@ def create_thread(
""" """
acreate_thread = kwargs.get("acreate_thread", None) acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -731,6 +740,7 @@ def create_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
acreate_thread=acreate_thread, acreate_thread=acreate_thread,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -795,7 +805,7 @@ def get_thread(
"""Get the thread object, given a thread_id""" """Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None) aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default # set timeout for 10 minutes by default
@ -884,6 +894,7 @@ def get_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
aget_thread=aget_thread, aget_thread=aget_thread,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -972,6 +983,7 @@ def add_message(
_message_data = MessageData( _message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata role=role, content=content, attachments=attachments, metadata=metadata
) )
litellm_params_dict = get_litellm_params(**kwargs)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
message_data = get_optional_params_add_message( message_data = get_optional_params_add_message(
@ -1068,6 +1080,7 @@ def add_message(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
a_add_message=a_add_message, a_add_message=a_add_message,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -1139,6 +1152,7 @@ def get_messages(
) -> SyncCursorPage[OpenAIMessage]: ) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None) aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1225,6 +1239,7 @@ def get_messages(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
aget_messages=aget_messages, aget_messages=aget_messages,
litellm_params=litellm_params_dict,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -1337,6 +1352,7 @@ def run_thread(
"""Run a given thread + assistant.""" """Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None) arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -1437,6 +1453,7 @@ def run_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
client=client, client=client,
arun_thread=arun_thread, arun_thread=arun_thread,
litellm_params=litellm_params_dict,
) # type: ignore ) # type: ignore
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(

View file

@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
) )
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import get_litellm_params, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_files_instance = OpenAIFilesAPI() openai_files_instance = OpenAIFilesAPI()
@ -546,6 +546,7 @@ def create_file(
try: try:
_is_async = kwargs.pop("acreate_file", False) is True _is_async = kwargs.pop("acreate_file", False) is True
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
@ -630,6 +631,7 @@ def create_file(
timeout=timeout, timeout=timeout,
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
create_file_data=_create_file_request, create_file_data=_create_file_request,
litellm_params=litellm_params_dict,
) )
elif custom_llm_provider == "vertex_ai": elif custom_llm_provider == "vertex_ai":
api_base = optional_params.api_base or "" api_base = optional_params.api_base or ""

View file

@ -18,10 +18,10 @@ from ...types.llms.openai import (
SyncCursorPage, SyncCursorPage,
Thread, Thread,
) )
from ..base import BaseLLM from .common_utils import BaseAzureLLM
class AzureAssistantsAPI(BaseLLM): class AzureAssistantsAPI(BaseAzureLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -34,18 +34,17 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AzureOpenAI] = None, client: Optional[AzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AzureOpenAI: ) -> AzureOpenAI:
received_args = locals()
if client is None: if client is None:
data = {} azure_client_params = self.initialize_azure_sdk_client(
for k, v in received_args.items(): litellm_params=litellm_params or {},
if k == "self" or k == "client": api_key=api_key,
pass api_base=api_base,
elif k == "api_base" and v is not None: model_name="",
data["azure_endpoint"] = v api_version=api_version,
elif v is not None: )
data[k] = v azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
azure_openai_client = AzureOpenAI(**data) # type: ignore
else: else:
azure_openai_client = client azure_openai_client = client
@ -60,18 +59,18 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None, client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncAzureOpenAI: ) -> AsyncAzureOpenAI:
received_args = locals()
if client is None: if client is None:
data = {} azure_client_params = self.initialize_azure_sdk_client(
for k, v in received_args.items(): litellm_params=litellm_params or {},
if k == "self" or k == "client": api_key=api_key,
pass api_base=api_base,
elif k == "api_base" and v is not None: model_name="",
data["azure_endpoint"] = v api_version=api_version,
elif v is not None: )
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data) azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else: else:
azure_openai_client = client azure_openai_client = client
@ -89,6 +88,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[Assistant]: ) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client( azure_openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -98,6 +98,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await azure_openai_client.beta.assistants.list() response = await azure_openai_client.beta.assistants.list()
@ -146,6 +147,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client=None, client=None,
aget_assistants=None, aget_assistants=None,
litellm_params: Optional[dict] = None,
): ):
if aget_assistants is not None and aget_assistants is True: if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants( return self.async_get_assistants(
@ -156,6 +158,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
azure_openai_client = self.get_azure_client( azure_openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -165,6 +168,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
api_version=api_version, api_version=api_version,
litellm_params=litellm_params,
) )
response = azure_openai_client.beta.assistants.list() response = azure_openai_client.beta.assistants.list()
@ -184,6 +188,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None, client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> OpenAIMessage: ) -> OpenAIMessage:
openai_client = self.async_get_azure_client( openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -193,6 +198,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
@ -222,6 +228,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True], a_add_message: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, OpenAIMessage]: ) -> Coroutine[None, None, OpenAIMessage]:
... ...
@ -238,6 +245,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AzureOpenAI], client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]], a_add_message: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> OpenAIMessage: ) -> OpenAIMessage:
... ...
@ -255,6 +263,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client=None, client=None,
a_add_message: Optional[bool] = None, a_add_message: Optional[bool] = None,
litellm_params: Optional[dict] = None,
): ):
if a_add_message is not None and a_add_message is True: if a_add_message is not None and a_add_message is True:
return self.a_add_message( return self.a_add_message(
@ -267,6 +276,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
openai_client = self.get_azure_client( openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -300,6 +310,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None, client: Optional[AsyncAzureOpenAI] = None,
litellm_params: Optional[dict] = None,
) -> AsyncCursorPage[OpenAIMessage]: ) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client( openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -309,6 +320,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await openai_client.beta.threads.messages.list(thread_id=thread_id) response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -329,6 +341,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True], aget_messages: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
... ...
@ -344,6 +357,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AzureOpenAI], client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]], aget_messages: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> SyncCursorPage[OpenAIMessage]: ) -> SyncCursorPage[OpenAIMessage]:
... ...
@ -360,6 +374,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client=None, client=None,
aget_messages=None, aget_messages=None,
litellm_params: Optional[dict] = None,
): ):
if aget_messages is not None and aget_messages is True: if aget_messages is not None and aget_messages is True:
return self.async_get_messages( return self.async_get_messages(
@ -371,6 +386,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
openai_client = self.get_azure_client( openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -380,6 +396,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = openai_client.beta.threads.messages.list(thread_id=thread_id) response = openai_client.beta.threads.messages.list(thread_id=thread_id)
@ -399,6 +416,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
litellm_params: Optional[dict] = None,
) -> Thread: ) -> Thread:
openai_client = self.async_get_azure_client( openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -408,6 +426,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
data = {} data = {}
@ -435,6 +454,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True], acreate_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]: ) -> Coroutine[None, None, Thread]:
... ...
@ -451,6 +471,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI], client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]], acreate_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread: ) -> Thread:
... ...
@ -468,6 +489,7 @@ class AzureAssistantsAPI(BaseLLM):
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None, client=None,
acreate_thread=None, acreate_thread=None,
litellm_params: Optional[dict] = None,
): ):
""" """
Here's an example: Here's an example:
@ -490,6 +512,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
messages=messages, messages=messages,
litellm_params=litellm_params,
) )
azure_openai_client = self.get_azure_client( azure_openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -499,6 +522,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
data = {} data = {}
@ -521,6 +545,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Thread: ) -> Thread:
openai_client = self.async_get_azure_client( openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -530,6 +555,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await openai_client.beta.threads.retrieve(thread_id=thread_id) response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -550,6 +576,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True], aget_thread: Literal[True],
litellm_params: Optional[dict] = None,
) -> Coroutine[None, None, Thread]: ) -> Coroutine[None, None, Thread]:
... ...
@ -565,6 +592,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AzureOpenAI], client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]], aget_thread: Optional[Literal[False]],
litellm_params: Optional[dict] = None,
) -> Thread: ) -> Thread:
... ...
@ -581,6 +609,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client=None, client=None,
aget_thread=None, aget_thread=None,
litellm_params: Optional[dict] = None,
): ):
if aget_thread is not None and aget_thread is True: if aget_thread is not None and aget_thread is True:
return self.async_get_thread( return self.async_get_thread(
@ -592,6 +621,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
openai_client = self.get_azure_client( openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -601,6 +631,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = openai_client.beta.threads.retrieve(thread_id=thread_id) response = openai_client.beta.threads.retrieve(thread_id=thread_id)
@ -629,6 +660,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
litellm_params: Optional[dict] = None,
) -> Run: ) -> Run:
openai_client = self.async_get_azure_client( openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -638,6 +670,7 @@ class AzureAssistantsAPI(BaseLLM):
api_version=api_version, api_version=api_version,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -645,7 +678,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,
instructions=instructions, instructions=instructions,
metadata=metadata, metadata=metadata, # type: ignore
model=model, model=model,
tools=tools, tools=tools,
) )
@ -663,6 +696,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str], model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]], tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler], event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = { data = {
"thread_id": thread_id, "thread_id": thread_id,
@ -688,6 +722,7 @@ class AzureAssistantsAPI(BaseLLM):
model: Optional[str], model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]], tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler], event_handler: Optional[AssistantEventHandler],
litellm_params: Optional[dict] = None,
) -> AssistantStreamManager[AssistantEventHandler]: ) -> AssistantStreamManager[AssistantEventHandler]:
data = { data = {
"thread_id": thread_id, "thread_id": thread_id,
@ -769,6 +804,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
arun_thread=None, arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None, event_handler: Optional[AssistantEventHandler] = None,
litellm_params: Optional[dict] = None,
): ):
if arun_thread is not None and arun_thread is True: if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True: if stream is not None and stream is True:
@ -780,6 +816,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
return self.async_run_thread_stream( return self.async_run_thread_stream(
client=azure_client, client=azure_client,
@ -791,13 +828,14 @@ class AzureAssistantsAPI(BaseLLM):
model=model, model=model,
tools=tools, tools=tools,
event_handler=event_handler, event_handler=event_handler,
litellm_params=litellm_params,
) )
return self.arun_thread( return self.arun_thread(
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,
instructions=instructions, instructions=instructions,
metadata=metadata, metadata=metadata, # type: ignore
model=model, model=model,
stream=stream, stream=stream,
tools=tools, tools=tools,
@ -808,6 +846,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
openai_client = self.get_azure_client( openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -817,6 +856,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
if stream is not None and stream is True: if stream is not None and stream is True:
@ -830,6 +870,7 @@ class AzureAssistantsAPI(BaseLLM):
model=model, model=model,
tools=tools, tools=tools,
event_handler=event_handler, event_handler=event_handler,
litellm_params=litellm_params,
) )
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
@ -837,7 +878,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id=assistant_id, assistant_id=assistant_id,
additional_instructions=additional_instructions, additional_instructions=additional_instructions,
instructions=instructions, instructions=instructions,
metadata=metadata, metadata=metadata, # type: ignore
model=model, model=model,
tools=tools, tools=tools,
) )
@ -855,6 +896,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
create_assistant_data: dict, create_assistant_data: dict,
litellm_params: Optional[dict] = None,
) -> Assistant: ) -> Assistant:
azure_openai_client = self.async_get_azure_client( azure_openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -864,6 +906,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await azure_openai_client.beta.assistants.create( response = await azure_openai_client.beta.assistants.create(
@ -882,6 +925,7 @@ class AzureAssistantsAPI(BaseLLM):
create_assistant_data: dict, create_assistant_data: dict,
client=None, client=None,
async_create_assistants=None, async_create_assistants=None,
litellm_params: Optional[dict] = None,
): ):
if async_create_assistants is not None and async_create_assistants is True: if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants( return self.async_create_assistants(
@ -893,6 +937,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
create_assistant_data=create_assistant_data, create_assistant_data=create_assistant_data,
litellm_params=litellm_params,
) )
azure_openai_client = self.get_azure_client( azure_openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -902,6 +947,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = azure_openai_client.beta.assistants.create(**create_assistant_data) response = azure_openai_client.beta.assistants.create(**create_assistant_data)
@ -918,6 +964,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI], client: Optional[AsyncAzureOpenAI],
assistant_id: str, assistant_id: str,
litellm_params: Optional[dict] = None,
): ):
azure_openai_client = self.async_get_azure_client( azure_openai_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
@ -927,6 +974,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = await azure_openai_client.beta.assistants.delete( response = await azure_openai_client.beta.assistants.delete(
@ -945,6 +993,7 @@ class AzureAssistantsAPI(BaseLLM):
assistant_id: str, assistant_id: str,
async_delete_assistants: Optional[bool] = None, async_delete_assistants: Optional[bool] = None,
client=None, client=None,
litellm_params: Optional[dict] = None,
): ):
if async_delete_assistants is not None and async_delete_assistants is True: if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant( return self.async_delete_assistant(
@ -956,6 +1005,7 @@ class AzureAssistantsAPI(BaseLLM):
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
assistant_id=assistant_id, assistant_id=assistant_id,
litellm_params=litellm_params,
) )
azure_openai_client = self.get_azure_client( azure_openai_client = self.get_azure_client(
api_key=api_key, api_key=api_key,
@ -965,6 +1015,7 @@ class AzureAssistantsAPI(BaseLLM):
timeout=timeout, timeout=timeout,
max_retries=max_retries, max_retries=max_retries,
client=client, client=client,
litellm_params=litellm_params,
) )
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id) response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)

View file

@ -4,50 +4,70 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
Written separately to handle faking streaming for o1 and o3 models. Written separately to handle faking streaming for o1 and o3 models.
""" """
from typing import Optional, Union from typing import Any, Callable, Optional, Union
import httpx import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.types.llms.openai import Any
from litellm.types.utils import ModelResponse
from ...openai.openai import OpenAIChatCompletion from ...openai.openai import OpenAIChatCompletion
from ..common_utils import get_azure_openai_client from ..common_utils import BaseAzureLLM
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion): class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
def _get_openai_client( def completion(
self, self,
is_async: bool, model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), dynamic_params: Optional[bool] = None,
max_retries: Optional[int] = 2, azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None, organization: Optional[str] = None,
client: Optional[ custom_llm_provider: Optional[str] = None,
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] drop_params: Optional[bool] = None,
] = None,
) -> Optional[
Union[
OpenAI,
AsyncOpenAI,
AzureOpenAI,
AsyncAzureOpenAI,
]
]:
# Override to use Azure-specific client initialization
if not isinstance(client, AzureOpenAI) and not isinstance(
client, AsyncAzureOpenAI
): ):
client = None client = self.get_azure_openai_client(
litellm_params=litellm_params,
return get_azure_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version, api_version=api_version,
client=client, client=client,
_is_async=is_async, )
return super().completion(
model_response=model_response,
timeout=timeout,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
model=model,
messages=messages,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
api_version=api_version,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
acompletion=acompletion,
logger_fn=logger_fn,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
client=client,
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
) )

View file

@ -35,40 +35,6 @@ class AzureOpenAIError(BaseLLMException):
) )
def get_azure_openai_client(
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
organization: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict: def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {} openai_headers = {}
if "x-ratelimit-limit-requests" in headers: if "x-ratelimit-limit-requests" in headers:
@ -277,6 +243,33 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
class BaseAzureLLM: class BaseAzureLLM:
def get_azure_openai_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
model_name="",
api_version=api_version,
)
if _is_async is True:
openai_client = AsyncAzureOpenAI(**azure_client_params)
else:
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
openai_client = client
return openai_client
def initialize_azure_sdk_client( def initialize_azure_sdk_client(
self, self,
litellm_params: dict, litellm_params: dict,
@ -294,6 +287,8 @@ class BaseAzureLLM:
client_secret = litellm_params.get("client_secret") client_secret = litellm_params.get("client_secret")
azure_username = litellm_params.get("azure_username") azure_username = litellm_params.get("azure_username")
azure_password = litellm_params.get("azure_password") azure_password = litellm_params.get("azure_password")
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if not api_key and tenant_id and client_id and client_secret: if not api_key and tenant_id and client_id and client_secret:
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth") verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_entrata_id( azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
@ -338,6 +333,10 @@ class BaseAzureLLM:
"azure_ad_token": azure_ad_token, "azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider, "azure_ad_token_provider": azure_ad_token_provider,
} }
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
if azure_ad_token_provider is not None: if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider

View file

@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
from openai.types.file_deleted import FileDeleted from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.llms.base import BaseLLM
from litellm.types.llms.openai import * from litellm.types.llms.openai import *
from ..common_utils import get_azure_openai_client from ..common_utils import BaseAzureLLM
class AzureOpenAIFilesAPI(BaseLLM): class AzureOpenAIFilesAPI(BaseAzureLLM):
""" """
AzureOpenAI methods to support for batches AzureOpenAI methods to support for batches
- create_file() - create_file()
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
api_version: Optional[str] = None, api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
) -> Union[ ) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]: ]:
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
api_version=api_version, api_version=api_version,
max_retries=max_retries,
organization=None,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
) )
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
api_version: Optional[str] = None, api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None,
api_version=api_version, api_version=api_version,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
organization: Optional[str] = None, organization: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version, api_version=api_version,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
purpose: Optional[str] = None, purpose: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
litellm_params: Optional[dict] = None,
): ):
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
get_azure_openai_client( self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=None, # openai param
api_version=api_version, api_version=api_version,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,

View file

@ -3,11 +3,11 @@ from typing import Optional, Union
import httpx import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.llms.azure.files.handler import get_azure_openai_client from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI): class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
""" """
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI. AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
""" """
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
] = None, ] = None,
_is_async: bool = False, _is_async: bool = False,
api_version: Optional[str] = None, api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[ ) -> Optional[
Union[ Union[
OpenAI, OpenAI,
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None client = None
return get_azure_openai_client( return self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
api_version=api_version, api_version=api_version,
client=client, client=client,
_is_async=_is_async, _is_async=_is_async,

View file

@ -27,6 +27,7 @@ class OpenAIFineTuningAPI:
] = None, ] = None,
_is_async: bool = False, _is_async: bool = False,
api_version: Optional[str] = None, api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[ ) -> Optional[
Union[ Union[
OpenAI, OpenAI,

View file

@ -191,6 +191,42 @@ class CallTypes(Enum):
retrieve_batch = "retrieve_batch" retrieve_batch = "retrieve_batch"
pass_through = "pass_through_endpoint" pass_through = "pass_through_endpoint"
anthropic_messages = "anthropic_messages" anthropic_messages = "anthropic_messages"
get_assistants = "get_assistants"
aget_assistants = "aget_assistants"
create_assistants = "create_assistants"
acreate_assistants = "acreate_assistants"
delete_assistant = "delete_assistant"
adelete_assistant = "adelete_assistant"
acreate_thread = "acreate_thread"
create_thread = "create_thread"
aget_thread = "aget_thread"
get_thread = "get_thread"
a_add_message = "a_add_message"
add_message = "add_message"
aget_messages = "aget_messages"
get_messages = "get_messages"
arun_thread = "arun_thread"
run_thread = "run_thread"
arun_thread_stream = "arun_thread_stream"
run_thread_stream = "run_thread_stream"
afile_retrieve = "afile_retrieve"
file_retrieve = "file_retrieve"
afile_delete = "afile_delete"
file_delete = "file_delete"
afile_list = "afile_list"
file_list = "file_list"
acreate_file = "acreate_file"
create_file = "create_file"
afile_content = "afile_content"
file_content = "file_content"
create_fine_tuning_job = "create_fine_tuning_job"
acreate_fine_tuning_job = "acreate_fine_tuning_job"
acancel_fine_tuning_job = "acancel_fine_tuning_job"
cancel_fine_tuning_job = "cancel_fine_tuning_job"
alist_fine_tuning_jobs = "alist_fine_tuning_jobs"
list_fine_tuning_jobs = "list_fine_tuning_jobs"
aretrieve_fine_tuning_job = "aretrieve_fine_tuning_job"
retrieve_fine_tuning_job = "retrieve_fine_tuning_job"
CallTypesLiteral = Literal[ CallTypesLiteral = Literal[

View file

@ -216,16 +216,18 @@ def test_select_azure_base_url_called(setup_mocks):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"call_type", "call_type",
[ [
CallTypes.acompletion, call_type
CallTypes.atext_completion, for call_type in CallTypes.__members__.values()
CallTypes.aembedding, if call_type.name.startswith("a")
CallTypes.atranscription, and call_type.name
CallTypes.aspeech, not in [
CallTypes.aimage_generation, "amoderation",
# BATCHES ENDPOINTS "arerank",
CallTypes.acreate_batch, "arealtime",
CallTypes.aretrieve_batch, "anthropic_messages",
# ASSISTANT ENDPOINTS "add_message",
"arun_thread_stream",
]
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -267,6 +269,28 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
"input_file_id": "123", "input_file_id": "123",
}, },
"aretrieve_batch": {"batch_id": "123"}, "aretrieve_batch": {"batch_id": "123"},
"aget_assistants": {"custom_llm_provider": "azure"},
"acreate_assistants": {"custom_llm_provider": "azure"},
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
"acreate_thread": {"custom_llm_provider": "azure"},
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
"a_add_message": {
"custom_llm_provider": "azure",
"thread_id": "123",
"role": "user",
"content": "Hello, how are you?",
},
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
"arun_thread": {
"custom_llm_provider": "azure",
"assistant_id": "123",
"thread_id": "123",
},
"acreate_file": {
"custom_llm_provider": "azure",
"file": MagicMock(),
"purpose": "assistants",
},
} }
# Get appropriate input for this call type # Get appropriate input for this call type
@ -285,12 +309,34 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
patch_target = ( patch_target = (
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client" "litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
) )
elif (
call_type == CallTypes.aget_assistants
or call_type == CallTypes.acreate_assistants
or call_type == CallTypes.adelete_assistant
or call_type == CallTypes.acreate_thread
or call_type == CallTypes.aget_thread
or call_type == CallTypes.a_add_message
or call_type == CallTypes.aget_messages
or call_type == CallTypes.arun_thread
):
patch_target = (
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
)
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
patch_target = (
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
)
# Mock the initialize_azure_sdk_client function # Mock the initialize_azure_sdk_client function
with patch(patch_target) as mock_init_azure: with patch(patch_target) as mock_init_azure:
# Also mock async_function_with_fallbacks to prevent actual API calls # Also mock async_function_with_fallbacks to prevent actual API calls
# Call the appropriate router method # Call the appropriate router method
try: try:
get_attr = getattr(router, call_type.value, None)
if get_attr is None:
pytest.skip(
f"Skipping {call_type.value} because it is not supported on Router"
)
await getattr(router, call_type.value)( await getattr(router, call_type.value)(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
**input_kwarg, **input_kwarg,