forked from phoenix/litellm-mirror
[Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support (#5777)
* update docs to show providers * azure - move assistants in it's own file * create new azure assistants file * add azure create assistants * add test for create / delete assistants * azure add delete assistants support * docs add Azure to support providers for assistants api * fix linting errors * fix standard logging merge conflict * docs azure create assistants * fix doc
This commit is contained in:
parent
a109853d21
commit
7e07c37be7
7 changed files with 1172 additions and 897 deletions
|
@ -7,6 +7,7 @@ Covers Threads, Messages, Assistants.
|
||||||
|
|
||||||
LiteLLM currently covers:
|
LiteLLM currently covers:
|
||||||
- Create Assistants
|
- Create Assistants
|
||||||
|
- Delete Assistants
|
||||||
- Get Assistants
|
- Get Assistants
|
||||||
- Create Thread
|
- Create Thread
|
||||||
- Get Thread
|
- Get Thread
|
||||||
|
@ -14,6 +15,12 @@ LiteLLM currently covers:
|
||||||
- Get Messages
|
- Get Messages
|
||||||
- Run Thread
|
- Run Thread
|
||||||
|
|
||||||
|
|
||||||
|
## **Supported Providers**:
|
||||||
|
- [OpenAI](#quick-start)
|
||||||
|
- [Azure OpenAI](#azure-openai)
|
||||||
|
- [OpenAI-Compatible APIs](#openai-compatible-apis)
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
Call an existing Assistant.
|
Call an existing Assistant.
|
||||||
|
@ -283,6 +290,32 @@ curl -X POST 'http://0.0.0.0:4000/threads/{thread_id}/runs' \
|
||||||
|
|
||||||
## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/assistants)
|
## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/assistants)
|
||||||
|
|
||||||
|
|
||||||
|
## Azure OpenAI
|
||||||
|
|
||||||
|
**config**
|
||||||
|
```yaml
|
||||||
|
assistant_settings:
|
||||||
|
custom_llm_provider: azure
|
||||||
|
litellm_params:
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
```
|
||||||
|
|
||||||
|
**curl**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:4000/v1/assistants" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-d '{
|
||||||
|
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
|
||||||
|
"name": "Math Tutor",
|
||||||
|
"tools": [{"type": "code_interpreter"}],
|
||||||
|
"model": "<my-azure-deployment-name>"
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
## OpenAI-Compatible APIs
|
## OpenAI-Compatible APIs
|
||||||
|
|
||||||
To call openai-compatible Assistants API's (eg. Astra Assistants API), just add `openai/` to the model name:
|
To call openai-compatible Assistants API's (eg. Astra Assistants API), just add `openai/` to the model name:
|
||||||
|
|
|
@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
Covers Batches, Files
|
Covers Batches, Files
|
||||||
|
|
||||||
Supported Providers:
|
## **Supported Providers**:
|
||||||
- Azure OpenAI
|
- Azure OpenAI
|
||||||
- OpenAI
|
- OpenAI
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from openai.types.beta.assistant_deleted import AssistantDeleted
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import client
|
from litellm import client
|
||||||
|
from litellm.llms.AzureOpenAI import assistants
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
exception_type,
|
exception_type,
|
||||||
|
@ -21,7 +22,7 @@ from litellm.utils import (
|
||||||
supports_httpx_timeout,
|
supports_httpx_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..llms.AzureOpenAI.azure import AzureAssistantsAPI
|
from ..llms.AzureOpenAI.assistants import AzureAssistantsAPI
|
||||||
from ..llms.OpenAI.openai import OpenAIAssistantsAPI
|
from ..llms.OpenAI.openai import OpenAIAssistantsAPI
|
||||||
from ..types.llms.openai import *
|
from ..types.llms.openai import *
|
||||||
from ..types.router import *
|
from ..types.router import *
|
||||||
|
@ -210,8 +211,8 @@ async def acreate_assistants(
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
### PASS ARGS TO GET ASSISTANTS ###
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
kwargs["async_create_assistants"] = True
|
kwargs["async_create_assistants"] = True
|
||||||
try:
|
|
||||||
model = kwargs.pop("model", None)
|
model = kwargs.pop("model", None)
|
||||||
|
try:
|
||||||
kwargs["client"] = client
|
kwargs["client"] = client
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
|
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
|
||||||
|
@ -258,7 +259,7 @@ def create_assistants(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Assistant:
|
) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
|
||||||
async_create_assistants: Optional[bool] = kwargs.pop(
|
async_create_assistants: Optional[bool] = kwargs.pop(
|
||||||
"async_create_assistants", None
|
"async_create_assistants", None
|
||||||
)
|
)
|
||||||
|
@ -288,7 +289,20 @@ def create_assistants(
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
|
||||||
response: Optional[Assistant] = None
|
create_assistant_data = {
|
||||||
|
"model": model,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"instructions": instructions,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_resources": tool_resources,
|
||||||
|
"metadata": metadata,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"response_format": response_format,
|
||||||
|
}
|
||||||
|
|
||||||
|
response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -310,19 +324,6 @@ def create_assistants(
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
create_assistant_data = {
|
|
||||||
"model": model,
|
|
||||||
"name": name,
|
|
||||||
"description": description,
|
|
||||||
"instructions": instructions,
|
|
||||||
"tools": tools,
|
|
||||||
"tool_resources": tool_resources,
|
|
||||||
"metadata": metadata,
|
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"response_format": response_format,
|
|
||||||
}
|
|
||||||
|
|
||||||
response = openai_assistants_api.create_assistants(
|
response = openai_assistants_api.create_assistants(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -333,6 +334,46 @@ def create_assistants(
|
||||||
client=client,
|
client=client,
|
||||||
async_create_assistants=async_create_assistants, # type: ignore
|
async_create_assistants=async_create_assistants, # type: ignore
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
if isinstance(client, OpenAI):
|
||||||
|
client = None # only pass client if it's AzureOpenAI
|
||||||
|
|
||||||
|
response = azure_assistants_api.create_assistants(
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
client=client,
|
||||||
|
async_create_assistants=async_create_assistants,
|
||||||
|
create_assistant_data=create_assistant_data,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
|
||||||
|
@ -401,7 +442,7 @@ def delete_assistant(
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AssistantDeleted:
|
) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
|
||||||
optional_params = GenericLiteLLMParams(
|
optional_params = GenericLiteLLMParams(
|
||||||
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -432,7 +473,9 @@ def delete_assistant(
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
|
||||||
response: Optional[AssistantDeleted] = None
|
response: Optional[
|
||||||
|
Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
|
||||||
|
] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base
|
optional_params.api_base
|
||||||
|
@ -464,6 +507,46 @@ def delete_assistant(
|
||||||
client=client,
|
client=client,
|
||||||
async_delete_assistants=async_delete_assistants,
|
async_delete_assistants=async_delete_assistants,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "azure":
|
||||||
|
api_base = (
|
||||||
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_version = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
optional_params.api_key
|
||||||
|
or litellm.api_key
|
||||||
|
or litellm.azure_key
|
||||||
|
or get_secret("AZURE_OPENAI_API_KEY")
|
||||||
|
or get_secret("AZURE_API_KEY")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
extra_body = optional_params.get("extra_body", {})
|
||||||
|
azure_ad_token: Optional[str] = None
|
||||||
|
if extra_body is not None:
|
||||||
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
|
else:
|
||||||
|
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||||
|
|
||||||
|
if isinstance(client, OpenAI):
|
||||||
|
client = None # only pass client if it's AzureOpenAI
|
||||||
|
|
||||||
|
response = azure_assistants_api.delete_assistant(
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=optional_params.max_retries,
|
||||||
|
client=client,
|
||||||
|
async_delete_assistants=async_delete_assistants,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
|
||||||
|
@ -575,6 +658,9 @@ def create_thread(
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
response: Optional[Thread] = None
|
response: Optional[Thread] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -612,12 +698,6 @@ def create_thread(
|
||||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_version = (
|
|
||||||
optional_params.api_version
|
|
||||||
or litellm.api_version
|
|
||||||
or get_secret("AZURE_API_VERSION")
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
api_key = (
|
api_key = (
|
||||||
optional_params.api_key
|
optional_params.api_key
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
|
@ -626,8 +706,14 @@ def create_thread(
|
||||||
or get_secret("AZURE_API_KEY")
|
or get_secret("AZURE_API_KEY")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
api_version: Optional[str] = (
|
||||||
|
optional_params.api_version
|
||||||
|
or litellm.api_version
|
||||||
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token = None
|
azure_ad_token: Optional[str] = None
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
|
@ -647,7 +733,7 @@ def create_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
client=client,
|
client=client,
|
||||||
acreate_thread=acreate_thread,
|
acreate_thread=acreate_thread,
|
||||||
) # type :ignore
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
||||||
|
@ -727,7 +813,8 @@ def get_thread(
|
||||||
timeout = float(timeout) # type: ignore
|
timeout = float(timeout) # type: ignore
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
response: Optional[Thread] = None
|
response: Optional[Thread] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -765,7 +852,7 @@ def get_thread(
|
||||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version: Optional[str] = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
@ -780,7 +867,7 @@ def get_thread(
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token = None
|
azure_ad_token: Optional[str] = None
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
|
@ -912,7 +999,8 @@ def add_message(
|
||||||
timeout = float(timeout) # type: ignore
|
timeout = float(timeout) # type: ignore
|
||||||
elif timeout is None:
|
elif timeout is None:
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
api_base: Optional[str] = None
|
||||||
response: Optional[OpenAIMessage] = None
|
response: Optional[OpenAIMessage] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -950,7 +1038,7 @@ def add_message(
|
||||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version: Optional[str] = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
@ -965,7 +1053,7 @@ def add_message(
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token = None
|
azure_ad_token: Optional[str] = None
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
|
@ -1071,6 +1159,8 @@ def get_messages(
|
||||||
timeout = 600.0
|
timeout = 600.0
|
||||||
|
|
||||||
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
api_base: Optional[str] = None
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||||
|
@ -1106,7 +1196,7 @@ def get_messages(
|
||||||
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
api_version = (
|
api_version: Optional[str] = (
|
||||||
optional_params.api_version
|
optional_params.api_version
|
||||||
or litellm.api_version
|
or litellm.api_version
|
||||||
or get_secret("AZURE_API_VERSION")
|
or get_secret("AZURE_API_VERSION")
|
||||||
|
@ -1121,7 +1211,7 @@ def get_messages(
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
extra_body = optional_params.get("extra_body", {})
|
extra_body = optional_params.get("extra_body", {})
|
||||||
azure_ad_token = None
|
azure_ad_token: Optional[str] = None
|
||||||
if extra_body is not None:
|
if extra_body is not None:
|
||||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||||
else:
|
else:
|
||||||
|
|
975
litellm/llms/AzureOpenAI/assistants.py
Normal file
975
litellm/llms/AzureOpenAI/assistants.py
Normal file
|
@ -0,0 +1,975 @@
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
from typing_extensions import overload
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.utils import FileTypes # type: ignore
|
||||||
|
|
||||||
|
from ...types.llms.openai import (
|
||||||
|
Assistant,
|
||||||
|
AssistantEventHandler,
|
||||||
|
AssistantStreamManager,
|
||||||
|
AssistantToolParam,
|
||||||
|
AsyncAssistantEventHandler,
|
||||||
|
AsyncAssistantStreamManager,
|
||||||
|
AsyncCursorPage,
|
||||||
|
OpenAICreateThreadParamsMessage,
|
||||||
|
OpenAIMessage,
|
||||||
|
Run,
|
||||||
|
SyncCursorPage,
|
||||||
|
Thread,
|
||||||
|
)
|
||||||
|
from ..base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class AzureAssistantsAPI(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_azure_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI] = None,
|
||||||
|
) -> AzureOpenAI:
|
||||||
|
received_args = locals()
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["azure_endpoint"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_openai_client = client
|
||||||
|
|
||||||
|
return azure_openai_client
|
||||||
|
|
||||||
|
def async_get_azure_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
) -> AsyncAzureOpenAI:
|
||||||
|
received_args = locals()
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["azure_endpoint"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
azure_openai_client = AsyncAzureOpenAI(**data)
|
||||||
|
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_openai_client = client
|
||||||
|
|
||||||
|
return azure_openai_client
|
||||||
|
|
||||||
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
async def async_get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
) -> AsyncCursorPage[Assistant]:
|
||||||
|
azure_openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await azure_openai_client.beta.assistants.list()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
aget_assistants: Literal[True],
|
||||||
|
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
aget_assistants: Optional[Literal[False]],
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client=None,
|
||||||
|
aget_assistants=None,
|
||||||
|
):
|
||||||
|
if aget_assistants is not None and aget_assistants == True:
|
||||||
|
return self.async_get_assistants(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
azure_openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
api_version=api_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = azure_openai_client.beta.assistants.list()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### MESSAGES ###
|
||||||
|
|
||||||
|
async def a_add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||||
|
thread_id, **message_data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj: Optional[OpenAIMessage] = None
|
||||||
|
if getattr(thread_message, "status", None) is None:
|
||||||
|
thread_message.status = "completed"
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
else:
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
return response_obj
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
a_add_message: Literal[True],
|
||||||
|
) -> Coroutine[None, None, OpenAIMessage]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
a_add_message: Optional[Literal[False]],
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: dict,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client=None,
|
||||||
|
a_add_message: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if a_add_message is not None and a_add_message == True:
|
||||||
|
return self.a_add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_data=message_data,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
||||||
|
thread_id, **message_data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj: Optional[OpenAIMessage] = None
|
||||||
|
if getattr(thread_message, "status", None) is None:
|
||||||
|
thread_message.status = "completed"
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
else:
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
return response_obj
|
||||||
|
|
||||||
|
async def async_get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
|
openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
aget_messages: Literal[True],
|
||||||
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
aget_messages: Optional[Literal[False]],
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client=None,
|
||||||
|
aget_messages=None,
|
||||||
|
):
|
||||||
|
if aget_messages is not None and aget_messages == True:
|
||||||
|
return self.async_get_messages(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
### THREADS ###
|
||||||
|
|
||||||
|
async def async_create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if messages is not None:
|
||||||
|
data["messages"] = messages # type: ignore
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata # type: ignore
|
||||||
|
|
||||||
|
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
||||||
|
|
||||||
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
acreate_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Thread]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
acreate_thread: Optional[Literal[False]],
|
||||||
|
) -> Thread:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
client=None,
|
||||||
|
acreate_thread=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Here's an example:
|
||||||
|
```
|
||||||
|
from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData
|
||||||
|
|
||||||
|
# create thread
|
||||||
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
||||||
|
openai_api.create_thread(messages=[message])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if acreate_thread is not None and acreate_thread == True:
|
||||||
|
return self.async_create_thread(
|
||||||
|
metadata=metadata,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
azure_openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if messages is not None:
|
||||||
|
data["messages"] = messages # type: ignore
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata # type: ignore
|
||||||
|
|
||||||
|
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
|
||||||
|
|
||||||
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
async def async_get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
|
||||||
|
return Thread(**response.dict())
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
aget_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Thread]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
aget_thread: Optional[Literal[False]],
|
||||||
|
) -> Thread:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client=None,
|
||||||
|
aget_thread=None,
|
||||||
|
):
|
||||||
|
if aget_thread is not None and aget_thread == True:
|
||||||
|
return self.async_get_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
|
||||||
|
return Thread(**response.dict())
|
||||||
|
|
||||||
|
# def delete_thread(self):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
### RUNS ###
|
||||||
|
|
||||||
|
async def arun_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
) -> Run:
|
||||||
|
openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def async_run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: AsyncAzureOpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
|
def run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: AzureOpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
arun_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Run]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AzureOpenAI],
|
||||||
|
arun_thread: Optional[Literal[False]],
|
||||||
|
) -> Run:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client=None,
|
||||||
|
arun_thread=None,
|
||||||
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
|
):
|
||||||
|
if arun_thread is not None and arun_thread == True:
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
azure_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
return self.async_run_thread_stream(
|
||||||
|
client=azure_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
|
return self.arun_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
return self.run_thread_stream(
|
||||||
|
client=openai_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Create Assistant
|
||||||
|
async def async_create_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
create_assistant_data: dict,
|
||||||
|
) -> Assistant:
|
||||||
|
azure_openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await azure_openai_client.beta.assistants.create(
|
||||||
|
**create_assistant_data
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def create_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
create_assistant_data: dict,
|
||||||
|
client=None,
|
||||||
|
async_create_assistants=None,
|
||||||
|
):
|
||||||
|
if async_create_assistants is not None and async_create_assistants == True:
|
||||||
|
return self.async_create_assistants(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
create_assistant_data=create_assistant_data,
|
||||||
|
)
|
||||||
|
azure_openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = azure_openai_client.beta.assistants.create(**create_assistant_data)
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Delete Assistant
|
||||||
|
async def async_delete_assistant(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
client: Optional[AsyncAzureOpenAI],
|
||||||
|
assistant_id: str,
|
||||||
|
):
|
||||||
|
azure_openai_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await azure_openai_client.beta.assistants.delete(
|
||||||
|
assistant_id=assistant_id
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def delete_assistant(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_version: Optional[str],
|
||||||
|
azure_ad_token: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
assistant_id: str,
|
||||||
|
async_delete_assistants: Optional[bool] = None,
|
||||||
|
client=None,
|
||||||
|
):
|
||||||
|
if async_delete_assistants is not None and async_delete_assistants == True:
|
||||||
|
return self.async_delete_assistant(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
)
|
||||||
|
azure_openai_client = self.get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id)
|
||||||
|
return response
|
|
@ -17,7 +17,8 @@ from litellm import ImageResponse, OpenAIConfig
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.utils import FileTypes
|
from litellm.types.utils import FileTypes # type: ignore
|
||||||
|
from litellm.types.utils import EmbeddingResponse
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -735,6 +736,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client._custom_query.setdefault(
|
azure_client._custom_query.setdefault(
|
||||||
"api-version", api_version
|
"api-version", api_version
|
||||||
)
|
)
|
||||||
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AzureOpenAI",
|
||||||
|
)
|
||||||
|
|
||||||
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||||
azure_client=azure_client, data=data, timeout=timeout
|
azure_client=azure_client, data=data, timeout=timeout
|
||||||
|
@ -1015,12 +1021,12 @@ class AzureChatCompletion(BaseLLM):
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: EmbeddingResponse,
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
input: list,
|
input: list,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
logging_obj=None,
|
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
|
@ -1067,9 +1073,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
logging_obj=None,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
model_response=None,
|
model_response: EmbeddingResponse,
|
||||||
optional_params=None,
|
optional_params: dict,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aembedding=None,
|
aembedding=None,
|
||||||
|
@ -1407,8 +1413,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
input: list,
|
input: list,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
client=None,
|
client=None,
|
||||||
logging_obj=None,
|
|
||||||
timeout=None,
|
timeout=None,
|
||||||
):
|
):
|
||||||
response: Optional[dict] = None
|
response: Optional[dict] = None
|
||||||
|
@ -1471,14 +1477,14 @@ class AzureChatCompletion(BaseLLM):
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
logging_obj=None,
|
|
||||||
optional_params=None,
|
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
):
|
):
|
||||||
|
@ -1565,7 +1571,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if hasattr(e, "status_code"):
|
if hasattr(e, "status_code"):
|
||||||
raise AzureOpenAIError(status_code=e.status_code, message=str(e))
|
_status_code = getattr(e, "status_code")
|
||||||
|
raise AzureOpenAIError(status_code=_status_code, message=str(e))
|
||||||
else:
|
else:
|
||||||
raise AzureOpenAIError(status_code=500, message=str(e))
|
raise AzureOpenAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
@ -1847,831 +1854,6 @@ class AzureChatCompletion(BaseLLM):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class AzureAssistantsAPI(BaseLLM):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def get_azure_client(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI] = None,
|
|
||||||
) -> AzureOpenAI:
|
|
||||||
received_args = locals()
|
|
||||||
if client is None:
|
|
||||||
data = {}
|
|
||||||
for k, v in received_args.items():
|
|
||||||
if k == "self" or k == "client":
|
|
||||||
pass
|
|
||||||
elif k == "api_base" and v is not None:
|
|
||||||
data["azure_endpoint"] = v
|
|
||||||
elif v is not None:
|
|
||||||
data[k] = v
|
|
||||||
azure_openai_client = AzureOpenAI(**data) # type: ignore
|
|
||||||
else:
|
|
||||||
azure_openai_client = client
|
|
||||||
|
|
||||||
return azure_openai_client
|
|
||||||
|
|
||||||
def async_get_azure_client(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
|
||||||
) -> AsyncAzureOpenAI:
|
|
||||||
received_args = locals()
|
|
||||||
if client is None:
|
|
||||||
data = {}
|
|
||||||
for k, v in received_args.items():
|
|
||||||
if k == "self" or k == "client":
|
|
||||||
pass
|
|
||||||
elif k == "api_base" and v is not None:
|
|
||||||
data["azure_endpoint"] = v
|
|
||||||
elif v is not None:
|
|
||||||
data[k] = v
|
|
||||||
|
|
||||||
azure_openai_client = AsyncAzureOpenAI(**data)
|
|
||||||
# azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore
|
|
||||||
else:
|
|
||||||
azure_openai_client = client
|
|
||||||
|
|
||||||
return azure_openai_client
|
|
||||||
|
|
||||||
### ASSISTANTS ###
|
|
||||||
|
|
||||||
async def async_get_assistants(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
) -> AsyncCursorPage[Assistant]:
|
|
||||||
azure_openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await azure_openai_client.beta.assistants.list()
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_assistants(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
aget_assistants: Literal[True],
|
|
||||||
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_assistants(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
aget_assistants: Optional[Literal[False]],
|
|
||||||
) -> SyncCursorPage[Assistant]:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def get_assistants(
|
|
||||||
self,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client=None,
|
|
||||||
aget_assistants=None,
|
|
||||||
):
|
|
||||||
if aget_assistants is not None and aget_assistants == True:
|
|
||||||
return self.async_get_assistants(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
azure_openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
api_version=api_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = azure_openai_client.beta.assistants.list()
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
### MESSAGES ###
|
|
||||||
|
|
||||||
async def a_add_message(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
message_data: dict,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
|
||||||
) -> OpenAIMessage:
|
|
||||||
openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
|
||||||
thread_id, **message_data # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
response_obj: Optional[OpenAIMessage] = None
|
|
||||||
if getattr(thread_message, "status", None) is None:
|
|
||||||
thread_message.status = "completed"
|
|
||||||
response_obj = OpenAIMessage(**thread_message.dict())
|
|
||||||
else:
|
|
||||||
response_obj = OpenAIMessage(**thread_message.dict())
|
|
||||||
return response_obj
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def add_message(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
message_data: dict,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
a_add_message: Literal[True],
|
|
||||||
) -> Coroutine[None, None, OpenAIMessage]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def add_message(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
message_data: dict,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
a_add_message: Optional[Literal[False]],
|
|
||||||
) -> OpenAIMessage:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def add_message(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
message_data: dict,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client=None,
|
|
||||||
a_add_message: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
if a_add_message is not None and a_add_message == True:
|
|
||||||
return self.a_add_message(
|
|
||||||
thread_id=thread_id,
|
|
||||||
message_data=message_data,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore
|
|
||||||
thread_id, **message_data # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
response_obj: Optional[OpenAIMessage] = None
|
|
||||||
if getattr(thread_message, "status", None) is None:
|
|
||||||
thread_message.status = "completed"
|
|
||||||
response_obj = OpenAIMessage(**thread_message.dict())
|
|
||||||
else:
|
|
||||||
response_obj = OpenAIMessage(**thread_message.dict())
|
|
||||||
return response_obj
|
|
||||||
|
|
||||||
async def async_get_messages(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
|
||||||
) -> AsyncCursorPage[OpenAIMessage]:
|
|
||||||
openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_messages(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
aget_messages: Literal[True],
|
|
||||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_messages(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
aget_messages: Optional[Literal[False]],
|
|
||||||
) -> SyncCursorPage[OpenAIMessage]:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def get_messages(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client=None,
|
|
||||||
aget_messages=None,
|
|
||||||
):
|
|
||||||
if aget_messages is not None and aget_messages == True:
|
|
||||||
return self.async_get_messages(
|
|
||||||
thread_id=thread_id,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
### THREADS ###
|
|
||||||
|
|
||||||
async def async_create_thread(
|
|
||||||
self,
|
|
||||||
metadata: Optional[dict],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
||||||
) -> Thread:
|
|
||||||
openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
if messages is not None:
|
|
||||||
data["messages"] = messages # type: ignore
|
|
||||||
if metadata is not None:
|
|
||||||
data["metadata"] = metadata # type: ignore
|
|
||||||
|
|
||||||
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
|
||||||
|
|
||||||
return Thread(**message_thread.dict())
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def create_thread(
|
|
||||||
self,
|
|
||||||
metadata: Optional[dict],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
acreate_thread: Literal[True],
|
|
||||||
) -> Coroutine[None, None, Thread]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def create_thread(
|
|
||||||
self,
|
|
||||||
metadata: Optional[dict],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
acreate_thread: Optional[Literal[False]],
|
|
||||||
) -> Thread:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def create_thread(
|
|
||||||
self,
|
|
||||||
metadata: Optional[dict],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
|
||||||
client=None,
|
|
||||||
acreate_thread=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Here's an example:
|
|
||||||
```
|
|
||||||
from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData
|
|
||||||
|
|
||||||
# create thread
|
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
|
|
||||||
openai_api.create_thread(messages=[message])
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if acreate_thread is not None and acreate_thread == True:
|
|
||||||
return self.async_create_thread(
|
|
||||||
metadata=metadata,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
azure_openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
if messages is not None:
|
|
||||||
data["messages"] = messages # type: ignore
|
|
||||||
if metadata is not None:
|
|
||||||
data["metadata"] = metadata # type: ignore
|
|
||||||
|
|
||||||
message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore
|
|
||||||
|
|
||||||
return Thread(**message_thread.dict())
|
|
||||||
|
|
||||||
async def async_get_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
) -> Thread:
|
|
||||||
openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
||||||
|
|
||||||
return Thread(**response.dict())
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
aget_thread: Literal[True],
|
|
||||||
) -> Coroutine[None, None, Thread]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
aget_thread: Optional[Literal[False]],
|
|
||||||
) -> Thread:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def get_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client=None,
|
|
||||||
aget_thread=None,
|
|
||||||
):
|
|
||||||
if aget_thread is not None and aget_thread == True:
|
|
||||||
return self.async_get_thread(
|
|
||||||
thread_id=thread_id,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
|
|
||||||
|
|
||||||
return Thread(**response.dict())
|
|
||||||
|
|
||||||
# def delete_thread(self):
|
|
||||||
# pass
|
|
||||||
|
|
||||||
### RUNS ###
|
|
||||||
|
|
||||||
async def arun_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
stream: Optional[bool],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
) -> Run:
|
|
||||||
openai_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def async_run_thread_stream(
|
|
||||||
self,
|
|
||||||
client: AsyncAzureOpenAI,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
event_handler: Optional[AssistantEventHandler],
|
|
||||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
|
||||||
data = {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"assistant_id": assistant_id,
|
|
||||||
"additional_instructions": additional_instructions,
|
|
||||||
"instructions": instructions,
|
|
||||||
"metadata": metadata,
|
|
||||||
"model": model,
|
|
||||||
"tools": tools,
|
|
||||||
}
|
|
||||||
if event_handler is not None:
|
|
||||||
data["event_handler"] = event_handler
|
|
||||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
||||||
|
|
||||||
def run_thread_stream(
|
|
||||||
self,
|
|
||||||
client: AzureOpenAI,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
event_handler: Optional[AssistantEventHandler],
|
|
||||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
|
||||||
data = {
|
|
||||||
"thread_id": thread_id,
|
|
||||||
"assistant_id": assistant_id,
|
|
||||||
"additional_instructions": additional_instructions,
|
|
||||||
"instructions": instructions,
|
|
||||||
"metadata": metadata,
|
|
||||||
"model": model,
|
|
||||||
"tools": tools,
|
|
||||||
}
|
|
||||||
if event_handler is not None:
|
|
||||||
data["event_handler"] = event_handler
|
|
||||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def run_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
stream: Optional[bool],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AsyncAzureOpenAI],
|
|
||||||
arun_thread: Literal[True],
|
|
||||||
) -> Coroutine[None, None, Run]:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def run_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
stream: Optional[bool],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client: Optional[AzureOpenAI],
|
|
||||||
arun_thread: Optional[Literal[False]],
|
|
||||||
) -> Run:
|
|
||||||
...
|
|
||||||
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def run_thread(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
assistant_id: str,
|
|
||||||
additional_instructions: Optional[str],
|
|
||||||
instructions: Optional[str],
|
|
||||||
metadata: Optional[object],
|
|
||||||
model: Optional[str],
|
|
||||||
stream: Optional[bool],
|
|
||||||
tools: Optional[Iterable[AssistantToolParam]],
|
|
||||||
api_key: Optional[str],
|
|
||||||
api_base: Optional[str],
|
|
||||||
api_version: Optional[str],
|
|
||||||
azure_ad_token: Optional[str],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
max_retries: Optional[int],
|
|
||||||
client=None,
|
|
||||||
arun_thread=None,
|
|
||||||
event_handler: Optional[AssistantEventHandler] = None,
|
|
||||||
):
|
|
||||||
if arun_thread is not None and arun_thread == True:
|
|
||||||
if stream is not None and stream == True:
|
|
||||||
azure_client = self.async_get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
return self.async_run_thread_stream(
|
|
||||||
client=azure_client,
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
event_handler=event_handler,
|
|
||||||
)
|
|
||||||
return self.arun_thread(
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
tools=tools,
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
openai_client = self.get_azure_client(
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
api_version=api_version,
|
|
||||||
azure_ad_token=azure_ad_token,
|
|
||||||
timeout=timeout,
|
|
||||||
max_retries=max_retries,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream is not None and stream == True:
|
|
||||||
return self.run_thread_stream(
|
|
||||||
client=openai_client,
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
event_handler=event_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
additional_instructions=additional_instructions,
|
|
||||||
instructions=instructions,
|
|
||||||
metadata=metadata,
|
|
||||||
model=model,
|
|
||||||
tools=tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class AzureBatchesAPI(BaseLLM):
|
class AzureBatchesAPI(BaseLLM):
|
||||||
"""
|
"""
|
||||||
Azure methods to support for batches
|
Azure methods to support for batches
|
||||||
|
|
|
@ -1,19 +1,4 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gemini-vision
|
|
||||||
litellm_params:
|
|
||||||
model: vertex_ai/gemini-1.5-pro
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001
|
|
||||||
vertex_project: "adroit-crow-413218"
|
|
||||||
vertex_location: "us-central1"
|
|
||||||
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
|
|
||||||
- model_name: gemini-vision
|
|
||||||
litellm_params:
|
|
||||||
model: vertex_ai/gemini-1.0-pro-vision-001
|
|
||||||
api_base: https://exampleopenaiendpoint-production-c715.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001
|
|
||||||
vertex_project: "adroit-crow-413218"
|
|
||||||
vertex_location: "us-central1"
|
|
||||||
vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json"
|
|
||||||
|
|
||||||
- model_name: fake-azure-endpoint
|
- model_name: fake-azure-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/429
|
model: openai/429
|
||||||
|
@ -21,6 +6,13 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
|
|
||||||
|
|
||||||
|
assistant_settings:
|
||||||
|
custom_llm_provider: azure
|
||||||
|
litellm_params:
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
||||||
|
|
|
@ -59,25 +59,28 @@ async def test_get_assistants(provider, sync_mode):
|
||||||
assert isinstance(assistants, AsyncCursorPage)
|
assert isinstance(assistants, AsyncCursorPage)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai"])
|
@pytest.mark.parametrize("provider", ["azure", "openai"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"sync_mode",
|
"sync_mode",
|
||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio()
|
||||||
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
async def test_create_delete_assistants(provider, sync_mode):
|
async def test_create_delete_assistants(provider, sync_mode):
|
||||||
data = {
|
model = "gpt-4-turbo"
|
||||||
"custom_llm_provider": provider,
|
if provider == "azure":
|
||||||
}
|
os.environ["AZURE_API_VERSION"] = "2024-05-01-preview"
|
||||||
|
model = "chatgpt-v-2"
|
||||||
|
|
||||||
if sync_mode == True:
|
if sync_mode == True:
|
||||||
assistant = litellm.create_assistants(
|
assistant = litellm.create_assistants(
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
model="gpt-4-turbo",
|
model=model,
|
||||||
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
|
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
|
||||||
name="Math Tutor",
|
name="Math Tutor",
|
||||||
tools=[{"type": "code_interpreter"}],
|
tools=[{"type": "code_interpreter"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
print("New assistants", assistant)
|
print("New assistants", assistant)
|
||||||
assert isinstance(assistant, Assistant)
|
assert isinstance(assistant, Assistant)
|
||||||
assert (
|
assert (
|
||||||
|
@ -88,14 +91,14 @@ async def test_create_delete_assistants(provider, sync_mode):
|
||||||
|
|
||||||
# delete the created assistant
|
# delete the created assistant
|
||||||
response = litellm.delete_assistant(
|
response = litellm.delete_assistant(
|
||||||
custom_llm_provider="openai", assistant_id=assistant.id
|
custom_llm_provider=provider, assistant_id=assistant.id
|
||||||
)
|
)
|
||||||
print("Response deleting assistant", response)
|
print("Response deleting assistant", response)
|
||||||
assert response.id == assistant.id
|
assert response.id == assistant.id
|
||||||
else:
|
else:
|
||||||
assistant = await litellm.acreate_assistants(
|
assistant = await litellm.acreate_assistants(
|
||||||
custom_llm_provider="openai",
|
custom_llm_provider=provider,
|
||||||
model="gpt-4-turbo",
|
model=model,
|
||||||
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
|
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
|
||||||
name="Math Tutor",
|
name="Math Tutor",
|
||||||
tools=[{"type": "code_interpreter"}],
|
tools=[{"type": "code_interpreter"}],
|
||||||
|
@ -109,7 +112,7 @@ async def test_create_delete_assistants(provider, sync_mode):
|
||||||
assert assistant.id is not None
|
assert assistant.id is not None
|
||||||
|
|
||||||
response = await litellm.adelete_assistant(
|
response = await litellm.adelete_assistant(
|
||||||
custom_llm_provider="openai", assistant_id=assistant.id
|
custom_llm_provider=provider, assistant_id=assistant.id
|
||||||
)
|
)
|
||||||
print("Response deleting assistant", response)
|
print("Response deleting assistant", response)
|
||||||
assert response.id == assistant.id
|
assert response.id == assistant.id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue