mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(azure.py): refactor to have client init work across all endpoints
This commit is contained in:
parent
d99d60a182
commit
cbc2e84044
10 changed files with 296 additions and 129 deletions
|
@ -15,6 +15,7 @@ import litellm
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
exception_type,
|
exception_type,
|
||||||
|
get_litellm_params,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
get_secret,
|
get_secret,
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
|
@ -86,6 +87,7 @@ def get_assistants(
|
||||||
optional_params = GenericLiteLLMParams(
|
optional_params = GenericLiteLLMParams(
|
||||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||||
)
|
)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -169,6 +171,7 @@ def get_assistants(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
aget_assistants=aget_assistants, # type: ignore
|
aget_assistants=aget_assistants, # type: ignore
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -270,6 +273,7 @@ def create_assistants(
|
||||||
optional_params = GenericLiteLLMParams(
|
optional_params = GenericLiteLLMParams(
|
||||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||||
)
|
)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -371,6 +375,7 @@ def create_assistants(
|
||||||
client=client,
|
client=client,
|
||||||
async_create_assistants=async_create_assistants,
|
async_create_assistants=async_create_assistants,
|
||||||
create_assistant_data=create_assistant_data,
|
create_assistant_data=create_assistant_data,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -445,6 +450,8 @@ def delete_assistant(
|
||||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
async_delete_assistants: Optional[bool] = kwargs.pop(
|
async_delete_assistants: Optional[bool] = kwargs.pop(
|
||||||
"async_delete_assistants", None
|
"async_delete_assistants", None
|
||||||
)
|
)
|
||||||
|
@ -544,6 +551,7 @@ def delete_assistant(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
async_delete_assistants=async_delete_assistants,
|
async_delete_assistants=async_delete_assistants,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -639,6 +647,7 @@ def create_thread(
|
||||||
"""
|
"""
|
||||||
acreate_thread = kwargs.get("acreate_thread", None)
|
acreate_thread = kwargs.get("acreate_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -731,6 +740,7 @@ def create_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
acreate_thread=acreate_thread,
|
acreate_thread=acreate_thread,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -795,7 +805,7 @@ def get_thread(
|
||||||
"""Get the thread object, given a thread_id"""
|
"""Get the thread object, given a thread_id"""
|
||||||
aget_thread = kwargs.pop("aget_thread", None)
|
aget_thread = kwargs.pop("aget_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
# set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
|
@ -884,6 +894,7 @@ def get_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
aget_thread=aget_thread,
|
aget_thread=aget_thread,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -972,6 +983,7 @@ def add_message(
|
||||||
_message_data = MessageData(
|
_message_data = MessageData(
|
||||||
role=role, content=content, attachments=attachments, metadata=metadata
|
role=role, content=content, attachments=attachments, metadata=metadata
|
||||||
)
|
)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
message_data = get_optional_params_add_message(
|
message_data = get_optional_params_add_message(
|
||||||
|
@ -1068,6 +1080,7 @@ def add_message(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
a_add_message=a_add_message,
|
a_add_message=a_add_message,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -1139,6 +1152,7 @@ def get_messages(
|
||||||
) -> SyncCursorPage[OpenAIMessage]:
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
aget_messages = kwargs.pop("aget_messages", None)
|
aget_messages = kwargs.pop("aget_messages", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -1225,6 +1239,7 @@ def get_messages(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
aget_messages=aget_messages,
|
aget_messages=aget_messages,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -1337,6 +1352,7 @@ def run_thread(
|
||||||
"""Run a given thread + assistant."""
|
"""Run a given thread + assistant."""
|
||||||
arun_thread = kwargs.pop("arun_thread", None)
|
arun_thread = kwargs.pop("arun_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -1437,6 +1453,7 @@ def run_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
arun_thread=arun_thread,
|
arun_thread=arun_thread,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
|
|
@ -25,7 +25,7 @@ from litellm.types.llms.openai import (
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
)
|
)
|
||||||
from litellm.types.router import *
|
from litellm.types.router import *
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import get_litellm_params, supports_httpx_timeout
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_files_instance = OpenAIFilesAPI()
|
openai_files_instance = OpenAIFilesAPI()
|
||||||
|
@ -546,6 +546,7 @@ def create_file(
|
||||||
try:
|
try:
|
||||||
_is_async = kwargs.pop("acreate_file", False) is True
|
_is_async = kwargs.pop("acreate_file", False) is True
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
|
@ -630,6 +631,7 @@ def create_file(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
create_file_data=_create_file_request,
|
create_file_data=_create_file_request,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
api_base = optional_params.api_base or ""
|
api_base = optional_params.api_base or ""
|
||||||
|
|
|
@ -18,10 +18,10 @@ from ...types.llms.openai import (
|
||||||
SyncCursorPage,
|
SyncCursorPage,
|
||||||
Thread,
|
Thread,
|
||||||
)
|
)
|
||||||
from ..base import BaseLLM
|
from .common_utils import BaseAzureLLM
|
||||||
|
|
||||||
|
|
||||||
class AzureAssistantsAPI(BaseLLM):
|
class AzureAssistantsAPI(BaseAzureLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -34,18 +34,17 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AzureOpenAI] = None,
|
client: Optional[AzureOpenAI] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AzureOpenAI:
|
) -> AzureOpenAI:
|
||||||
received_args = locals()
|
|
||||||
if client is None:
|
if client is None:
|
||||||
data = {}
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
for k, v in received_args.items():
|
litellm_params=litellm_params or {},
|
||||||
if k == "self" or k == "client":
|
api_key=api_key,
|
||||||
pass
|
api_base=api_base,
|
||||||
elif k == "api_base" and v is not None:
|
model_name="",
|
||||||
data["azure_endpoint"] = v
|
api_version=api_version,
|
||||||
elif v is not None:
|
)
|
||||||
data[k] = v
|
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
azure_openai_client = client
|
azure_openai_client = client
|
||||||
|
|
||||||
|
@ -60,18 +59,18 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AsyncAzureOpenAI:
|
) -> AsyncAzureOpenAI:
|
||||||
received_args = locals()
|
|
||||||
if client is None:
|
if client is None:
|
||||||
data = {}
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
for k, v in received_args.items():
|
litellm_params=litellm_params or {},
|
||||||
if k == "self" or k == "client":
|
api_key=api_key,
|
||||||
pass
|
api_base=api_base,
|
||||||
elif k == "api_base" and v is not None:
|
model_name="",
|
||||||
data["azure_endpoint"] = v
|
api_version=api_version,
|
||||||
elif v is not None:
|
)
|
||||||
data[k] = v
|
|
||||||
azure_openai_client = AsyncAzureOpenAI(**data)
|
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
||||||
else:
|
else:
|
||||||
azure_openai_client = client
|
azure_openai_client = client
|
||||||
|
@ -89,6 +88,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AsyncCursorPage[Assistant]:
|
) -> AsyncCursorPage[Assistant]:
|
||||||
azure_openai_client = self.async_get_azure_client(
|
azure_openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -98,6 +98,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await azure_openai_client.beta.assistants.list()
|
response = await azure_openai_client.beta.assistants.list()
|
||||||
|
@ -146,6 +147,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client=None,
|
client=None,
|
||||||
aget_assistants=None,
|
aget_assistants=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if aget_assistants is not None and aget_assistants is True:
|
if aget_assistants is not None and aget_assistants is True:
|
||||||
return self.async_get_assistants(
|
return self.async_get_assistants(
|
||||||
|
@ -156,6 +158,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
azure_openai_client = self.get_azure_client(
|
azure_openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -165,6 +168,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = azure_openai_client.beta.assistants.list()
|
response = azure_openai_client.beta.assistants.list()
|
||||||
|
@ -184,6 +188,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> OpenAIMessage:
|
) -> OpenAIMessage:
|
||||||
openai_client = self.async_get_azure_client(
|
openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -193,6 +198,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||||
|
@ -222,6 +228,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
a_add_message: Literal[True],
|
a_add_message: Literal[True],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Coroutine[None, None, OpenAIMessage]:
|
) -> Coroutine[None, None, OpenAIMessage]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -238,6 +245,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AzureOpenAI],
|
client: Optional[AzureOpenAI],
|
||||||
a_add_message: Optional[Literal[False]],
|
a_add_message: Optional[Literal[False]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> OpenAIMessage:
|
) -> OpenAIMessage:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -255,6 +263,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client=None,
|
client=None,
|
||||||
a_add_message: Optional[bool] = None,
|
a_add_message: Optional[bool] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if a_add_message is not None and a_add_message is True:
|
if a_add_message is not None and a_add_message is True:
|
||||||
return self.a_add_message(
|
return self.a_add_message(
|
||||||
|
@ -267,6 +276,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
openai_client = self.get_azure_client(
|
openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -300,6 +310,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AsyncCursorPage[OpenAIMessage]:
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
openai_client = self.async_get_azure_client(
|
openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -309,6 +320,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
@ -329,6 +341,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
aget_messages: Literal[True],
|
aget_messages: Literal[True],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -344,6 +357,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AzureOpenAI],
|
client: Optional[AzureOpenAI],
|
||||||
aget_messages: Optional[Literal[False]],
|
aget_messages: Optional[Literal[False]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> SyncCursorPage[OpenAIMessage]:
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -360,6 +374,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client=None,
|
client=None,
|
||||||
aget_messages=None,
|
aget_messages=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if aget_messages is not None and aget_messages is True:
|
if aget_messages is not None and aget_messages is True:
|
||||||
return self.async_get_messages(
|
return self.async_get_messages(
|
||||||
|
@ -371,6 +386,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
openai_client = self.get_azure_client(
|
openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -380,6 +396,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
@ -399,6 +416,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
openai_client = self.async_get_azure_client(
|
openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -408,6 +426,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -435,6 +454,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
acreate_thread: Literal[True],
|
acreate_thread: Literal[True],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Coroutine[None, None, Thread]:
|
) -> Coroutine[None, None, Thread]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -451,6 +471,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
client: Optional[AzureOpenAI],
|
client: Optional[AzureOpenAI],
|
||||||
acreate_thread: Optional[Literal[False]],
|
acreate_thread: Optional[Literal[False]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -468,6 +489,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
client=None,
|
client=None,
|
||||||
acreate_thread=None,
|
acreate_thread=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Here's an example:
|
Here's an example:
|
||||||
|
@ -490,6 +512,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
azure_openai_client = self.get_azure_client(
|
azure_openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -499,6 +522,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -521,6 +545,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
openai_client = self.async_get_azure_client(
|
openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -530,6 +555,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
@ -550,6 +576,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
aget_thread: Literal[True],
|
aget_thread: Literal[True],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Coroutine[None, None, Thread]:
|
) -> Coroutine[None, None, Thread]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -565,6 +592,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AzureOpenAI],
|
client: Optional[AzureOpenAI],
|
||||||
aget_thread: Optional[Literal[False]],
|
aget_thread: Optional[Literal[False]],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -581,6 +609,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client=None,
|
client=None,
|
||||||
aget_thread=None,
|
aget_thread=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if aget_thread is not None and aget_thread is True:
|
if aget_thread is not None and aget_thread is True:
|
||||||
return self.async_get_thread(
|
return self.async_get_thread(
|
||||||
|
@ -592,6 +621,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
openai_client = self.get_azure_client(
|
openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -601,6 +631,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
@ -629,6 +660,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
openai_client = self.async_get_azure_client(
|
openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -638,6 +670,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
@ -645,7 +678,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
additional_instructions=additional_instructions,
|
additional_instructions=additional_instructions,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
metadata=metadata,
|
metadata=metadata, # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
@ -663,6 +696,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
event_handler: Optional[AssistantEventHandler],
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||||
data = {
|
data = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
|
@ -688,6 +722,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
event_handler: Optional[AssistantEventHandler],
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||||
data = {
|
data = {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
|
@ -769,6 +804,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=None,
|
client=None,
|
||||||
arun_thread=None,
|
arun_thread=None,
|
||||||
event_handler: Optional[AssistantEventHandler] = None,
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if arun_thread is not None and arun_thread is True:
|
if arun_thread is not None and arun_thread is True:
|
||||||
if stream is not None and stream is True:
|
if stream is not None and stream is True:
|
||||||
|
@ -780,6 +816,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
return self.async_run_thread_stream(
|
return self.async_run_thread_stream(
|
||||||
client=azure_client,
|
client=azure_client,
|
||||||
|
@ -791,13 +828,14 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
return self.arun_thread(
|
return self.arun_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
additional_instructions=additional_instructions,
|
additional_instructions=additional_instructions,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
metadata=metadata,
|
metadata=metadata, # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -808,6 +846,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
openai_client = self.get_azure_client(
|
openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -817,6 +856,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stream is not None and stream is True:
|
if stream is not None and stream is True:
|
||||||
|
@ -830,6 +870,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
@ -837,7 +878,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
additional_instructions=additional_instructions,
|
additional_instructions=additional_instructions,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
metadata=metadata,
|
metadata=metadata, # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
@ -855,6 +896,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
create_assistant_data: dict,
|
create_assistant_data: dict,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Assistant:
|
) -> Assistant:
|
||||||
azure_openai_client = self.async_get_azure_client(
|
azure_openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -864,6 +906,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await azure_openai_client.beta.assistants.create(
|
response = await azure_openai_client.beta.assistants.create(
|
||||||
|
@ -882,6 +925,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
create_assistant_data: dict,
|
create_assistant_data: dict,
|
||||||
client=None,
|
client=None,
|
||||||
async_create_assistants=None,
|
async_create_assistants=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if async_create_assistants is not None and async_create_assistants is True:
|
if async_create_assistants is not None and async_create_assistants is True:
|
||||||
return self.async_create_assistants(
|
return self.async_create_assistants(
|
||||||
|
@ -893,6 +937,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
create_assistant_data=create_assistant_data,
|
create_assistant_data=create_assistant_data,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
azure_openai_client = self.get_azure_client(
|
azure_openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -902,6 +947,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
||||||
|
@ -918,6 +964,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[AsyncAzureOpenAI],
|
client: Optional[AsyncAzureOpenAI],
|
||||||
assistant_id: str,
|
assistant_id: str,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
azure_openai_client = self.async_get_azure_client(
|
azure_openai_client = self.async_get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -927,6 +974,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await azure_openai_client.beta.assistants.delete(
|
response = await azure_openai_client.beta.assistants.delete(
|
||||||
|
@ -945,6 +993,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
assistant_id: str,
|
assistant_id: str,
|
||||||
async_delete_assistants: Optional[bool] = None,
|
async_delete_assistants: Optional[bool] = None,
|
||||||
client=None,
|
client=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if async_delete_assistants is not None and async_delete_assistants is True:
|
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||||
return self.async_delete_assistant(
|
return self.async_delete_assistant(
|
||||||
|
@ -956,6 +1005,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
azure_openai_client = self.get_azure_client(
|
azure_openai_client = self.get_azure_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -965,6 +1015,7 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
||||||
|
|
|
@ -4,50 +4,70 @@ Handler file for calls to Azure OpenAI's o1/o3 family of models
|
||||||
Written separately to handle faking streaming for o1 and o3 models.
|
Written separately to handle faking streaming for o1 and o3 models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||||
|
|
||||||
|
from litellm.types.llms.openai import Any
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
|
||||||
from ...openai.openai import OpenAIChatCompletion
|
from ...openai.openai import OpenAIChatCompletion
|
||||||
from ..common_utils import get_azure_openai_client
|
from ..common_utils import BaseAzureLLM
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIO1ChatCompletion(OpenAIChatCompletion):
|
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
|
||||||
def _get_openai_client(
|
def completion(
|
||||||
self,
|
self,
|
||||||
is_async: bool,
|
model_response: ModelResponse,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
print_verbose: Optional[Callable] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
dynamic_params: Optional[bool] = None,
|
||||||
max_retries: Optional[int] = 2,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
acompletion: bool = False,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
custom_prompt_dict: dict = {},
|
||||||
|
client=None,
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
client: Optional[
|
custom_llm_provider: Optional[str] = None,
|
||||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
drop_params: Optional[bool] = None,
|
||||||
] = None,
|
):
|
||||||
) -> Optional[
|
client = self.get_azure_openai_client(
|
||||||
Union[
|
litellm_params=litellm_params,
|
||||||
OpenAI,
|
|
||||||
AsyncOpenAI,
|
|
||||||
AzureOpenAI,
|
|
||||||
AsyncAzureOpenAI,
|
|
||||||
]
|
|
||||||
]:
|
|
||||||
|
|
||||||
# Override to use Azure-specific client initialization
|
|
||||||
if not isinstance(client, AzureOpenAI) and not isinstance(
|
|
||||||
client, AsyncAzureOpenAI
|
|
||||||
):
|
|
||||||
client = None
|
|
||||||
|
|
||||||
return get_azure_openai_client(
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
organization=organization,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=is_async,
|
)
|
||||||
|
return super().completion(
|
||||||
|
model_response=model_response,
|
||||||
|
timeout=timeout,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
client=client,
|
||||||
|
organization=organization,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -35,40 +35,6 @@ class AzureOpenAIError(BaseLLMException):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_azure_openai_client(
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
api_version: Optional[str] = None,
|
|
||||||
organization: Optional[str] = None,
|
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
|
||||||
_is_async: bool = False,
|
|
||||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
|
||||||
received_args = locals()
|
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
|
||||||
if client is None:
|
|
||||||
data = {}
|
|
||||||
for k, v in received_args.items():
|
|
||||||
if k == "self" or k == "client" or k == "_is_async":
|
|
||||||
pass
|
|
||||||
elif k == "api_base" and v is not None:
|
|
||||||
data["azure_endpoint"] = v
|
|
||||||
elif v is not None:
|
|
||||||
data[k] = v
|
|
||||||
if "api_version" not in data:
|
|
||||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
|
||||||
|
|
||||||
if _is_async is True:
|
|
||||||
openai_client = AsyncAzureOpenAI(**data)
|
|
||||||
else:
|
|
||||||
openai_client = AzureOpenAI(**data) # type: ignore
|
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
|
|
||||||
return openai_client
|
|
||||||
|
|
||||||
|
|
||||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||||
openai_headers = {}
|
openai_headers = {}
|
||||||
if "x-ratelimit-limit-requests" in headers:
|
if "x-ratelimit-limit-requests" in headers:
|
||||||
|
@ -277,6 +243,33 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
|
|
||||||
|
|
||||||
class BaseAzureLLM:
|
class BaseAzureLLM:
|
||||||
|
def get_azure_openai_client(
|
||||||
|
self,
|
||||||
|
litellm_params: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
_is_async: bool = False,
|
||||||
|
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||||
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||||
|
if client is None:
|
||||||
|
azure_client_params = self.initialize_azure_sdk_client(
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
model_name="",
|
||||||
|
api_version=api_version,
|
||||||
|
)
|
||||||
|
if _is_async is True:
|
||||||
|
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
return openai_client
|
||||||
|
|
||||||
def initialize_azure_sdk_client(
|
def initialize_azure_sdk_client(
|
||||||
self,
|
self,
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
|
@ -294,6 +287,8 @@ class BaseAzureLLM:
|
||||||
client_secret = litellm_params.get("client_secret")
|
client_secret = litellm_params.get("client_secret")
|
||||||
azure_username = litellm_params.get("azure_username")
|
azure_username = litellm_params.get("azure_username")
|
||||||
azure_password = litellm_params.get("azure_password")
|
azure_password = litellm_params.get("azure_password")
|
||||||
|
max_retries = litellm_params.get("max_retries")
|
||||||
|
timeout = litellm_params.get("timeout")
|
||||||
if not api_key and tenant_id and client_id and client_secret:
|
if not api_key and tenant_id and client_id and client_secret:
|
||||||
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||||
|
@ -338,6 +333,10 @@ class BaseAzureLLM:
|
||||||
"azure_ad_token": azure_ad_token,
|
"azure_ad_token": azure_ad_token,
|
||||||
"azure_ad_token_provider": azure_ad_token_provider,
|
"azure_ad_token_provider": azure_ad_token_provider,
|
||||||
}
|
}
|
||||||
|
if max_retries is not None:
|
||||||
|
azure_client_params["max_retries"] = max_retries
|
||||||
|
if timeout is not None:
|
||||||
|
azure_client_params["timeout"] = timeout
|
||||||
|
|
||||||
if azure_ad_token_provider is not None:
|
if azure_ad_token_provider is not None:
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
|
@ -5,13 +5,12 @@ from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
from openai.types.file_deleted import FileDeleted
|
from openai.types.file_deleted import FileDeleted
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
from litellm.types.llms.openai import *
|
from litellm.types.llms.openai import *
|
||||||
|
|
||||||
from ..common_utils import get_azure_openai_client
|
from ..common_utils import BaseAzureLLM
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIFilesAPI(BaseLLM):
|
class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||||
"""
|
"""
|
||||||
AzureOpenAI methods to support for batches
|
AzureOpenAI methods to support for batches
|
||||||
- create_file()
|
- create_file()
|
||||||
|
@ -45,14 +44,15 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
|
||||||
|
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
get_azure_openai_client(
|
self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
@ -91,17 +91,16 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||||
]:
|
]:
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
get_azure_openai_client(
|
self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
max_retries=max_retries,
|
|
||||||
organization=None,
|
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
)
|
)
|
||||||
|
@ -144,14 +143,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
get_azure_openai_client(
|
self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
organization=None,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -197,14 +195,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
organization: Optional[str] = None,
|
organization: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
get_azure_openai_client(
|
self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
organization=organization,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
@ -252,14 +249,13 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
||||||
purpose: Optional[str] = None,
|
purpose: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||||
get_azure_openai_client(
|
self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
organization=None, # openai param
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
|
|
@ -3,11 +3,11 @@ from typing import Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||||
|
|
||||||
from litellm.llms.azure.files.handler import get_azure_openai_client
|
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||||
"""
|
"""
|
||||||
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
||||||
"""
|
"""
|
||||||
|
@ -24,6 +24,7 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
||||||
] = None,
|
] = None,
|
||||||
_is_async: bool = False,
|
_is_async: bool = False,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Optional[
|
) -> Optional[
|
||||||
Union[
|
Union[
|
||||||
OpenAI,
|
OpenAI,
|
||||||
|
@ -36,12 +37,10 @@ class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI):
|
||||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
return get_azure_openai_client(
|
return self.get_azure_openai_client(
|
||||||
|
litellm_params=litellm_params or {},
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
organization=organization,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
client=client,
|
client=client,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
|
|
|
@ -27,6 +27,7 @@ class OpenAIFineTuningAPI:
|
||||||
] = None,
|
] = None,
|
||||||
_is_async: bool = False,
|
_is_async: bool = False,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> Optional[
|
) -> Optional[
|
||||||
Union[
|
Union[
|
||||||
OpenAI,
|
OpenAI,
|
||||||
|
|
|
@ -191,6 +191,42 @@ class CallTypes(Enum):
|
||||||
retrieve_batch = "retrieve_batch"
|
retrieve_batch = "retrieve_batch"
|
||||||
pass_through = "pass_through_endpoint"
|
pass_through = "pass_through_endpoint"
|
||||||
anthropic_messages = "anthropic_messages"
|
anthropic_messages = "anthropic_messages"
|
||||||
|
get_assistants = "get_assistants"
|
||||||
|
aget_assistants = "aget_assistants"
|
||||||
|
create_assistants = "create_assistants"
|
||||||
|
acreate_assistants = "acreate_assistants"
|
||||||
|
delete_assistant = "delete_assistant"
|
||||||
|
adelete_assistant = "adelete_assistant"
|
||||||
|
acreate_thread = "acreate_thread"
|
||||||
|
create_thread = "create_thread"
|
||||||
|
aget_thread = "aget_thread"
|
||||||
|
get_thread = "get_thread"
|
||||||
|
a_add_message = "a_add_message"
|
||||||
|
add_message = "add_message"
|
||||||
|
aget_messages = "aget_messages"
|
||||||
|
get_messages = "get_messages"
|
||||||
|
arun_thread = "arun_thread"
|
||||||
|
run_thread = "run_thread"
|
||||||
|
arun_thread_stream = "arun_thread_stream"
|
||||||
|
run_thread_stream = "run_thread_stream"
|
||||||
|
afile_retrieve = "afile_retrieve"
|
||||||
|
file_retrieve = "file_retrieve"
|
||||||
|
afile_delete = "afile_delete"
|
||||||
|
file_delete = "file_delete"
|
||||||
|
afile_list = "afile_list"
|
||||||
|
file_list = "file_list"
|
||||||
|
acreate_file = "acreate_file"
|
||||||
|
create_file = "create_file"
|
||||||
|
afile_content = "afile_content"
|
||||||
|
file_content = "file_content"
|
||||||
|
create_fine_tuning_job = "create_fine_tuning_job"
|
||||||
|
acreate_fine_tuning_job = "acreate_fine_tuning_job"
|
||||||
|
acancel_fine_tuning_job = "acancel_fine_tuning_job"
|
||||||
|
cancel_fine_tuning_job = "cancel_fine_tuning_job"
|
||||||
|
alist_fine_tuning_jobs = "alist_fine_tuning_jobs"
|
||||||
|
list_fine_tuning_jobs = "list_fine_tuning_jobs"
|
||||||
|
aretrieve_fine_tuning_job = "aretrieve_fine_tuning_job"
|
||||||
|
retrieve_fine_tuning_job = "retrieve_fine_tuning_job"
|
||||||
|
|
||||||
|
|
||||||
CallTypesLiteral = Literal[
|
CallTypesLiteral = Literal[
|
||||||
|
|
|
@ -216,16 +216,18 @@ def test_select_azure_base_url_called(setup_mocks):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"call_type",
|
"call_type",
|
||||||
[
|
[
|
||||||
CallTypes.acompletion,
|
call_type
|
||||||
CallTypes.atext_completion,
|
for call_type in CallTypes.__members__.values()
|
||||||
CallTypes.aembedding,
|
if call_type.name.startswith("a")
|
||||||
CallTypes.atranscription,
|
and call_type.name
|
||||||
CallTypes.aspeech,
|
not in [
|
||||||
CallTypes.aimage_generation,
|
"amoderation",
|
||||||
# BATCHES ENDPOINTS
|
"arerank",
|
||||||
CallTypes.acreate_batch,
|
"arealtime",
|
||||||
CallTypes.aretrieve_batch,
|
"anthropic_messages",
|
||||||
# ASSISTANT ENDPOINTS
|
"add_message",
|
||||||
|
"arun_thread_stream",
|
||||||
|
]
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -267,6 +269,28 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
"input_file_id": "123",
|
"input_file_id": "123",
|
||||||
},
|
},
|
||||||
"aretrieve_batch": {"batch_id": "123"},
|
"aretrieve_batch": {"batch_id": "123"},
|
||||||
|
"aget_assistants": {"custom_llm_provider": "azure"},
|
||||||
|
"acreate_assistants": {"custom_llm_provider": "azure"},
|
||||||
|
"adelete_assistant": {"custom_llm_provider": "azure", "assistant_id": "123"},
|
||||||
|
"acreate_thread": {"custom_llm_provider": "azure"},
|
||||||
|
"aget_thread": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||||
|
"a_add_message": {
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"thread_id": "123",
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, how are you?",
|
||||||
|
},
|
||||||
|
"aget_messages": {"custom_llm_provider": "azure", "thread_id": "123"},
|
||||||
|
"arun_thread": {
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"assistant_id": "123",
|
||||||
|
"thread_id": "123",
|
||||||
|
},
|
||||||
|
"acreate_file": {
|
||||||
|
"custom_llm_provider": "azure",
|
||||||
|
"file": MagicMock(),
|
||||||
|
"purpose": "assistants",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get appropriate input for this call type
|
# Get appropriate input for this call type
|
||||||
|
@ -285,12 +309,34 @@ async def test_ensure_initialize_azure_sdk_client_always_used(call_type):
|
||||||
patch_target = (
|
patch_target = (
|
||||||
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
|
"litellm.batches.main.azure_batches_instance.initialize_azure_sdk_client"
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.aget_assistants
|
||||||
|
or call_type == CallTypes.acreate_assistants
|
||||||
|
or call_type == CallTypes.adelete_assistant
|
||||||
|
or call_type == CallTypes.acreate_thread
|
||||||
|
or call_type == CallTypes.aget_thread
|
||||||
|
or call_type == CallTypes.a_add_message
|
||||||
|
or call_type == CallTypes.aget_messages
|
||||||
|
or call_type == CallTypes.arun_thread
|
||||||
|
):
|
||||||
|
patch_target = (
|
||||||
|
"litellm.assistants.main.azure_assistants_api.initialize_azure_sdk_client"
|
||||||
|
)
|
||||||
|
elif call_type == CallTypes.acreate_file or call_type == CallTypes.afile_content:
|
||||||
|
patch_target = (
|
||||||
|
"litellm.files.main.azure_files_instance.initialize_azure_sdk_client"
|
||||||
|
)
|
||||||
|
|
||||||
# Mock the initialize_azure_sdk_client function
|
# Mock the initialize_azure_sdk_client function
|
||||||
with patch(patch_target) as mock_init_azure:
|
with patch(patch_target) as mock_init_azure:
|
||||||
# Also mock async_function_with_fallbacks to prevent actual API calls
|
# Also mock async_function_with_fallbacks to prevent actual API calls
|
||||||
# Call the appropriate router method
|
# Call the appropriate router method
|
||||||
try:
|
try:
|
||||||
|
get_attr = getattr(router, call_type.value, None)
|
||||||
|
if get_attr is None:
|
||||||
|
pytest.skip(
|
||||||
|
f"Skipping {call_type.value} because it is not supported on Router"
|
||||||
|
)
|
||||||
await getattr(router, call_type.value)(
|
await getattr(router, call_type.value)(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
**input_kwarg,
|
**input_kwarg,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue