mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support (#5777)
* update docs to show providers * azure - move assistants in it's own file * create new azure assistants file * add azure create assistants * add test for create / delete assistants * azure add delete assistants support * docs add Azure to support providers for assistants api * fix linting errors * fix standard logging merge conflict * docs azure create assistants * fix doc
This commit is contained in:
parent
a109853d21
commit
7e07c37be7
7 changed files with 1172 additions and 897 deletions
|
@ -13,6 +13,7 @@ from openai.types.beta.assistant_deleted import AssistantDeleted
|
|||
|
||||
import litellm
|
||||
from litellm import client
|
||||
from litellm.llms.AzureOpenAI import assistants
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import (
|
||||
exception_type,
|
||||
|
@ -21,7 +22,7 @@ from litellm.utils import (
|
|||
supports_httpx_timeout,
|
||||
)
|
||||
|
||||
from ..llms.AzureOpenAI.azure import AzureAssistantsAPI
|
||||
from ..llms.AzureOpenAI.assistants import AzureAssistantsAPI
|
||||
from ..llms.OpenAI.openai import OpenAIAssistantsAPI
|
||||
from ..types.llms.openai import *
|
||||
from ..types.router import *
|
||||
|
@ -210,8 +211,8 @@ async def acreate_assistants(
|
|||
loop = asyncio.get_event_loop()
|
||||
### PASS ARGS TO GET ASSISTANTS ###
|
||||
kwargs["async_create_assistants"] = True
|
||||
model = kwargs.pop("model", None)
|
||||
try:
|
||||
model = kwargs.pop("model", None)
|
||||
kwargs["client"] = client
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
|
||||
|
@ -258,7 +259,7 @@ def create_assistants(
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Assistant:
|
||||
) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
|
||||
async_create_assistants: Optional[bool] = kwargs.pop(
|
||||
"async_create_assistants", None
|
||||
)
|
||||
|
@ -288,7 +289,20 @@ def create_assistants(
|
|||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[Assistant] = None
|
||||
create_assistant_data = {
|
||||
"model": model,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"instructions": instructions,
|
||||
"tools": tools,
|
||||
"tool_resources": tool_resources,
|
||||
"metadata": metadata,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -310,19 +324,6 @@ def create_assistants(
|
|||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
create_assistant_data = {
|
||||
"model": model,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"instructions": instructions,
|
||||
"tools": tools,
|
||||
"tool_resources": tool_resources,
|
||||
"metadata": metadata,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"response_format": response_format,
|
||||
}
|
||||
|
||||
response = openai_assistants_api.create_assistants(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
|
@ -333,6 +334,46 @@ def create_assistants(
|
|||
client=client,
|
||||
async_create_assistants=async_create_assistants, # type: ignore
|
||||
) # type: ignore
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
if isinstance(client, OpenAI):
|
||||
client = None # only pass client if it's AzureOpenAI
|
||||
|
||||
response = azure_assistants_api.create_assistants(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
async_create_assistants=async_create_assistants,
|
||||
create_assistant_data=create_assistant_data,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
|
||||
|
@ -401,7 +442,7 @@ def delete_assistant(
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> AssistantDeleted:
|
||||
) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
|
||||
optional_params = GenericLiteLLMParams(
|
||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||
)
|
||||
|
@ -432,7 +473,9 @@ def delete_assistant(
|
|||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
response: Optional[AssistantDeleted] = None
|
||||
response: Optional[
|
||||
Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
|
||||
] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
|
@ -464,6 +507,46 @@ def delete_assistant(
|
|||
client=client,
|
||||
async_delete_assistants=async_delete_assistants,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
if isinstance(client, OpenAI):
|
||||
client = None # only pass client if it's AzureOpenAI
|
||||
|
||||
response = azure_assistants_api.delete_assistant(
|
||||
assistant_id=assistant_id,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
async_delete_assistants=async_delete_assistants,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
|
||||
|
@ -575,6 +658,9 @@ def create_thread(
|
|||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
response: Optional[Thread] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
|
@ -612,12 +698,6 @@ def create_thread(
|
|||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
|
@ -626,8 +706,14 @@ def create_thread(
|
|||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
api_version: Optional[str] = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token = None
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
|
@ -647,7 +733,7 @@ def create_thread(
|
|||
max_retries=optional_params.max_retries,
|
||||
client=client,
|
||||
acreate_thread=acreate_thread,
|
||||
) # type :ignore
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||
|
@ -727,7 +813,8 @@ def get_thread(
|
|||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
response: Optional[Thread] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
|
@ -765,7 +852,7 @@ def get_thread(
|
|||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
api_version: Optional[str] = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
|
@ -780,7 +867,7 @@ def get_thread(
|
|||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token = None
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
|
@ -912,7 +999,8 @@ def add_message(
|
|||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
response: Optional[OpenAIMessage] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
|
@ -950,7 +1038,7 @@ def add_message(
|
|||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
api_version: Optional[str] = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
|
@ -965,7 +1053,7 @@ def add_message(
|
|||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token = None
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
|
@ -1071,6 +1159,8 @@ def get_messages(
|
|||
timeout = 600.0
|
||||
|
||||
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -1106,7 +1196,7 @@ def get_messages(
|
|||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
) # type: ignore
|
||||
|
||||
api_version = (
|
||||
api_version: Optional[str] = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
|
@ -1121,7 +1211,7 @@ def get_messages(
|
|||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token = None
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue