diff --git a/litellm/__init__.py b/litellm/__init__.py index cfc96ede82..a9f2fe537a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -783,7 +783,11 @@ from .llms.openai import ( MistralConfig, DeepInfraConfig, ) -from .llms.azure import AzureOpenAIConfig, AzureOpenAIError +from .llms.azure import ( + AzureOpenAIConfig, + AzureOpenAIError, + AzureOpenAIAssistantsAPIConfig, +) from .llms.watsonx import IBMWatsonXAIConfig from .main import * # type: ignore from .integrations import * diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 848a83e53c..1486b89849 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -4,21 +4,29 @@ from typing import Iterable from functools import partial import os, asyncio, contextvars import litellm -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI from litellm import client -from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider +from litellm.utils import ( + supports_httpx_timeout, + exception_type, + get_llm_provider, + get_secret, +) from ..llms.openai import OpenAIAssistantsAPI +from ..llms.azure import AzureAssistantsAPI from ..types.llms.openai import * from ..types.router import * +from .utils import get_optional_params_add_message ####### ENVIRONMENT VARIABLES ################### openai_assistants_api = OpenAIAssistantsAPI() +azure_assistants_api = AzureAssistantsAPI() ### ASSISTANTS ### async def aget_assistants( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], client: Optional[AsyncOpenAI] = None, **kwargs, ) -> AsyncCursorPage[Assistant]: @@ -55,12 +63,21 @@ async def aget_assistants( def get_assistants( - custom_llm_provider: Literal["openai"], - client: Optional[OpenAI] = None, + custom_llm_provider: Literal["openai", "azure"], + client: Optional[Any] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, **kwargs, ) -> SyncCursorPage[Assistant]: - aget_assistants = kwargs.pop("aget_assistants", None) - optional_params = GenericLiteLLMParams(**kwargs) + aget_assistants: Optional[bool] = kwargs.pop("aget_assistants", None) + if aget_assistants is not None and not isinstance(aget_assistants, bool): + raise Exception( + "Invalid value passed in for aget_assistants. Only bool or None allowed" + ) + optional_params = GenericLiteLLMParams( + api_key=api_key, api_base=api_base, api_version=api_version, **kwargs + ) ### TIMEOUT LOGIC ### timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 @@ -99,6 +116,7 @@ def get_assistants( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) + response = openai_assistants_api.get_assistants( api_base=api_base, api_key=api_key, @@ -106,7 +124,43 @@ def get_assistants( max_retries=optional_params.max_retries, organization=organization, client=client, - aget_assistants=aget_assistants, + aget_assistants=aget_assistants, # type: ignore + ) # type: ignore + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_assistants_api.get_assistants( + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + aget_assistants=aget_assistants, # type: ignore ) else: raise litellm.exceptions.BadRequestError( @@ -127,7 +181,9 @@ def get_assistants( ### THREADS ### -async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread: +async def acreate_thread( + custom_llm_provider: Literal["openai", "azure"], **kwargs +) -> Thread: loop = asyncio.get_event_loop() ### PASS ARGS TO GET ASSISTANTS ### kwargs["acreate_thread"] = True @@ -161,7 +217,7 @@ async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Th def create_thread( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, metadata: Optional[dict] = None, tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None, @@ -241,6 +297,47 @@ def create_thread( client=client, acreate_thread=acreate_thread, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + if isinstance(client, OpenAI): + client = None # only pass client if it's AzureOpenAI + + response = azure_assistants_api.create_thread( + messages=messages, + metadata=metadata, + api_base=api_base, + api_key=api_key, + azure_ad_token=azure_ad_token, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + acreate_thread=acreate_thread, + ) # type :ignore else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( @@ -254,11 +351,11 @@ def create_thread( request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) - return response + return response # type: ignore async def aget_thread( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, client: Optional[AsyncOpenAI] = None, **kwargs, @@ -296,9 +393,9 @@ async def aget_thread( def get_thread( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, - client: Optional[OpenAI] = None, + client=None, **kwargs, ) -> Thread: """Get the thread object, given a thread_id""" @@ -342,6 +439,7 @@ def get_thread( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) + response = openai_assistants_api.get_thread( thread_id=thread_id, api_base=api_base, @@ -352,6 +450,46 @@ def get_thread( client=client, aget_thread=aget_thread, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + if isinstance(client, OpenAI): + client = None # only pass client if it's AzureOpenAI + + response = azure_assistants_api.get_thread( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + azure_ad_token=azure_ad_token, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + aget_thread=aget_thread, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format( @@ -365,20 +503,20 @@ def get_thread( request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) - return response + return response # type: ignore ### MESSAGES ### async def a_add_message( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, role: Literal["user", "assistant"], content: str, attachments: Optional[List[Attachment]] = None, metadata: Optional[dict] = None, - client: Optional[AsyncOpenAI] = None, + client=None, **kwargs, ) -> OpenAIMessage: loop = asyncio.get_event_loop() @@ -425,22 +563,30 @@ async def a_add_message( def add_message( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, role: Literal["user", "assistant"], content: str, attachments: Optional[List[Attachment]] = None, metadata: Optional[dict] = None, - client: Optional[OpenAI] = None, + client=None, **kwargs, ) -> OpenAIMessage: ### COMMON OBJECTS ### a_add_message = kwargs.pop("a_add_message", None) - message_data = MessageData( + _message_data = MessageData( role=role, content=content, attachments=attachments, metadata=metadata ) optional_params = GenericLiteLLMParams(**kwargs) + message_data = get_optional_params_add_message( + role=_message_data["role"], + content=_message_data["content"], + attachments=_message_data["attachments"], + metadata=_message_data["metadata"], + custom_llm_provider=custom_llm_provider, + ) + ### TIMEOUT LOGIC ### timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 # set timeout for 10 minutes by default @@ -489,6 +635,44 @@ def add_message( client=client, a_add_message=a_add_message, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_assistants_api.add_message( + thread_id=thread_id, + message_data=message_data, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + a_add_message=a_add_message, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( @@ -503,11 +687,11 @@ def add_message( ), ) - return response + return response # type: ignore async def aget_messages( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, client: Optional[AsyncOpenAI] = None, **kwargs, @@ -552,9 +736,9 @@ async def aget_messages( def get_messages( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, - client: Optional[OpenAI] = None, + client: Optional[Any] = None, **kwargs, ) -> SyncCursorPage[OpenAIMessage]: aget_messages = kwargs.pop("aget_messages", None) @@ -607,6 +791,43 @@ def get_messages( client=client, aget_messages=aget_messages, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_assistants_api.get_messages( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + aget_messages=aget_messages, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format( @@ -621,12 +842,12 @@ def get_messages( ), ) - return response + return response # type: ignore ### RUNS ### async def arun_thread( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, assistant_id: str, additional_instructions: Optional[str] = None, @@ -635,7 +856,7 @@ async def arun_thread( model: Optional[str] = None, stream: Optional[bool] = None, tools: Optional[Iterable[AssistantToolParam]] = None, - client: Optional[AsyncOpenAI] = None, + client: Optional[Any] = None, **kwargs, ) -> Run: loop = asyncio.get_event_loop() @@ -685,7 +906,7 @@ async def arun_thread( def run_thread( - custom_llm_provider: Literal["openai"], + custom_llm_provider: Literal["openai", "azure"], thread_id: str, assistant_id: str, additional_instructions: Optional[str] = None, @@ -694,7 +915,7 @@ def run_thread( model: Optional[str] = None, stream: Optional[bool] = None, tools: Optional[Iterable[AssistantToolParam]] = None, - client: Optional[OpenAI] = None, + client: Optional[Any] = None, **kwargs, ) -> Run: """Run a given thread + assistant.""" @@ -755,6 +976,50 @@ def run_thread( client=client, arun_thread=arun_thread, ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) # type: ignore + + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_assistants_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_base=str(api_base) if api_base is not None else None, + api_key=str(api_key) if api_key is not None else None, + api_version=str(api_version) if api_version is not None else None, + azure_ad_token=str(azure_ad_token) if azure_ad_token is not None else None, + timeout=timeout, + max_retries=optional_params.max_retries, + client=client, + arun_thread=arun_thread, + ) # type: ignore else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format( @@ -768,4 +1033,4 @@ def run_thread( request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore ), ) - return response + return response # type: ignore diff --git a/litellm/assistants/utils.py b/litellm/assistants/utils.py new file mode 100644 index 0000000000..ca5a1293dc --- /dev/null +++ b/litellm/assistants/utils.py @@ -0,0 +1,158 @@ +import litellm +from typing import Optional, Union +from ..types.llms.openai import * + + +def get_optional_params_add_message( + role: Optional[str], + content: Optional[ + Union[ + str, + List[ + Union[ + MessageContentTextObject, + MessageContentImageFileObject, + MessageContentImageURLObject, + ] + ], + ] + ], + attachments: Optional[List[Attachment]], + metadata: Optional[dict], + custom_llm_provider: str, + **kwargs, +): + """ + Azure doesn't support 'attachments' for creating a message + + Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message + """ + passed_params = locals() + custom_llm_provider = passed_params.pop("custom_llm_provider") + special_params = passed_params.pop("kwargs") + for k, v in special_params.items(): + passed_params[k] = v + + default_params = { + "role": None, + "content": None, + "attachments": None, + "metadata": None, + } + + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } + optional_params = {} + + ## raise exception if non-default value passed for non-openai/azure embedding calls + def _check_valid_arg(supported_params): + if len(non_default_params.keys()) > 0: + keys = list(non_default_params.keys()) + for k in keys: + if ( + litellm.drop_params is True and k not in supported_params + ): # drop the unsupported non-default values + non_default_params.pop(k, None) + elif k not in supported_params: + raise litellm.utils.UnsupportedParamsError( + status_code=500, + message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format( + k, custom_llm_provider, supported_params + ), + ) + return non_default_params + + if custom_llm_provider == "openai": + optional_params = non_default_params + elif custom_llm_provider == "azure": + supported_params = ( + litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params() + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params( + non_default_params=non_default_params, optional_params=optional_params + ) + for k in passed_params.keys(): + if k not in default_params.keys(): + optional_params[k] = passed_params[k] + return optional_params + + +def get_optional_params_image_gen( + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, +): + # retrieve all parameters passed to the function + passed_params = locals() + custom_llm_provider = passed_params.pop("custom_llm_provider") + special_params = passed_params.pop("kwargs") + for k, v in special_params.items(): + passed_params[k] = v + + default_params = { + "n": None, + "quality": None, + "response_format": None, + "size": None, + "style": None, + "user": None, + } + + non_default_params = { + k: v + for k, v in passed_params.items() + if (k in default_params and v != default_params[k]) + } + optional_params = {} + + ## raise exception if non-default value passed for non-openai/azure embedding calls + def _check_valid_arg(supported_params): + if len(non_default_params.keys()) > 0: + keys = list(non_default_params.keys()) + for k in keys: + if ( + litellm.drop_params is True and k not in supported_params + ): # drop the unsupported non-default values + non_default_params.pop(k, None) + elif k not in supported_params: + raise UnsupportedParamsError( + status_code=500, + message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.", + ) + return non_default_params + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider in litellm.openai_compatible_providers + ): + optional_params = non_default_params + elif custom_llm_provider == "bedrock": + supported_params = ["size"] + _check_valid_arg(supported_params=supported_params) + if size is not None: + width, height = size.split("x") + optional_params["width"] = int(width) + optional_params["height"] = int(height) + elif custom_llm_provider == "vertex_ai": + supported_params = ["n"] + """ + All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + """ + _check_valid_arg(supported_params=supported_params) + if n is not None: + optional_params["sampleCount"] = int(n) + + for k in passed_params.keys(): + if k not in default_params.keys(): + optional_params[k] = passed_params[k] + return optional_params diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index a92a03803d..709385ef76 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,4 +1,5 @@ -from typing import Optional, Union, Any, Literal +from typing import Optional, Union, Any, Literal, Coroutine, Iterable +from typing_extensions import overload import types, requests from .base import BaseLLM from litellm.utils import ( @@ -19,6 +20,18 @@ from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTra from openai import AzureOpenAI, AsyncAzureOpenAI import uuid import os +from ..types.llms.openai import ( + AsyncCursorPage, + AssistantToolParam, + SyncCursorPage, + Assistant, + MessageData, + OpenAIMessage, + OpenAICreateThreadParamsMessage, + Thread, + AssistantToolParam, + Run, +) class AzureOpenAIError(Exception): @@ -199,6 +212,68 @@ class AzureOpenAIConfig: return ["europe", "sweden", "switzerland", "france", "uk"] +class AzureOpenAIAssistantsAPIConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message + """ + + def __init__( + self, + ) -> None: + pass + + def get_supported_openai_create_message_params(self): + return [ + "role", + "content", + "attachments", + "metadata", + ] + + def map_openai_params_create_message_params( + self, non_default_params: dict, optional_params: dict + ): + for param, value in non_default_params.items(): + if param == "role": + optional_params["role"] = value + if param == "metadata": + optional_params["metadata"] = value + elif param == "content": # only string accepted + if isinstance(value, str): + optional_params["content"] = value + else: + raise litellm.utils.UnsupportedParamsError( + message="Azure only accepts content as a string.", + status_code=400, + ) + elif ( + param == "attachments" + ): # this is a v2 param. Azure currently supports the old 'file_id's param + file_ids: List[str] = [] + if isinstance(value, list): + for item in value: + if "file_id" in item: + file_ids.append(item["file_id"]) + else: + if litellm.drop_params == True: + pass + else: + raise litellm.utils.UnsupportedParamsError( + message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format( + value + ), + status_code=400, + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format( + type(value), value + ), + status_code=400, + ) + return optional_params + + def select_azure_base_url_or_endpoint(azure_client_params: dict): # azure_client_params = { # "api_version": api_version, @@ -1277,3 +1352,753 @@ class AzureChatCompletion(BaseLLM): response["x-ms-region"] = completion.headers["x-ms-region"] return response + + +class AzureAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + ) -> AzureOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + azure_openai_client = AzureOpenAI(**data) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + def async_get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> AsyncAzureOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + + azure_openai_client = AsyncAzureOpenAI(**data) + # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + ### ASSISTANTS ### + + async def async_get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> AsyncCursorPage[Assistant]: + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await azure_openai_client.beta.assistants.list() + + return response + + # fmt: off + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_assistants: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: + ... + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_assistants: Optional[Literal[False]], + ) -> SyncCursorPage[Assistant]: + ... + + # fmt: on + + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_assistants=None, + ): + if aget_assistants is not None and aget_assistants == True: + return self.async_get_assistants( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + api_version=api_version, + ) + + response = azure_openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + async def a_add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> OpenAIMessage: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + # fmt: off + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + a_add_message: Literal[True], + ) -> Coroutine[None, None, OpenAIMessage]: + ... + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + a_add_message: Optional[Literal[False]], + ) -> OpenAIMessage: + ... + + # fmt: on + + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + a_add_message: Optional[bool] = None, + ): + if a_add_message is not None and a_add_message == True: + return self.a_add_message( + thread_id=thread_id, + message_data=message_data, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + async def async_get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + ) -> AsyncCursorPage[OpenAIMessage]: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + # fmt: off + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_messages: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: + ... + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_messages: Optional[Literal[False]], + ) -> SyncCursorPage[OpenAIMessage]: + ... + + # fmt: on + + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_messages=None, + ): + if aget_messages is not None and aget_messages == True: + return self.async_get_messages( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + async def async_create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = await openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + # fmt: off + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AsyncAzureOpenAI], + acreate_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AzureOpenAI], + acreate_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client=None, + acreate_thread=None, + ): + """ + Here's an example: + ``` + from litellm.llms.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + if acreate_thread is not None and acreate_thread == True: + return self.async_create_thread( + metadata=metadata, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + messages=messages, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + async def async_get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = await openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # fmt: off + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_thread=None, + ): + if aget_thread is not None and aget_thread == True: + return self.async_get_thread( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # def delete_thread(self): + # pass + + ### RUNS ### + + async def arun_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + ) -> Run: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + api_version=api_version, + azure_ad_token=azure_ad_token, + client=client, + ) + + response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP", + assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T", + additional_instructions=None, + instructions=None, + metadata=None, + model=None, + tools=None, + ) + + # response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + # thread_id=thread_id, + # assistant_id=assistant_id, + # additional_instructions=additional_instructions, + # instructions=instructions, + # metadata=metadata, + # model=model, + # tools=tools, + # ) + + return response + + # fmt: off + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + arun_thread: Literal[True], + ) -> Coroutine[None, None, Run]: + ... + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + arun_thread: Optional[Literal[False]], + ) -> Run: + ... + + # fmt: on + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + arun_thread=None, + ): + if arun_thread is not None and arun_thread == True: + return self.arun_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index f561fa3a9b..03657f0ee9 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -2088,7 +2088,7 @@ class OpenAIAssistantsAPI(BaseLLM): async def a_add_message( self, thread_id: str, - message_data: MessageData, + message_data: dict, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], @@ -2123,7 +2123,7 @@ class OpenAIAssistantsAPI(BaseLLM): def add_message( self, thread_id: str, - message_data: MessageData, + message_data: dict, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], @@ -2138,7 +2138,7 @@ class OpenAIAssistantsAPI(BaseLLM): def add_message( self, thread_id: str, - message_data: MessageData, + message_data: dict, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], @@ -2154,7 +2154,7 @@ class OpenAIAssistantsAPI(BaseLLM): def add_message( self, thread_id: str, - message_data: MessageData, + message_data: dict, api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], @@ -2552,7 +2552,7 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[AsyncOpenAI], + client, arun_thread: Literal[True], ) -> Coroutine[None, None, Run]: ... @@ -2573,7 +2573,7 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI], + client, arun_thread: Optional[Literal[False]], ) -> Run: ... diff --git a/litellm/router.py b/litellm/router.py index ad95a0a9ea..6419877c49 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1910,7 +1910,7 @@ class Router: model: Optional[str] = None, stream: Optional[bool] = None, tools: Optional[Iterable[AssistantToolParam]] = None, - client: Optional[AsyncOpenAI] = None, + client: Optional[Any] = None, **kwargs, ) -> Run: return await litellm.arun_thread( diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 1e35806890..377b9a42f3 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -17,6 +17,7 @@ from litellm.llms.openai import ( Thread, OpenAIMessage as Message, AsyncCursorPage, + SyncCursorPage, ) """ @@ -27,27 +28,43 @@ V0 Scope: """ +@pytest.mark.parametrize("provider", ["openai", "azure"]) +@pytest.mark.parametrize( + "sync_mode", + [True, False], +) @pytest.mark.asyncio -async def test_async_get_assistants(): - assistants = await litellm.aget_assistants(custom_llm_provider="openai") - assert isinstance(assistants, AsyncCursorPage) +async def test_get_assistants(provider, sync_mode): + data = { + "custom_llm_provider": provider, + } + if provider == "azure": + data["api_version"] = "2024-02-15-preview" + + if sync_mode == True: + assistants = litellm.get_assistants(**data) + assert isinstance(assistants, SyncCursorPage) + else: + assistants = await litellm.aget_assistants(**data) + assert isinstance(assistants, AsyncCursorPage) +@pytest.mark.parametrize("provider", ["openai", "azure"]) @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_create_thread_litellm(sync_mode) -> Thread: +async def test_create_thread_litellm(sync_mode, provider) -> Thread: message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + data = { + "custom_llm_provider": provider, + "message": [message], + } + if provider == "azure": + data["api_version"] = "2024-02-15-preview" if sync_mode: - new_thread = create_thread( - custom_llm_provider="openai", - messages=[message], # type: ignore - ) + new_thread = create_thread(**data) else: - new_thread = await litellm.acreate_thread( - custom_llm_provider="openai", - messages=[message], # type: ignore - ) + new_thread = await litellm.acreate_thread(**data) assert isinstance( new_thread, Thread @@ -56,26 +73,28 @@ async def test_create_thread_litellm(sync_mode) -> Thread: return new_thread +@pytest.mark.parametrize("provider", ["openai", "azure"]) @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_get_thread_litellm(sync_mode): - new_thread = test_create_thread_litellm(sync_mode) +async def test_get_thread_litellm(provider, sync_mode): + new_thread = test_create_thread_litellm(sync_mode, provider) if asyncio.iscoroutine(new_thread): _new_thread = await new_thread else: _new_thread = new_thread + data = { + "custom_llm_provider": provider, + "thread_id": _new_thread.id, + } + if provider == "azure": + data["api_version"] = "2024-02-15-preview" + if sync_mode: - received_thread = get_thread( - custom_llm_provider="openai", - thread_id=_new_thread.id, - ) + received_thread = get_thread(**data) else: - received_thread = await litellm.aget_thread( - custom_llm_provider="openai", - thread_id=_new_thread.id, - ) + received_thread = await litellm.aget_thread(**data) assert isinstance( received_thread, Thread @@ -83,11 +102,12 @@ async def test_get_thread_litellm(sync_mode): return new_thread +@pytest.mark.parametrize("provider", ["openai", "azure"]) @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_add_message_litellm(sync_mode): +async def test_add_message_litellm(sync_mode, provider): message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - new_thread = test_create_thread_litellm(sync_mode) + new_thread = test_create_thread_litellm(sync_mode, provider) if asyncio.iscoroutine(new_thread): _new_thread = await new_thread @@ -95,37 +115,38 @@ async def test_add_message_litellm(sync_mode): _new_thread = new_thread # add message to thread message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + + data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} + if provider == "azure": + data["api_version"] = "2024-02-15-preview" if sync_mode: - added_message = litellm.add_message( - thread_id=_new_thread.id, custom_llm_provider="openai", **message - ) + added_message = litellm.add_message(**data) else: - added_message = await litellm.a_add_message( - thread_id=_new_thread.id, custom_llm_provider="openai", **message - ) + added_message = await litellm.a_add_message(**data) print(f"added message: {added_message}") assert isinstance(added_message, Message) +@pytest.mark.parametrize("provider", ["openai", "azure"]) @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_run_thread_litellm(sync_mode): +async def test_run_thread_litellm(sync_mode, provider): """ - Get Assistants - Create thread - Create run w/ Assistants + Thread """ if sync_mode: - assistants = litellm.get_assistants(custom_llm_provider="openai") + assistants = litellm.get_assistants(custom_llm_provider=provider) else: - assistants = await litellm.aget_assistants(custom_llm_provider="openai") + assistants = await litellm.aget_assistants(custom_llm_provider=provider) ## get the first assistant ### assistant_id = assistants.data[0].id - new_thread = test_create_thread_litellm(sync_mode=sync_mode) + new_thread = test_create_thread_litellm(sync_mode=sync_mode, provider=provider) if asyncio.iscoroutine(new_thread): _new_thread = await new_thread @@ -137,35 +158,31 @@ async def test_run_thread_litellm(sync_mode): # add message to thread message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - if sync_mode: - added_message = litellm.add_message( - thread_id=_new_thread.id, custom_llm_provider="openai", **message - ) + data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} - run = litellm.run_thread( - custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id - ) + if sync_mode: + added_message = litellm.add_message(**data) + + run = litellm.run_thread(assistant_id=assistant_id, **data) if run.status == "completed": messages = litellm.get_messages( - thread_id=_new_thread.id, custom_llm_provider="openai" + thread_id=_new_thread.id, custom_llm_provider=provider ) assert isinstance(messages.data[0], Message) else: pytest.fail("An unexpected error occurred when running the thread") else: - added_message = await litellm.a_add_message( - thread_id=_new_thread.id, custom_llm_provider="openai", **message - ) + added_message = await litellm.a_add_message(**data) run = await litellm.arun_thread( - custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id ) if run.status == "completed": messages = await litellm.aget_messages( - thread_id=_new_thread.id, custom_llm_provider="openai" + thread_id=_new_thread.id, custom_llm_provider=provider ) assert isinstance(messages.data[0], Message) else: diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 33f3b256e1..885ed6053e 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -136,9 +136,43 @@ class Attachment(TypedDict, total=False): """The tools to add this file to.""" +class ImageFileObject(TypedDict): + file_id: Required[str] + detail: Optional[str] + + +class ImageURLObject(TypedDict): + url: Required[str] + detail: Optional[str] + + +class MessageContentTextObject(TypedDict): + type: Required[Literal["text"]] + text: str + + +class MessageContentImageFileObject(TypedDict): + type: Literal["image_file"] + image_file: ImageFileObject + + +class MessageContentImageURLObject(TypedDict): + type: Required[str] + image_url: ImageURLObject + + class MessageData(TypedDict): role: Literal["user", "assistant"] - content: str + content: Union[ + str, + List[ + Union[ + MessageContentTextObject, + MessageContentImageFileObject, + MessageContentImageURLObject, + ] + ], + ] attachments: Optional[List[Attachment]] metadata: Optional[dict]