""" Azure Batches API Handler """ from typing import Any, Coroutine, Optional, Union, cast import httpx from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI from litellm.types.llms.openai import ( Batch, CancelBatchRequest, CreateBatchRequest, RetrieveBatchRequest, ) from litellm.types.utils import LiteLLMBatch from ..common_utils import BaseAzureLLM class AzureBatchesAPI(BaseAzureLLM): """ Azure methods to support for batches - create_batch() - retrieve_batch() - cancel_batch() - list_batch() """ def __init__(self) -> None: super().__init__() async def acreate_batch( self, create_batch_data: CreateBatchRequest, azure_client: AsyncAzureOpenAI, ) -> LiteLLMBatch: response = await azure_client.batches.create(**create_batch_data) return LiteLLMBatch(**response.model_dump()) def create_batch( self, _is_async: bool, create_batch_data: CreateBatchRequest, api_key: Optional[str], api_base: Optional[str], api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, litellm_params: Optional[dict] = None, ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: azure_client: Optional[ Union[AzureOpenAI, AsyncAzureOpenAI] ] = self.get_azure_openai_client( api_key=api_key, api_base=api_base, api_version=api_version, client=client, _is_async=_is_async, litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) if _is_async is True: if not isinstance(azure_client, AsyncAzureOpenAI): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) return self.acreate_batch( # type: ignore create_batch_data=create_batch_data, azure_client=azure_client ) response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data) return LiteLLMBatch(**response.model_dump()) async def aretrieve_batch( self, retrieve_batch_data: RetrieveBatchRequest, client: AsyncAzureOpenAI, ) -> LiteLLMBatch: response = await client.batches.retrieve(**retrieve_batch_data) return LiteLLMBatch(**response.model_dump()) def retrieve_batch( self, _is_async: bool, retrieve_batch_data: RetrieveBatchRequest, api_key: Optional[str], api_base: Optional[str], api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): azure_client: Optional[ Union[AzureOpenAI, AsyncAzureOpenAI] ] = self.get_azure_openai_client( api_key=api_key, api_base=api_base, api_version=api_version, client=client, _is_async=_is_async, litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) if _is_async is True: if not isinstance(azure_client, AsyncAzureOpenAI): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) return self.aretrieve_batch( # type: ignore retrieve_batch_data=retrieve_batch_data, client=azure_client ) response = cast(AzureOpenAI, azure_client).batches.retrieve( **retrieve_batch_data ) return LiteLLMBatch(**response.model_dump()) async def acancel_batch( self, cancel_batch_data: CancelBatchRequest, client: AsyncAzureOpenAI, ) -> Batch: response = await client.batches.cancel(**cancel_batch_data) return response def cancel_batch( self, _is_async: bool, cancel_batch_data: CancelBatchRequest, api_key: Optional[str], api_base: Optional[str], api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): azure_client: Optional[ Union[AzureOpenAI, AsyncAzureOpenAI] ] = self.get_azure_openai_client( api_key=api_key, api_base=api_base, api_version=api_version, client=client, _is_async=_is_async, litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) response = azure_client.batches.cancel(**cancel_batch_data) return response async def alist_batches( self, client: AsyncAzureOpenAI, after: Optional[str] = None, limit: Optional[int] = None, ): response = await client.batches.list(after=after, limit=limit) # type: ignore return response def list_batches( self, _is_async: bool, api_key: Optional[str], api_base: Optional[str], api_version: Optional[str], timeout: Union[float, httpx.Timeout], max_retries: Optional[int], after: Optional[str] = None, limit: Optional[int] = None, client: Optional[AzureOpenAI] = None, litellm_params: Optional[dict] = None, ): azure_client: Optional[ Union[AzureOpenAI, AsyncAzureOpenAI] ] = self.get_azure_openai_client( api_key=api_key, api_base=api_base, api_version=api_version, client=client, _is_async=_is_async, litellm_params=litellm_params or {}, ) if azure_client is None: raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) if _is_async is True: if not isinstance(azure_client, AsyncAzureOpenAI): raise ValueError( "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." ) return self.alist_batches( # type: ignore client=azure_client, after=after, limit=limit ) response = azure_client.batches.list(after=after, limit=limit) # type: ignore return response