[Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support (#5777)

* update docs to show providers

* azure - move assistants in it's own file

* create new azure assistants file

* add azure create assistants

* add test for create / delete assistants

* azure add delete assistants support

* docs add Azure to support providers for assistants api

* fix linting errors

* fix standard logging merge conflict

* docs azure create assistants

* fix doc
This commit is contained in:
Ishaan Jaff 2024-09-18 16:27:33 -07:00 committed by GitHub
parent a109853d21
commit 7e07c37be7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1172 additions and 897 deletions

View file

@ -17,7 +17,8 @@ from litellm import ImageResponse, OpenAIConfig
from litellm.caching import DualCache
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.utils import FileTypes
from litellm.types.utils import FileTypes # type: ignore
from litellm.types.utils import EmbeddingResponse
from litellm.utils import (
Choices,
CustomStreamWrapper,
@ -735,6 +736,11 @@ class AzureChatCompletion(BaseLLM):
azure_client._custom_query.setdefault(
"api-version", api_version
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
@ -1015,12 +1021,12 @@ class AzureChatCompletion(BaseLLM):
async def aembedding(
self,
data: dict,
model_response: ModelResponse,
model_response: EmbeddingResponse,
azure_client_params: dict,
api_key: str,
input: list,
logging_obj: LiteLLMLoggingObj,
client: Optional[AsyncAzureOpenAI] = None,
logging_obj=None,
timeout=None,
):
response = None
@ -1067,9 +1073,9 @@ class AzureChatCompletion(BaseLLM):
api_base: str,
api_version: str,
timeout: float,
logging_obj=None,
model_response=None,
optional_params=None,
logging_obj: LiteLLMLoggingObj,
model_response: EmbeddingResponse,
optional_params: dict,
azure_ad_token: Optional[str] = None,
client=None,
aembedding=None,
@ -1407,8 +1413,8 @@ class AzureChatCompletion(BaseLLM):
azure_client_params: dict,
api_key: str,
input: list,
logging_obj: LiteLLMLoggingObj,
client=None,
logging_obj=None,
timeout=None,
):
response: Optional[dict] = None
@ -1471,14 +1477,14 @@ class AzureChatCompletion(BaseLLM):
self,
prompt: str,
timeout: float,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
model_response: Optional[litellm.utils.ImageResponse] = None,
azure_ad_token: Optional[str] = None,
logging_obj=None,
optional_params=None,
client=None,
aimg_generation=None,
):
@ -1565,7 +1571,8 @@ class AzureChatCompletion(BaseLLM):
raise e
except Exception as e:
if hasattr(e, "status_code"):
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
_status_code = getattr(e, "status_code")
raise AzureOpenAIError(status_code=_status_code, message=str(e))
else:
raise AzureOpenAIError(status_code=500, message=str(e))
@ -1847,831 +1854,6 @@ class AzureChatCompletion(BaseLLM):
return response
class AzureAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_azure_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
) -> AzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
return azure_openai_client
def async_get_azure_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> AsyncAzureOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
azure_openai_client = AsyncAzureOpenAI(**data)
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
else:
azure_openai_client = client
return azure_openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> AsyncCursorPage[Assistant]:
azure_openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await azure_openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
api_version=api_version,
)
response = azure_openai_client.beta.assistants.list()
return response
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: dict,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncAzureOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AzureOpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
messages=messages,
)
azure_openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> Thread:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# def delete_thread(self):
# pass
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
) -> Run:
openai_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
azure_ad_token=azure_ad_token,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
def async_run_thread_stream(
self,
client: AsyncAzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: AzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AsyncAzureOpenAI],
arun_thread: Literal[True],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI],
arun_thread: Optional[Literal[False]],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
azure_ad_token: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
azure_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
return self.async_run_thread_stream(
client=azure_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
openai_client = self.get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
class AzureBatchesAPI(BaseLLM):
"""
Azure methods to support for batches