diff --git a/litellm/files/main.py b/litellm/files/main.py index 836f22f967..b3fbd775f6 100644 --- a/litellm/files/main.py +++ b/litellm/files/main.py @@ -14,7 +14,8 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union import httpx import litellm -from litellm import client +from litellm import client, get_secret +from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI from litellm.llms.openai import FileDeleted, FileObject, OpenAIFilesAPI from litellm.types.llms.openai import ( Batch, @@ -28,6 +29,7 @@ from litellm.utils import supports_httpx_timeout ####### ENVIRONMENT VARIABLES ################### openai_files_instance = OpenAIFilesAPI() +azure_files_instance = AzureOpenAIFilesAPI() ################################################# @@ -402,7 +404,7 @@ def file_list( async def acreate_file( file: FileTypes, purpose: Literal["assistants", "batch", "fine-tune"], - custom_llm_provider: Literal["openai"] = "openai", + custom_llm_provider: Literal["openai", "azure"] = "openai", extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, **kwargs, @@ -455,7 +457,31 @@ def create_file( LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files """ try: + _is_async = kwargs.pop("acreate_file", False) is True optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _create_file_request = CreateFileRequest( + file=file, + purpose=purpose, + extra_headers=extra_headers, + extra_body=extra_body, + ) if custom_llm_provider == "openai": # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there api_base = ( @@ -477,32 +503,6 @@ def create_file( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) - ### TIMEOUT LOGIC ### - timeout = ( - optional_params.timeout or kwargs.get("request_timeout", 600) or 600 - ) - # set timeout for 10 minutes by default - - if ( - timeout is not None - and isinstance(timeout, httpx.Timeout) - and supports_httpx_timeout(custom_llm_provider) == False - ): - read_timeout = timeout.read or 600 - timeout = read_timeout # default 10 min timeout - elif timeout is not None and not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - elif timeout is None: - timeout = 600.0 - - _create_file_request = CreateFileRequest( - file=file, - purpose=purpose, - extra_headers=extra_headers, - extra_body=extra_body, - ) - - _is_async = kwargs.pop("acreate_file", False) is True response = openai_files_instance.create_file( _is_async=_is_async, @@ -513,6 +513,38 @@ def create_file( organization=organization, create_file_data=_create_file_request, ) + elif custom_llm_provider == "azure": + api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + extra_body = optional_params.get("extra_body", {}) + azure_ad_token: Optional[str] = None + if extra_body is not None: + azure_ad_token = extra_body.pop("azure_ad_token", None) + else: + azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore + + response = azure_files_instance.create_file( + _is_async=_is_async, + api_base=api_base, + api_key=api_key, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + create_file_data=_create_file_request, + ) else: raise litellm.exceptions.BadRequestError( message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( diff --git a/litellm/llms/files_apis/azure.py b/litellm/llms/files_apis/azure.py new file mode 100644 index 0000000000..d46d9225a4 --- /dev/null +++ b/litellm/llms/files_apis/azure.py @@ -0,0 +1,315 @@ +from typing import Any, Coroutine, Dict, List, Optional, Union + +import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI +from openai.types.file_deleted import FileDeleted + +import litellm +from litellm._logging import verbose_logger +from litellm.llms.base import BaseLLM +from litellm.types.llms.openai import * + + +def get_azure_openai_client( + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + api_version: Optional[str] = None, + organization: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + _is_async: bool = False, +) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: + received_args = locals() + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + if "api_version" not in data: + data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION + if _is_async is True: + openai_client = AsyncAzureOpenAI(**data) + else: + openai_client = AzureOpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + +class AzureOpenAIFilesAPI(BaseLLM): + """ + AzureOpenAI methods to support for batches + - create_file() + - retrieve_file() + - list_files() + - delete_file() + - file_content() + - update_file() + """ + + def __init__(self) -> None: + super().__init__() + + async def acreate_file( + self, + create_file_data: CreateFileRequest, + openai_client: AsyncAzureOpenAI, + ) -> FileObject: + verbose_logger.debug("create_file_data=%s", create_file_data) + response = await openai_client.files.create(**create_file_data) + verbose_logger.debug("create_file_response=%s", response) + return response + + def create_file( + self, + _is_async: bool, + create_file_data: CreateFileRequest, + api_base: str, + api_key: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.acreate_file( # type: ignore + create_file_data=create_file_data, openai_client=openai_client + ) + response = openai_client.files.create(**create_file_data) + return response + + async def afile_content( + self, + file_content_request: FileContentRequest, + openai_client: AsyncAzureOpenAI, + ) -> HttpxBinaryResponseContent: + response = await openai_client.files.content(**file_content_request) + return response + + def file_content( + self, + _is_async: bool, + file_content_request: FileContentRequest, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ) -> Union[ + HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] + ]: + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + api_version=api_version, + max_retries=max_retries, + organization=organization, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.afile_content( # type: ignore + file_content_request=file_content_request, + openai_client=openai_client, + ) + response = openai_client.files.content(**file_content_request) + + return response + + async def aretrieve_file( + self, + file_id: str, + openai_client: AsyncAzureOpenAI, + ) -> FileObject: + response = await openai_client.files.retrieve(file_id=file_id) + return response + + def retrieve_file( + self, + _is_async: bool, + file_id: str, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.aretrieve_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.retrieve(file_id=file_id) + + return response + + async def adelete_file( + self, + file_id: str, + openai_client: AsyncAzureOpenAI, + ) -> FileDeleted: + response = await openai_client.files.delete(file_id=file_id) + return response + + def delete_file( + self, + _is_async: bool, + file_id: str, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.adelete_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.delete(file_id=file_id) + + return response + + async def alist_files( + self, + openai_client: AsyncAzureOpenAI, + purpose: Optional[str] = None, + ): + if isinstance(purpose, str): + response = await openai_client.files.list(purpose=purpose) + else: + response = await openai_client.files.list() + return response + + def list_files( + self, + _is_async: bool, + api_base: str, + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + purpose: Optional[str] = None, + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.alist_files( # type: ignore + purpose=purpose, + openai_client=openai_client, + ) + + if isinstance(purpose, str): + response = openai_client.files.list(purpose=purpose) + else: + response = openai_client.files.list() + + return response