[Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support (#5777)

* update docs to show providers

* azure - move assistants in it's own file

* create new azure assistants file

* add azure create assistants

* add test for create / delete assistants

* azure add delete assistants support

* docs add Azure to support providers for assistants api

* fix linting errors

* fix standard logging merge conflict

* docs azure create assistants

* fix doc
This commit is contained in:
Ishaan Jaff 2024-09-18 16:27:33 -07:00 committed by GitHub
parent a109853d21
commit 7e07c37be7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1172 additions and 897 deletions

View file

@ -13,6 +13,7 @@ from openai.types.beta.assistant_deleted import AssistantDeleted
import litellm
from litellm import client
from litellm.llms.AzureOpenAI import assistants
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
exception_type,
@ -21,7 +22,7 @@ from litellm.utils import (
supports_httpx_timeout,
)
from ..llms.AzureOpenAI.azure import AzureAssistantsAPI
from ..llms.AzureOpenAI.assistants import AzureAssistantsAPI
from ..llms.OpenAI.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
@ -210,8 +211,8 @@ async def acreate_assistants(
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["async_create_assistants"] = True
model = kwargs.pop("model", None)
try:
model = kwargs.pop("model", None)
kwargs["client"] = client
# Use a partial function to pass your keyword arguments
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
@ -258,7 +259,7 @@ def create_assistants(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> Assistant:
) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
async_create_assistants: Optional[bool] = kwargs.pop(
"async_create_assistants", None
)
@ -288,7 +289,20 @@ def create_assistants(
elif timeout is None:
timeout = 600.0
response: Optional[Assistant] = None
create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}
response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -310,19 +324,6 @@ def create_assistants(
or os.getenv("OPENAI_API_KEY")
)
create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}
response = openai_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
@ -333,6 +334,46 @@ def create_assistants(
client=client,
async_create_assistants=async_create_assistants, # type: ignore
) # type: ignore
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI
response = azure_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
async_create_assistants=async_create_assistants,
create_assistant_data=create_assistant_data,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
@ -401,7 +442,7 @@ def delete_assistant(
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> AssistantDeleted:
) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
@ -432,7 +473,9 @@ def delete_assistant(
elif timeout is None:
timeout = 600.0
response: Optional[AssistantDeleted] = None
response: Optional[
Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
@ -464,6 +507,46 @@ def delete_assistant(
client=client,
async_delete_assistants=async_delete_assistants,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
if isinstance(client, OpenAI):
client = None # only pass client if it's AzureOpenAI
response = azure_assistants_api.delete_assistant(
assistant_id=assistant_id,
api_base=api_base,
api_key=api_key,
azure_ad_token=azure_ad_token,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
client=client,
async_delete_assistants=async_delete_assistants,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
@ -575,6 +658,9 @@ def create_thread(
elif timeout is None:
timeout = 600.0
api_base: Optional[str] = None
api_key: Optional[str] = None
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
@ -612,12 +698,6 @@ def create_thread(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
@ -626,8 +706,14 @@ def create_thread(
or get_secret("AZURE_API_KEY")
) # type: ignore
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
@ -647,7 +733,7 @@ def create_thread(
max_retries=optional_params.max_retries,
client=client,
acreate_thread=acreate_thread,
) # type :ignore
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
@ -727,7 +813,8 @@ def get_thread(
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
api_base: Optional[str] = None
api_key: Optional[str] = None
response: Optional[Thread] = None
if custom_llm_provider == "openai":
api_base = (
@ -765,7 +852,7 @@ def get_thread(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
@ -780,7 +867,7 @@ def get_thread(
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
@ -912,7 +999,8 @@ def add_message(
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
api_key: Optional[str] = None
api_base: Optional[str] = None
response: Optional[OpenAIMessage] = None
if custom_llm_provider == "openai":
api_base = (
@ -950,7 +1038,7 @@ def add_message(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
@ -965,7 +1053,7 @@ def add_message(
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else:
@ -1071,6 +1159,8 @@ def get_messages(
timeout = 600.0
response: Optional[SyncCursorPage[OpenAIMessage]] = None
api_key: Optional[str] = None
api_base: Optional[str] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -1106,7 +1196,7 @@ def get_messages(
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
) # type: ignore
api_version = (
api_version: Optional[str] = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
@ -1121,7 +1211,7 @@ def get_messages(
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token = None
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
else: