mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3996 from BerriAI/litellm_azure_assistants_api_support
feat(assistants/main.py): Azure Assistants API support
This commit is contained in:
commit
5ee3b0f30f
8 changed files with 1390 additions and 87 deletions
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, Union, Any, Literal
|
||||
from typing import Optional, Union, Any, Literal, Coroutine, Iterable
|
||||
from typing_extensions import overload
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import (
|
||||
|
@ -19,6 +20,18 @@ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTra
|
|||
from openai import AzureOpenAI, AsyncAzureOpenAI
|
||||
import uuid
|
||||
import os
|
||||
from ..types.llms.openai import (
|
||||
AsyncCursorPage,
|
||||
AssistantToolParam,
|
||||
SyncCursorPage,
|
||||
Assistant,
|
||||
MessageData,
|
||||
OpenAIMessage,
|
||||
OpenAICreateThreadParamsMessage,
|
||||
Thread,
|
||||
AssistantToolParam,
|
||||
Run,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIError(Exception):
|
||||
|
@ -199,6 +212,68 @@ class AzureOpenAIConfig:
|
|||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
|
||||
class AzureOpenAIAssistantsAPIConfig:
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_create_message_params(self):
|
||||
return [
|
||||
"role",
|
||||
"content",
|
||||
"attachments",
|
||||
"metadata",
|
||||
]
|
||||
|
||||
def map_openai_params_create_message_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "role":
|
||||
optional_params["role"] = value
|
||||
if param == "metadata":
|
||||
optional_params["metadata"] = value
|
||||
elif param == "content": # only string accepted
|
||||
if isinstance(value, str):
|
||||
optional_params["content"] = value
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Azure only accepts content as a string.",
|
||||
status_code=400,
|
||||
)
|
||||
elif (
|
||||
param == "attachments"
|
||||
): # this is a v2 param. Azure currently supports the old 'file_id's param
|
||||
file_ids: List[str] = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if "file_id" in item:
|
||||
file_ids.append(item["file_id"])
|
||||
else:
|
||||
if litellm.drop_params == True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
|
||||
type(value), value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
return optional_params
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
# azure_client_params = {
|
||||
# "api_version": api_version,
|
||||
|
@ -1277,3 +1352,753 @@ class AzureChatCompletion(BaseLLM):
|
|||
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AzureAssistantsAPI(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_azure_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI] = None,
|
||||
) -> AzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
||||
return azure_openai_client
|
||||
|
||||
def async_get_azure_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
) -> AsyncAzureOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
|
||||
azure_openai_client = AsyncAzureOpenAI(**data)
|
||||
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
azure_openai_client = client
|
||||
|
||||
return azure_openai_client
|
||||
|
||||
### ASSISTANTS ###
|
||||
|
||||
async def async_get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
azure_openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await azure_openai_client.beta.assistants.list()
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_assistants: Literal[True],
|
||||
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_assistants: Optional[Literal[False]],
|
||||
) -> SyncCursorPage[Assistant]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_assistants=None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants == True:
|
||||
return self.async_get_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
response = azure_openai_client.beta.assistants.list()
|
||||
|
||||
return response
|
||||
|
||||
### MESSAGES ###
|
||||
|
||||
async def a_add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
) -> OpenAIMessage:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||
thread_id, **message_data # type: ignore
|
||||
)
|
||||
|
||||
response_obj: Optional[OpenAIMessage] = None
|
||||
if getattr(thread_message, "status", None) is None:
|
||||
thread_message.status = "completed"
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
else:
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
return response_obj
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
a_add_message: Literal[True],
|
||||
) -> Coroutine[None, None, OpenAIMessage]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
a_add_message: Optional[Literal[False]],
|
||||
) -> OpenAIMessage:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message == True:
|
||||
return self.a_add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
||||
thread_id, **message_data # type: ignore
|
||||
)
|
||||
|
||||
response_obj: Optional[OpenAIMessage] = None
|
||||
if getattr(thread_message, "status", None) is None:
|
||||
thread_message.status = "completed"
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
else:
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
return response_obj
|
||||
|
||||
async def async_get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI] = None,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_messages: Literal[True],
|
||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_messages: Optional[Literal[False]],
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_messages=None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages == True:
|
||||
return self.async_get_messages(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
||||
return response
|
||||
|
||||
### THREADS ###
|
||||
|
||||
async def async_create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if messages is not None:
|
||||
data["messages"] = messages # type: ignore
|
||||
if metadata is not None:
|
||||
data["metadata"] = metadata # type: ignore
|
||||
|
||||
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
||||
|
||||
return Thread(**message_thread.dict())
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
acreate_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AzureOpenAI],
|
||||
acreate_thread: Optional[Literal[False]],
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client=None,
|
||||
acreate_thread=None,
|
||||
):
|
||||
"""
|
||||
Here's an example:
|
||||
```
|
||||
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
|
||||
|
||||
# create thread
|
||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||
openai_api.create_thread(messages=[message])
|
||||
```
|
||||
"""
|
||||
if acreate_thread is not None and acreate_thread == True:
|
||||
return self.async_create_thread(
|
||||
metadata=metadata,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
messages=messages,
|
||||
)
|
||||
azure_openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if messages is not None:
|
||||
data["messages"] = messages # type: ignore
|
||||
if metadata is not None:
|
||||
data["metadata"] = metadata # type: ignore
|
||||
|
||||
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
|
||||
|
||||
return Thread(**message_thread.dict())
|
||||
|
||||
async def async_get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
||||
return Thread(**response.dict())
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
aget_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
aget_thread: Optional[Literal[False]],
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client=None,
|
||||
aget_thread=None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread == True:
|
||||
return self.async_get_thread(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
||||
return Thread(**response.dict())
|
||||
|
||||
# def delete_thread(self):
|
||||
# pass
|
||||
|
||||
### RUNS ###
|
||||
|
||||
async def arun_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
) -> Run:
|
||||
openai_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
|
||||
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
|
||||
additional_instructions=None,
|
||||
instructions=None,
|
||||
metadata=None,
|
||||
model=None,
|
||||
tools=None,
|
||||
)
|
||||
|
||||
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
# thread_id=thread_id,
|
||||
# assistant_id=assistant_id,
|
||||
# additional_instructions=additional_instructions,
|
||||
# instructions=instructions,
|
||||
# metadata=metadata,
|
||||
# model=model,
|
||||
# tools=tools,
|
||||
# )
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AsyncAzureOpenAI],
|
||||
arun_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Run]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[AzureOpenAI],
|
||||
arun_thread: Optional[Literal[False]],
|
||||
) -> Run:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue