mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
1025 lines
32 KiB
Python
1025 lines
32 KiB
Python
from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union
|
|
|
|
import httpx
|
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
|
from typing_extensions import overload
|
|
|
|
from ...types.llms.openai import (
|
|
Assistant,
|
|
AssistantEventHandler,
|
|
AssistantStreamManager,
|
|
AssistantToolParam,
|
|
AsyncAssistantEventHandler,
|
|
AsyncAssistantStreamManager,
|
|
AsyncCursorPage,
|
|
OpenAICreateThreadParamsMessage,
|
|
OpenAIMessage,
|
|
Run,
|
|
SyncCursorPage,
|
|
Thread,
|
|
)
|
|
from .common_utils import BaseAzureLLM
|
|
|
|
|
|
class AzureAssistantsAPI(BaseAzureLLM):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_azure_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AzureOpenAI:
|
|
if client is None:
|
|
azure_client_params = self.initialize_azure_sdk_client(
|
|
litellm_params=litellm_params or {},
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model_name="",
|
|
api_version=api_version,
|
|
is_async=False,
|
|
)
|
|
azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
|
else:
|
|
azure_openai_client = client
|
|
|
|
return azure_openai_client
|
|
|
|
def async_get_azure_client(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AsyncAzureOpenAI:
|
|
if client is None:
|
|
azure_client_params = self.initialize_azure_sdk_client(
|
|
litellm_params=litellm_params or {},
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
model_name="",
|
|
api_version=api_version,
|
|
is_async=True,
|
|
)
|
|
|
|
azure_openai_client = AsyncAzureOpenAI(**azure_client_params)
|
|
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
|
else:
|
|
azure_openai_client = client
|
|
|
|
return azure_openai_client
|
|
|
|
### ASSISTANTS ###
|
|
|
|
async def async_get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AsyncCursorPage[Assistant]:
|
|
azure_openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await azure_openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
aget_assistants: Literal[True],
|
|
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
|
...
|
|
|
|
@overload
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI],
|
|
aget_assistants: Optional[Literal[False]],
|
|
) -> SyncCursorPage[Assistant]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client=None,
|
|
aget_assistants=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if aget_assistants is not None and aget_assistants is True:
|
|
return self.async_get_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
azure_openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
api_version=api_version,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = azure_openai_client.beta.assistants.list()
|
|
|
|
return response
|
|
|
|
### MESSAGES ###
|
|
|
|
async def a_add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> OpenAIMessage:
|
|
openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
a_add_message: Literal[True],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Coroutine[None, None, OpenAIMessage]:
|
|
...
|
|
|
|
@overload
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI],
|
|
a_add_message: Optional[Literal[False]],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> OpenAIMessage:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def add_message(
|
|
self,
|
|
thread_id: str,
|
|
message_data: dict,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client=None,
|
|
a_add_message: Optional[bool] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if a_add_message is not None and a_add_message is True:
|
|
return self.a_add_message(
|
|
thread_id=thread_id,
|
|
message_data=message_data,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
|
thread_id, **message_data # type: ignore
|
|
)
|
|
|
|
response_obj: Optional[OpenAIMessage] = None
|
|
if getattr(thread_message, "status", None) is None:
|
|
thread_message.status = "completed"
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
else:
|
|
response_obj = OpenAIMessage(**thread_message.dict())
|
|
return response_obj
|
|
|
|
async def async_get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AsyncCursorPage[OpenAIMessage]:
|
|
openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
aget_messages: Literal[True],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
|
...
|
|
|
|
@overload
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI],
|
|
aget_messages: Optional[Literal[False]],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> SyncCursorPage[OpenAIMessage]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_messages(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client=None,
|
|
aget_messages=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if aget_messages is not None and aget_messages is True:
|
|
return self.async_get_messages(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
|
|
return response
|
|
|
|
### THREADS ###
|
|
|
|
async def async_create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Thread:
|
|
openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
acreate_thread: Literal[True],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client: Optional[AzureOpenAI],
|
|
acreate_thread: Optional[Literal[False]],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def create_thread(
|
|
self,
|
|
metadata: Optional[dict],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
client=None,
|
|
acreate_thread=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
"""
|
|
Here's an example:
|
|
```
|
|
from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData
|
|
|
|
# create thread
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
|
openai_api.create_thread(messages=[message])
|
|
```
|
|
"""
|
|
if acreate_thread is not None and acreate_thread is True:
|
|
return self.async_create_thread(
|
|
metadata=metadata,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
messages=messages,
|
|
litellm_params=litellm_params,
|
|
)
|
|
azure_openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
data = {}
|
|
if messages is not None:
|
|
data["messages"] = messages # type: ignore
|
|
if metadata is not None:
|
|
data["metadata"] = metadata # type: ignore
|
|
|
|
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
|
|
|
|
return Thread(**message_thread.dict())
|
|
|
|
async def async_get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Thread:
|
|
openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
aget_thread: Literal[True],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Coroutine[None, None, Thread]:
|
|
...
|
|
|
|
@overload
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI],
|
|
aget_thread: Optional[Literal[False]],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Thread:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def get_thread(
|
|
self,
|
|
thread_id: str,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client=None,
|
|
aget_thread=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if aget_thread is not None and aget_thread is True:
|
|
return self.async_get_thread(
|
|
thread_id=thread_id,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
|
|
return Thread(**response.dict())
|
|
|
|
# def delete_thread(self):
|
|
# pass
|
|
|
|
### RUNS ###
|
|
|
|
async def arun_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Run:
|
|
openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata, # type: ignore
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|
|
|
|
def async_run_thread_stream(
|
|
self,
|
|
client: AsyncAzureOpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
|
data: Dict[str, Any] = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
def run_thread_stream(
|
|
self,
|
|
client: AzureOpenAI,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
event_handler: Optional[AssistantEventHandler],
|
|
litellm_params: Optional[dict] = None,
|
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
|
data: Dict[str, Any] = {
|
|
"thread_id": thread_id,
|
|
"assistant_id": assistant_id,
|
|
"additional_instructions": additional_instructions,
|
|
"instructions": instructions,
|
|
"metadata": metadata,
|
|
"model": model,
|
|
"tools": tools,
|
|
}
|
|
if event_handler is not None:
|
|
data["event_handler"] = event_handler
|
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
|
|
# fmt: off
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
arun_thread: Literal[True],
|
|
) -> Coroutine[None, None, Run]:
|
|
...
|
|
|
|
@overload
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AzureOpenAI],
|
|
arun_thread: Optional[Literal[False]],
|
|
) -> Run:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
def run_thread(
|
|
self,
|
|
thread_id: str,
|
|
assistant_id: str,
|
|
additional_instructions: Optional[str],
|
|
instructions: Optional[str],
|
|
metadata: Optional[Dict],
|
|
model: Optional[str],
|
|
stream: Optional[bool],
|
|
tools: Optional[Iterable[AssistantToolParam]],
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client=None,
|
|
arun_thread=None,
|
|
event_handler: Optional[AssistantEventHandler] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if arun_thread is not None and arun_thread is True:
|
|
if stream is not None and stream is True:
|
|
azure_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
return self.async_run_thread_stream(
|
|
client=azure_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
litellm_params=litellm_params,
|
|
)
|
|
return self.arun_thread(
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata, # type: ignore
|
|
model=model,
|
|
stream=stream,
|
|
tools=tools,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
if stream is not None and stream is True:
|
|
return self.run_thread_stream(
|
|
client=openai_client,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata,
|
|
model=model,
|
|
tools=tools,
|
|
event_handler=event_handler,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=additional_instructions,
|
|
instructions=instructions,
|
|
metadata=metadata, # type: ignore
|
|
model=model,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|
|
|
|
# Create Assistant
|
|
async def async_create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
create_assistant_data: dict,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Assistant:
|
|
azure_openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await azure_openai_client.beta.assistants.create(
|
|
**create_assistant_data
|
|
)
|
|
return response
|
|
|
|
def create_assistants(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
create_assistant_data: dict,
|
|
client=None,
|
|
async_create_assistants=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if async_create_assistants is not None and async_create_assistants is True:
|
|
return self.async_create_assistants(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
create_assistant_data=create_assistant_data,
|
|
litellm_params=litellm_params,
|
|
)
|
|
azure_openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
|
return response
|
|
|
|
# Delete Assistant
|
|
async def async_delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
client: Optional[AsyncAzureOpenAI],
|
|
assistant_id: str,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
azure_openai_client = self.async_get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = await azure_openai_client.beta.assistants.delete(
|
|
assistant_id=assistant_id
|
|
)
|
|
return response
|
|
|
|
def delete_assistant(
|
|
self,
|
|
api_key: Optional[str],
|
|
api_base: Optional[str],
|
|
api_version: Optional[str],
|
|
azure_ad_token: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
assistant_id: str,
|
|
async_delete_assistants: Optional[bool] = None,
|
|
client=None,
|
|
litellm_params: Optional[dict] = None,
|
|
):
|
|
if async_delete_assistants is not None and async_delete_assistants is True:
|
|
return self.async_delete_assistant(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
assistant_id=assistant_id,
|
|
litellm_params=litellm_params,
|
|
)
|
|
azure_openai_client = self.get_azure_client(
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
api_version=api_version,
|
|
azure_ad_token=azure_ad_token,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
client=client,
|
|
litellm_params=litellm_params,
|
|
)
|
|
|
|
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
|
return response
|