From 7e07c37be7cf345d9fa87a27bb8750d380faf46c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 18 Sep 2024 16:27:33 -0700 Subject: [PATCH] [Feat-Proxy] Add Azure Assistants API - Create Assistant, Delete Assistant Support (#5777) * update docs to show providers * azure - move assistants in it's own file * create new azure assistants file * add azure create assistants * add test for create / delete assistants * azure add delete assistants support * docs add Azure to support providers for assistants api * fix linting errors * fix standard logging merge conflict * docs azure create assistants * fix doc --- docs/my-website/docs/assistants.md | 33 + docs/my-website/docs/batches.md | 2 +- litellm/assistants/main.py | 160 +++- litellm/llms/AzureOpenAI/assistants.py | 975 +++++++++++++++++++++++++ litellm/llms/AzureOpenAI/azure.py | 852 +-------------------- litellm/proxy/proxy_config.yaml | 22 +- litellm/tests/test_assistants.py | 25 +- 7 files changed, 1172 insertions(+), 897 deletions(-) create mode 100644 litellm/llms/AzureOpenAI/assistants.py diff --git a/docs/my-website/docs/assistants.md b/docs/my-website/docs/assistants.md index fb30a132f..5e68e8dde 100644 --- a/docs/my-website/docs/assistants.md +++ b/docs/my-website/docs/assistants.md @@ -7,6 +7,7 @@ Covers Threads, Messages, Assistants. LiteLLM currently covers: - Create Assistants +- Delete Assistants - Get Assistants - Create Thread - Get Thread @@ -14,6 +15,12 @@ LiteLLM currently covers: - Get Messages - Run Thread + +## **Supported Providers**: +- [OpenAI](#quick-start) +- [Azure OpenAI](#azure-openai) +- [OpenAI-Compatible APIs](#openai-compatible-apis) + ## Quick Start Call an existing Assistant. @@ -283,6 +290,32 @@ curl -X POST 'http://0.0.0.0:4000/threads/{thread_id}/runs' \ ## [👉 Proxy API Reference](https://litellm-api.up.railway.app/#/assistants) + +## Azure OpenAI + +**config** +```yaml +assistant_settings: + custom_llm_provider: azure + litellm_params: + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE +``` + +**curl** + +```bash +curl -X POST "http://localhost:4000/v1/assistants" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "name": "Math Tutor", + "tools": [{"type": "code_interpreter"}], + "model": "" + }' +``` + ## OpenAI-Compatible APIs To call openai-compatible Assistants API's (eg. Astra Assistants API), just add `openai/` to the model name: diff --git a/docs/my-website/docs/batches.md b/docs/my-website/docs/batches.md index 144873928..eac6a629a 100644 --- a/docs/my-website/docs/batches.md +++ b/docs/my-website/docs/batches.md @@ -5,7 +5,7 @@ import TabItem from '@theme/TabItem'; Covers Batches, Files -Supported Providers: +## **Supported Providers**: - Azure OpenAI - OpenAI diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 0ea5860ae..2d467cd71 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -13,6 +13,7 @@ from openai.types.beta.assistant_deleted import AssistantDeleted import litellm from litellm import client +from litellm.llms.AzureOpenAI import assistants from litellm.types.router import GenericLiteLLMParams from litellm.utils import ( exception_type, @@ -21,7 +22,7 @@ from litellm.utils import ( supports_httpx_timeout, ) -from ..llms.AzureOpenAI.azure import AzureAssistantsAPI +from ..llms.AzureOpenAI.assistants import AzureAssistantsAPI from ..llms.OpenAI.openai import OpenAIAssistantsAPI from ..types.llms.openai import * from ..types.router import * @@ -210,8 +211,8 @@ async def acreate_assistants( loop = asyncio.get_event_loop() ### PASS ARGS TO GET ASSISTANTS ### kwargs["async_create_assistants"] = True + model = kwargs.pop("model", None) try: - model = kwargs.pop("model", None) kwargs["client"] = client # Use a partial function to pass your keyword arguments func = partial(create_assistants, custom_llm_provider, model, **kwargs) @@ -258,7 +259,7 @@ def create_assistants( api_base: Optional[str] = None, api_version: Optional[str] = None, **kwargs, -) -> Assistant: +) -> Union[Assistant, Coroutine[Any, Any, Assistant]]: async_create_assistants: Optional[bool] = kwargs.pop( "async_create_assistants", None ) @@ -288,7 +289,20 @@ def create_assistants( elif timeout is None: timeout = 600.0 - response: Optional[Assistant] = None + create_assistant_data = { + "model": model, + "name": name, + "description": description, + "instructions": instructions, + "tools": tools, + "tool_resources": tool_resources, + "metadata": metadata, + "temperature": temperature, + "top_p": top_p, + "response_format": response_format, + } + + response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None if custom_llm_provider == "openai": api_base = ( optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there @@ -310,19 +324,6 @@ def create_assistants( or os.getenv("OPENAI_API_KEY") ) - create_assistant_data = { - "model": model, - "name": name, - "description": description, - "instructions": instructions, - "tools": tools, - "tool_resources": tool_resources, - "metadata": metadata, - "temperature": temperature, - "top_p": top_p, - "response_format": response_format, - } - response = openai_assistants_api.create_assistants( api_base=api_base, api_key=api_key, @@ -333,6 +334,46 @@ def create_assistants( client=client, async_create_assistants=async_create_assistants, # type: ignore ) # type: ignore + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + if isinstance(client, OpenAI): + client = None # only pass client if it's AzureOpenAI + + response = azure_assistants_api.create_assistants( + api_base=api_base, + api_key=api_key, + azure_ad_token=azure_ad_token, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + async_create_assistants=async_create_assistants, + create_assistant_data=create_assistant_data, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format( @@ -401,7 +442,7 @@ def delete_assistant( api_base: Optional[str] = None, api_version: Optional[str] = None, **kwargs, -) -> AssistantDeleted: +) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]: optional_params = GenericLiteLLMParams( api_key=api_key, api_base=api_base, api_version=api_version, **kwargs ) @@ -432,7 +473,9 @@ def delete_assistant( elif timeout is None: timeout = 600.0 - response: Optional[AssistantDeleted] = None + response: Optional[ + Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]] + ] = None if custom_llm_provider == "openai": api_base = ( optional_params.api_base @@ -464,6 +507,46 @@ def delete_assistant( client=client, async_delete_assistants=async_delete_assistants, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + if isinstance(client, OpenAI): + client = None # only pass client if it's AzureOpenAI + + response = azure_assistants_api.delete_assistant( + assistant_id=assistant_id, + api_base=api_base, + api_key=api_key, + azure_ad_token=azure_ad_token, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + async_delete_assistants=async_delete_assistants, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format( @@ -575,6 +658,9 @@ def create_thread( elif timeout is None: timeout = 600.0 + api_base: Optional[str] = None + api_key: Optional[str] = None + response: Optional[Thread] = None if custom_llm_provider == "openai": api_base = ( @@ -612,12 +698,6 @@ def create_thread( optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") ) # type: ignore - api_version = ( - optional_params.api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - ) # type: ignore - api_key = ( optional_params.api_key or litellm.api_key @@ -626,8 +706,14 @@ def create_thread( or get_secret("AZURE_API_KEY") ) # type: ignore + api_version: Optional[str] = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + extra_body = optional_params.get("extra_body", {}) - azure_ad_token = None + azure_ad_token: Optional[str] = None if extra_body is not None: azure_ad_token = extra_body.pop("azure_ad_token", None) else: @@ -647,7 +733,7 @@ def create_thread( max_retries=optional_params.max_retries, client=client, acreate_thread=acreate_thread, - ) # type :ignore + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( @@ -727,7 +813,8 @@ def get_thread( timeout = float(timeout) # type: ignore elif timeout is None: timeout = 600.0 - + api_base: Optional[str] = None + api_key: Optional[str] = None response: Optional[Thread] = None if custom_llm_provider == "openai": api_base = ( @@ -765,7 +852,7 @@ def get_thread( optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") ) # type: ignore - api_version = ( + api_version: Optional[str] = ( optional_params.api_version or litellm.api_version or get_secret("AZURE_API_VERSION") @@ -780,7 +867,7 @@ def get_thread( ) # type: ignore extra_body = optional_params.get("extra_body", {}) - azure_ad_token = None + azure_ad_token: Optional[str] = None if extra_body is not None: azure_ad_token = extra_body.pop("azure_ad_token", None) else: @@ -912,7 +999,8 @@ def add_message( timeout = float(timeout) # type: ignore elif timeout is None: timeout = 600.0 - + api_key: Optional[str] = None + api_base: Optional[str] = None response: Optional[OpenAIMessage] = None if custom_llm_provider == "openai": api_base = ( @@ -950,7 +1038,7 @@ def add_message( optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") ) # type: ignore - api_version = ( + api_version: Optional[str] = ( optional_params.api_version or litellm.api_version or get_secret("AZURE_API_VERSION") @@ -965,7 +1053,7 @@ def add_message( ) # type: ignore extra_body = optional_params.get("extra_body", {}) - azure_ad_token = None + azure_ad_token: Optional[str] = None if extra_body is not None: azure_ad_token = extra_body.pop("azure_ad_token", None) else: @@ -1071,6 +1159,8 @@ def get_messages( timeout = 600.0 response: Optional[SyncCursorPage[OpenAIMessage]] = None + api_key: Optional[str] = None + api_base: Optional[str] = None if custom_llm_provider == "openai": api_base = ( optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there @@ -1106,7 +1196,7 @@ def get_messages( optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") ) # type: ignore - api_version = ( + api_version: Optional[str] = ( optional_params.api_version or litellm.api_version or get_secret("AZURE_API_VERSION") @@ -1121,7 +1211,7 @@ def get_messages( ) # type: ignore extra_body = optional_params.get("extra_body", {}) - azure_ad_token = None + azure_ad_token: Optional[str] = None if extra_body is not None: azure_ad_token = extra_body.pop("azure_ad_token", None) else: diff --git a/litellm/llms/AzureOpenAI/assistants.py b/litellm/llms/AzureOpenAI/assistants.py new file mode 100644 index 000000000..a1458b1f6 --- /dev/null +++ b/litellm/llms/AzureOpenAI/assistants.py @@ -0,0 +1,975 @@ +import uuid +from typing import Any, Callable, Coroutine, Iterable, List, Literal, Optional, Union + +import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI +from typing_extensions import overload + +import litellm +from litellm.types.utils import FileTypes # type: ignore + +from ...types.llms.openai import ( + Assistant, + AssistantEventHandler, + AssistantStreamManager, + AssistantToolParam, + AsyncAssistantEventHandler, + AsyncAssistantStreamManager, + AsyncCursorPage, + OpenAICreateThreadParamsMessage, + OpenAIMessage, + Run, + SyncCursorPage, + Thread, +) +from ..base import BaseLLM + + +class AzureAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + ) -> AzureOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + azure_openai_client = AzureOpenAI(**data) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + def async_get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> AsyncAzureOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + azure_openai_client = AsyncAzureOpenAI(**data) + # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + ### ASSISTANTS ### + + async def async_get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> AsyncCursorPage[Assistant]: + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await azure_openai_client.beta.assistants.list() + + return response + + # fmt: off + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_assistants: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: + ... + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_assistants: Optional[Literal[False]], + ) -> SyncCursorPage[Assistant]: + ... + + # fmt: on + + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_assistants=None, + ): + if aget_assistants is not None and aget_assistants == True: + return self.async_get_assistants( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + api_version=api_version, + ) + + response = azure_openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + async def a_add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> OpenAIMessage: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + # fmt: off + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + a_add_message: Literal[True], + ) -> Coroutine[None, None, OpenAIMessage]: + ... + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + a_add_message: Optional[Literal[False]], + ) -> OpenAIMessage: + ... + + # fmt: on + + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + a_add_message: Optional[bool] = None, + ): + if a_add_message is not None and a_add_message == True: + return self.a_add_message( + thread_id=thread_id, + message_data=message_data, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + async def async_get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> AsyncCursorPage[OpenAIMessage]: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + # fmt: off + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_messages: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: + ... + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_messages: Optional[Literal[False]], + ) -> SyncCursorPage[OpenAIMessage]: + ... + + # fmt: on + + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_messages=None, + ): + if aget_messages is not None and aget_messages == True: + return self.async_get_messages( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + async def async_create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = await openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + # fmt: off + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AsyncAzureOpenAI], + acreate_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AzureOpenAI], + acreate_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client=None, + acreate_thread=None, + ): + """ + Here's an example: + ``` + from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + if acreate_thread is not None and acreate_thread == True: + return self.async_create_thread( + metadata=metadata, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + messages=messages, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + async def async_get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # fmt: off + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_thread=None, + ): + if aget_thread is not None and aget_thread == True: + return self.async_get_thread( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # def delete_thread(self): + # pass + + ### RUNS ### + + async def arun_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> Run: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + api_version=api_version, + azure_ad_token=azure_ad_token, + client=client, + ) + + response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response + + def async_run_thread_stream( + self, + client: AsyncAzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + def run_thread_stream( + self, + client: AzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AssistantStreamManager[AssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + # fmt: off + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + arun_thread: Literal[True], + ) -> Coroutine[None, None, Run]: + ... + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + arun_thread: Optional[Literal[False]], + ) -> Run: + ... + + # fmt: on + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + arun_thread=None, + event_handler: Optional[AssistantEventHandler] = None, + ): + if arun_thread is not None and arun_thread == True: + if stream is not None and stream == True: + azure_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + return self.async_run_thread_stream( + client=azure_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + return self.arun_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + if stream is not None and stream == True: + return self.run_thread_stream( + client=openai_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response + + # Create Assistant + async def async_create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + create_assistant_data: dict, + ) -> Assistant: + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await azure_openai_client.beta.assistants.create( + **create_assistant_data + ) + return response + + def create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + create_assistant_data: dict, + client=None, + async_create_assistants=None, + ): + if async_create_assistants is not None and async_create_assistants == True: + return self.async_create_assistants( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + create_assistant_data=create_assistant_data, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = azure_openai_client.beta.assistants.create(**create_assistant_data) + return response + + # Delete Assistant + async def async_delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + assistant_id: str, + ): + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await azure_openai_client.beta.assistants.delete( + assistant_id=assistant_id + ) + return response + + def delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + assistant_id: str, + async_delete_assistants: Optional[bool] = None, + client=None, + ): + if async_delete_assistants is not None and async_delete_assistants == True: + return self.async_delete_assistant( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + assistant_id=assistant_id, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id) + return response diff --git a/litellm/llms/AzureOpenAI/azure.py b/litellm/llms/AzureOpenAI/azure.py index b1a9e1549..be3e2cbee 100644 --- a/litellm/llms/AzureOpenAI/azure.py +++ b/litellm/llms/AzureOpenAI/azure.py @@ -17,7 +17,8 @@ from litellm import ImageResponse, OpenAIConfig from litellm.caching import DualCache from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.utils import FileTypes +from litellm.types.utils import FileTypes # type: ignore +from litellm.types.utils import EmbeddingResponse from litellm.utils import ( Choices, CustomStreamWrapper, @@ -735,6 +736,11 @@ class AzureChatCompletion(BaseLLM): azure_client._custom_query.setdefault( "api-version", api_version ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) headers, response = self.make_sync_azure_openai_chat_completion_request( azure_client=azure_client, data=data, timeout=timeout @@ -1015,12 +1021,12 @@ class AzureChatCompletion(BaseLLM): async def aembedding( self, data: dict, - model_response: ModelResponse, + model_response: EmbeddingResponse, azure_client_params: dict, api_key: str, input: list, + logging_obj: LiteLLMLoggingObj, client: Optional[AsyncAzureOpenAI] = None, - logging_obj=None, timeout=None, ): response = None @@ -1067,9 +1073,9 @@ class AzureChatCompletion(BaseLLM): api_base: str, api_version: str, timeout: float, - logging_obj=None, - model_response=None, - optional_params=None, + logging_obj: LiteLLMLoggingObj, + model_response: EmbeddingResponse, + optional_params: dict, azure_ad_token: Optional[str] = None, client=None, aembedding=None, @@ -1407,8 +1413,8 @@ class AzureChatCompletion(BaseLLM): azure_client_params: dict, api_key: str, input: list, + logging_obj: LiteLLMLoggingObj, client=None, - logging_obj=None, timeout=None, ): response: Optional[dict] = None @@ -1471,14 +1477,14 @@ class AzureChatCompletion(BaseLLM): self, prompt: str, timeout: float, + optional_params: dict, + logging_obj: LiteLLMLoggingObj, model: Optional[str] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, api_version: Optional[str] = None, model_response: Optional[litellm.utils.ImageResponse] = None, azure_ad_token: Optional[str] = None, - logging_obj=None, - optional_params=None, client=None, aimg_generation=None, ): @@ -1565,7 +1571,8 @@ class AzureChatCompletion(BaseLLM): raise e except Exception as e: if hasattr(e, "status_code"): - raise AzureOpenAIError(status_code=e.status_code, message=str(e)) + _status_code = getattr(e, "status_code") + raise AzureOpenAIError(status_code=_status_code, message=str(e)) else: raise AzureOpenAIError(status_code=500, message=str(e)) @@ -1847,831 +1854,6 @@ class AzureChatCompletion(BaseLLM): return response -class AzureAssistantsAPI(BaseLLM): - def __init__(self) -> None: - super().__init__() - - def get_azure_client( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI] = None, - ) -> AzureOpenAI: - received_args = locals() - if client is None: - data = {} - for k, v in received_args.items(): - if k == "self" or k == "client": - pass - elif k == "api_base" and v is not None: - data["azure_endpoint"] = v - elif v is not None: - data[k] = v - azure_openai_client = AzureOpenAI(**data) # type: ignore - else: - azure_openai_client = client - - return azure_openai_client - - def async_get_azure_client( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI] = None, - ) -> AsyncAzureOpenAI: - received_args = locals() - if client is None: - data = {} - for k, v in received_args.items(): - if k == "self" or k == "client": - pass - elif k == "api_base" and v is not None: - data["azure_endpoint"] = v - elif v is not None: - data[k] = v - - azure_openai_client = AsyncAzureOpenAI(**data) - # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore - else: - azure_openai_client = client - - return azure_openai_client - - ### ASSISTANTS ### - - async def async_get_assistants( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - ) -> AsyncCursorPage[Assistant]: - azure_openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - response = await azure_openai_client.beta.assistants.list() - - return response - - # fmt: off - - @overload - def get_assistants( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - aget_assistants: Literal[True], - ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: - ... - - @overload - def get_assistants( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI], - aget_assistants: Optional[Literal[False]], - ) -> SyncCursorPage[Assistant]: - ... - - # fmt: on - - def get_assistants( - self, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client=None, - aget_assistants=None, - ): - if aget_assistants is not None and aget_assistants == True: - return self.async_get_assistants( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - azure_openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - api_version=api_version, - ) - - response = azure_openai_client.beta.assistants.list() - - return response - - ### MESSAGES ### - - async def a_add_message( - self, - thread_id: str, - message_data: dict, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI] = None, - ) -> OpenAIMessage: - openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore - thread_id, **message_data # type: ignore - ) - - response_obj: Optional[OpenAIMessage] = None - if getattr(thread_message, "status", None) is None: - thread_message.status = "completed" - response_obj = OpenAIMessage(**thread_message.dict()) - else: - response_obj = OpenAIMessage(**thread_message.dict()) - return response_obj - - # fmt: off - - @overload - def add_message( - self, - thread_id: str, - message_data: dict, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - a_add_message: Literal[True], - ) -> Coroutine[None, None, OpenAIMessage]: - ... - - @overload - def add_message( - self, - thread_id: str, - message_data: dict, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI], - a_add_message: Optional[Literal[False]], - ) -> OpenAIMessage: - ... - - # fmt: on - - def add_message( - self, - thread_id: str, - message_data: dict, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client=None, - a_add_message: Optional[bool] = None, - ): - if a_add_message is not None and a_add_message == True: - return self.a_add_message( - thread_id=thread_id, - message_data=message_data, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore - thread_id, **message_data # type: ignore - ) - - response_obj: Optional[OpenAIMessage] = None - if getattr(thread_message, "status", None) is None: - thread_message.status = "completed" - response_obj = OpenAIMessage(**thread_message.dict()) - else: - response_obj = OpenAIMessage(**thread_message.dict()) - return response_obj - - async def async_get_messages( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI] = None, - ) -> AsyncCursorPage[OpenAIMessage]: - openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - response = await openai_client.beta.threads.messages.list(thread_id=thread_id) - - return response - - # fmt: off - - @overload - def get_messages( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - aget_messages: Literal[True], - ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: - ... - - @overload - def get_messages( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI], - aget_messages: Optional[Literal[False]], - ) -> SyncCursorPage[OpenAIMessage]: - ... - - # fmt: on - - def get_messages( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client=None, - aget_messages=None, - ): - if aget_messages is not None and aget_messages == True: - return self.async_get_messages( - thread_id=thread_id, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - response = openai_client.beta.threads.messages.list(thread_id=thread_id) - - return response - - ### THREADS ### - - async def async_create_thread( - self, - metadata: Optional[dict], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], - ) -> Thread: - openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - data = {} - if messages is not None: - data["messages"] = messages # type: ignore - if metadata is not None: - data["metadata"] = metadata # type: ignore - - message_thread = await openai_client.beta.threads.create(**data) # type: ignore - - return Thread(**message_thread.dict()) - - # fmt: off - - @overload - def create_thread( - self, - metadata: Optional[dict], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], - client: Optional[AsyncAzureOpenAI], - acreate_thread: Literal[True], - ) -> Coroutine[None, None, Thread]: - ... - - @overload - def create_thread( - self, - metadata: Optional[dict], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], - client: Optional[AzureOpenAI], - acreate_thread: Optional[Literal[False]], - ) -> Thread: - ... - - # fmt: on - - def create_thread( - self, - metadata: Optional[dict], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], - client=None, - acreate_thread=None, - ): - """ - Here's an example: - ``` - from litellm.llms.OpenAI.openai import OpenAIAssistantsAPI, MessageData - - # create thread - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} - openai_api.create_thread(messages=[message]) - ``` - """ - if acreate_thread is not None and acreate_thread == True: - return self.async_create_thread( - metadata=metadata, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - messages=messages, - ) - azure_openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - data = {} - if messages is not None: - data["messages"] = messages # type: ignore - if metadata is not None: - data["metadata"] = metadata # type: ignore - - message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore - - return Thread(**message_thread.dict()) - - async def async_get_thread( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - ) -> Thread: - openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - response = await openai_client.beta.threads.retrieve(thread_id=thread_id) - - return Thread(**response.dict()) - - # fmt: off - - @overload - def get_thread( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - aget_thread: Literal[True], - ) -> Coroutine[None, None, Thread]: - ... - - @overload - def get_thread( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI], - aget_thread: Optional[Literal[False]], - ) -> Thread: - ... - - # fmt: on - - def get_thread( - self, - thread_id: str, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client=None, - aget_thread=None, - ): - if aget_thread is not None and aget_thread == True: - return self.async_get_thread( - thread_id=thread_id, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - response = openai_client.beta.threads.retrieve(thread_id=thread_id) - - return Thread(**response.dict()) - - # def delete_thread(self): - # pass - - ### RUNS ### - - async def arun_thread( - self, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - stream: Optional[bool], - tools: Optional[Iterable[AssistantToolParam]], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - ) -> Run: - openai_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - api_version=api_version, - azure_ad_token=azure_ad_token, - client=client, - ) - - response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - tools=tools, - ) - - return response - - def async_run_thread_stream( - self, - client: AsyncAzureOpenAI, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - tools: Optional[Iterable[AssistantToolParam]], - event_handler: Optional[AssistantEventHandler], - ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: - data = { - "thread_id": thread_id, - "assistant_id": assistant_id, - "additional_instructions": additional_instructions, - "instructions": instructions, - "metadata": metadata, - "model": model, - "tools": tools, - } - if event_handler is not None: - data["event_handler"] = event_handler - return client.beta.threads.runs.stream(**data) # type: ignore - - def run_thread_stream( - self, - client: AzureOpenAI, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - tools: Optional[Iterable[AssistantToolParam]], - event_handler: Optional[AssistantEventHandler], - ) -> AssistantStreamManager[AssistantEventHandler]: - data = { - "thread_id": thread_id, - "assistant_id": assistant_id, - "additional_instructions": additional_instructions, - "instructions": instructions, - "metadata": metadata, - "model": model, - "tools": tools, - } - if event_handler is not None: - data["event_handler"] = event_handler - return client.beta.threads.runs.stream(**data) # type: ignore - - # fmt: off - - @overload - def run_thread( - self, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - stream: Optional[bool], - tools: Optional[Iterable[AssistantToolParam]], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AsyncAzureOpenAI], - arun_thread: Literal[True], - ) -> Coroutine[None, None, Run]: - ... - - @overload - def run_thread( - self, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - stream: Optional[bool], - tools: Optional[Iterable[AssistantToolParam]], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI], - arun_thread: Optional[Literal[False]], - ) -> Run: - ... - - # fmt: on - - def run_thread( - self, - thread_id: str, - assistant_id: str, - additional_instructions: Optional[str], - instructions: Optional[str], - metadata: Optional[object], - model: Optional[str], - stream: Optional[bool], - tools: Optional[Iterable[AssistantToolParam]], - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - azure_ad_token: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client=None, - arun_thread=None, - event_handler: Optional[AssistantEventHandler] = None, - ): - if arun_thread is not None and arun_thread == True: - if stream is not None and stream == True: - azure_client = self.async_get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - return self.async_run_thread_stream( - client=azure_client, - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - tools=tools, - event_handler=event_handler, - ) - return self.arun_thread( - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - stream=stream, - tools=tools, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - openai_client = self.get_azure_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - timeout=timeout, - max_retries=max_retries, - client=client, - ) - - if stream is not None and stream == True: - return self.run_thread_stream( - client=openai_client, - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - tools=tools, - event_handler=event_handler, - ) - - response = openai_client.beta.threads.runs.create_and_poll( # type: ignore - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=additional_instructions, - instructions=instructions, - metadata=metadata, - model=model, - tools=tools, - ) - - return response - - class AzureBatchesAPI(BaseLLM): """ Azure methods to support for batches diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 50f395f4c..431f8816b 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,19 +1,4 @@ model_list: - - model_name: gemini-vision - litellm_params: - model: vertex_ai/gemini-1.5-pro - api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" - - model_name: gemini-vision - litellm_params: - model: vertex_ai/gemini-1.0-pro-vision-001 - api_base: https://exampleopenaiendpoint-production-c715.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 - vertex_project: "adroit-crow-413218" - vertex_location: "us-central1" - vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" - - model_name: fake-azure-endpoint litellm_params: model: openai/429 @@ -21,6 +6,13 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app +assistant_settings: + custom_llm_provider: azure + litellm_params: + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE + + general_settings: master_key: sk-1234 diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 0806697d7..266bf65f4 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -59,25 +59,28 @@ async def test_get_assistants(provider, sync_mode): assert isinstance(assistants, AsyncCursorPage) -@pytest.mark.parametrize("provider", ["openai"]) +@pytest.mark.parametrize("provider", ["azure", "openai"]) @pytest.mark.parametrize( "sync_mode", [True, False], ) -@pytest.mark.asyncio +@pytest.mark.asyncio() +@pytest.mark.flaky(retries=3, delay=1) async def test_create_delete_assistants(provider, sync_mode): - data = { - "custom_llm_provider": provider, - } + model = "gpt-4-turbo" + if provider == "azure": + os.environ["AZURE_API_VERSION"] = "2024-05-01-preview" + model = "chatgpt-v-2" if sync_mode == True: assistant = litellm.create_assistants( - custom_llm_provider="openai", - model="gpt-4-turbo", + custom_llm_provider=provider, + model=model, instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", name="Math Tutor", tools=[{"type": "code_interpreter"}], ) + print("New assistants", assistant) assert isinstance(assistant, Assistant) assert ( @@ -88,14 +91,14 @@ async def test_create_delete_assistants(provider, sync_mode): # delete the created assistant response = litellm.delete_assistant( - custom_llm_provider="openai", assistant_id=assistant.id + custom_llm_provider=provider, assistant_id=assistant.id ) print("Response deleting assistant", response) assert response.id == assistant.id else: assistant = await litellm.acreate_assistants( - custom_llm_provider="openai", - model="gpt-4-turbo", + custom_llm_provider=provider, + model=model, instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", name="Math Tutor", tools=[{"type": "code_interpreter"}], @@ -109,7 +112,7 @@ async def test_create_delete_assistants(provider, sync_mode): assert assistant.id is not None response = await litellm.adelete_assistant( - custom_llm_provider="openai", assistant_id=assistant.id + custom_llm_provider=provider, assistant_id=assistant.id ) print("Response deleting assistant", response) assert response.id == assistant.id