litellm-mirror/litellm/llms/azure/batches/handler.py
Krish Dholakia 5e386c28b2
Litellm dev 03 04 2025 p3 (#8997)
* fix(core_helpers.py): handle litellm_metadata instead of 'metadata'

* feat(batches/): ensure batches logs are written to db

makes batches response dict compatible

* fix(cost_calculator.py): handle batch response being a dictionary

* fix(batches/main.py): modify retrieve endpoints to use @client decorator

enables logging to work on retrieve call

* fix(batches/main.py): fix retrieve batch response type to be 'dict' compatible

* fix(spend_tracking_utils.py): send unique uuid for retrieve batch call type

create batch and retrieve batch share the same id

* fix(spend_tracking_utils.py): prevent duplicate retrieve batch calls from being double counted

* refactor(batches/): refactor cost tracking for batches - do it on retrieve, and within the established litellm_logging pipeline

ensures cost is always logged to db

* fix: fix linting errors

* fix: fix linting error
2025-03-04 21:58:03 -08:00

241 lines
8.3 KiB
Python

"""
Azure Batches API Handler
"""
from typing import Any, Coroutine, Optional, Union, cast
import httpx
import litellm
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.utils import LiteLLMBatch
class AzureBatchesAPI:
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> LiteLLMBatch:
response = await azure_client.batches.create(**create_batch_data)
return LiteLLMBatch(**response.model_dump())
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data)
return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> LiteLLMBatch:
response = await client.batches.retrieve(**retrieve_batch_data)
return LiteLLMBatch(**response.model_dump())
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = cast(AzureOpenAI, azure_client).batches.retrieve(
**retrieve_batch_data
)
return LiteLLMBatch(**response.model_dump())
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.cancel(**cancel_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response