diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 1c50474ccb..c7e524f2b0 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -20,11 +20,16 @@ import httpx import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.llms.azure.azure import AzureBatchesAPI +from litellm.llms.azure.batches.handler import AzureBatchesAPI from litellm.llms.openai.openai import OpenAIBatchesAPI from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest +from litellm.types.llms.openai import ( + Batch, + CancelBatchRequest, + CreateBatchRequest, + RetrieveBatchRequest, +) from litellm.types.router import GenericLiteLLMParams from litellm.utils import client, get_litellm_params, supports_httpx_timeout @@ -582,9 +587,163 @@ def list_batches( raise e -def cancel_batch(): - pass +async def acancel_batch( + batch_id: str, + custom_llm_provider: Literal["openai", "azure"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Batch: + """ + Async: Cancels a batch. + + LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel + """ + try: + loop = asyncio.get_event_loop() + kwargs["acancel_batch"] = True + + # Use a partial function to pass your keyword arguments + func = partial( + cancel_batch, + batch_id, + custom_llm_provider, + metadata, + extra_headers, + extra_body, + **kwargs, + ) + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + + return response + except Exception as e: + raise e -async def acancel_batch(): - pass +def cancel_batch( + batch_id: str, + custom_llm_provider: Literal["openai", "azure"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +) -> Union[Batch, Coroutine[Any, Any, Batch]]: + """ + Cancels a batch. + + LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) is False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _cancel_batch_request = CancelBatchRequest( + batch_id=batch_id, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + _is_async = kwargs.pop("acancel_batch", False) is True + api_base: Optional[str] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None + ) + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + + response = openai_batches_instance.cancel_batch( + _is_async=_is_async, + cancel_batch_data=_cancel_batch_request, + api_base=api_base, + api_key=api_key, + organization=organization, + timeout=timeout, + max_retries=optional_params.max_retries, + ) + elif custom_llm_provider == "azure": + api_base = ( + optional_params.api_base + or litellm.api_base + or get_secret_str("AZURE_API_BASE") + ) + api_version = ( + optional_params.api_version + or litellm.api_version + or get_secret_str("AZURE_API_VERSION") + ) + + api_key = ( + optional_params.api_key + or litellm.api_key + or litellm.azure_key + or get_secret_str("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") + ) + + extra_body = optional_params.get("extra_body", {}) + if extra_body is not None: + extra_body.pop("azure_ad_token", None) + else: + get_secret_str("AZURE_AD_TOKEN") # type: ignore + + response = azure_batches_instance.cancel_batch( + _is_async=_is_async, + api_base=api_base, + api_key=api_key, + api_version=api_version, + timeout=timeout, + max_retries=optional_params.max_retries, + cancel_batch_data=_cancel_batch_request, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index b38c7abbcb..837d425b82 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -2,7 +2,7 @@ import asyncio import json import os import time -from typing import Any, Callable, Coroutine, List, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import httpx # type: ignore from openai import AsyncAzureOpenAI, AzureOpenAI @@ -28,13 +28,7 @@ from litellm.utils import ( modify_url, ) -from ...types.llms.openai import ( - Batch, - CancelBatchRequest, - CreateBatchRequest, - HttpxBinaryResponseContent, - RetrieveBatchRequest, -) +from ...types.llms.openai import HttpxBinaryResponseContent from ..base import BaseLLM from .common_utils import AzureOpenAIError, process_azure_headers @@ -1613,216 +1607,3 @@ class AzureChatCompletion(BaseLLM): response["x-ms-region"] = completion.headers["x-ms-region"] return response - - -class AzureBatchesAPI(BaseLLM): - """ - Azure methods to support for batches - - create_batch() - - retrieve_batch() - - cancel_batch() - - list_batch() - """ - - def __init__(self) -> None: - super().__init__() - - def get_azure_openai_client( - self, - api_key: Optional[str], - api_base: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - api_version: Optional[str] = None, - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - _is_async: bool = False, - ) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: - received_args = locals() - openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None - if client is None: - data = {} - for k, v in received_args.items(): - if k == "self" or k == "client" or k == "_is_async": - pass - elif k == "api_base" and v is not None: - data["azure_endpoint"] = v - elif v is not None: - data[k] = v - if "api_version" not in data: - data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION - if _is_async is True: - openai_client = AsyncAzureOpenAI(**data) - else: - openai_client = AzureOpenAI(**data) # type: ignore - else: - openai_client = client - - return openai_client - - async def acreate_batch( - self, - create_batch_data: CreateBatchRequest, - azure_client: AsyncAzureOpenAI, - ) -> Batch: - response = await azure_client.batches.create(**create_batch_data) - return response - - def create_batch( - self, - _is_async: bool, - create_batch_data: CreateBatchRequest, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, - ) -> Union[Batch, Coroutine[Any, Any, Batch]]: - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - timeout=timeout, - api_version=api_version, - max_retries=max_retries, - client=client, - _is_async=_is_async, - ) - ) - if azure_client is None: - raise ValueError( - "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(azure_client, AsyncAzureOpenAI): - raise ValueError( - "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." - ) - return self.acreate_batch( # type: ignore - create_batch_data=create_batch_data, azure_client=azure_client - ) - response = azure_client.batches.create(**create_batch_data) - return response - - async def aretrieve_batch( - self, - retrieve_batch_data: RetrieveBatchRequest, - client: AsyncAzureOpenAI, - ) -> Batch: - response = await client.batches.retrieve(**retrieve_batch_data) - return response - - def retrieve_batch( - self, - _is_async: bool, - retrieve_batch_data: RetrieveBatchRequest, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - client: Optional[AzureOpenAI] = None, - ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - api_version=api_version, - timeout=timeout, - max_retries=max_retries, - client=client, - _is_async=_is_async, - ) - ) - if azure_client is None: - raise ValueError( - "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(azure_client, AsyncAzureOpenAI): - raise ValueError( - "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." - ) - return self.aretrieve_batch( # type: ignore - retrieve_batch_data=retrieve_batch_data, client=azure_client - ) - response = azure_client.batches.retrieve(**retrieve_batch_data) - return response - - def cancel_batch( - self, - _is_async: bool, - cancel_batch_data: CancelBatchRequest, - api_key: Optional[str], - api_base: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - organization: Optional[str], - client: Optional[AzureOpenAI] = None, - ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - client=client, - _is_async=_is_async, - ) - ) - if azure_client is None: - raise ValueError( - "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - response = azure_client.batches.cancel(**cancel_batch_data) - return response - - async def alist_batches( - self, - client: AsyncAzureOpenAI, - after: Optional[str] = None, - limit: Optional[int] = None, - ): - response = await client.batches.list(after=after, limit=limit) # type: ignore - return response - - def list_batches( - self, - _is_async: bool, - api_key: Optional[str], - api_base: Optional[str], - api_version: Optional[str], - timeout: Union[float, httpx.Timeout], - max_retries: Optional[int], - after: Optional[str] = None, - limit: Optional[int] = None, - client: Optional[AzureOpenAI] = None, - ): - azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( - self.get_azure_openai_client( - api_key=api_key, - api_base=api_base, - timeout=timeout, - max_retries=max_retries, - api_version=api_version, - client=client, - _is_async=_is_async, - ) - ) - if azure_client is None: - raise ValueError( - "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." - ) - - if _is_async is True: - if not isinstance(azure_client, AsyncAzureOpenAI): - raise ValueError( - "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." - ) - return self.alist_batches( # type: ignore - client=azure_client, after=after, limit=limit - ) - response = azure_client.batches.list(after=after, limit=limit) # type: ignore - return response diff --git a/litellm/llms/azure/batches/handler.py b/litellm/llms/azure/batches/handler.py new file mode 100644 index 0000000000..5fae527670 --- /dev/null +++ b/litellm/llms/azure/batches/handler.py @@ -0,0 +1,238 @@ +""" +Azure Batches API Handler +""" + +from typing import Any, Coroutine, Optional, Union + +import httpx + +import litellm +from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI +from litellm.types.llms.openai import ( + Batch, + CancelBatchRequest, + CreateBatchRequest, + RetrieveBatchRequest, +) + + +class AzureBatchesAPI: + """ + Azure methods to support for batches + - create_batch() + - retrieve_batch() + - cancel_batch() + - list_batch() + """ + + def __init__(self) -> None: + super().__init__() + + def get_azure_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + _is_async: bool = False, + ) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: + received_args = locals() + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client" or k == "_is_async": + pass + elif k == "api_base" and v is not None: + data["azure_endpoint"] = v + elif v is not None: + data[k] = v + if "api_version" not in data: + data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION + if _is_async is True: + openai_client = AsyncAzureOpenAI(**data) + else: + openai_client = AzureOpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + + async def acreate_batch( + self, + create_batch_data: CreateBatchRequest, + azure_client: AsyncAzureOpenAI, + ) -> Batch: + response = await azure_client.batches.create(**create_batch_data) + return response + + def create_batch( + self, + _is_async: bool, + create_batch_data: CreateBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ) -> Union[Batch, Coroutine[Any, Any, Batch]]: + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + api_version=api_version, + max_retries=max_retries, + client=client, + _is_async=_is_async, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_batch( # type: ignore + create_batch_data=create_batch_data, azure_client=azure_client + ) + response = azure_client.batches.create(**create_batch_data) + return response + + async def aretrieve_batch( + self, + retrieve_batch_data: RetrieveBatchRequest, + client: AsyncAzureOpenAI, + ) -> Batch: + response = await client.batches.retrieve(**retrieve_batch_data) + return response + + def retrieve_batch( + self, + _is_async: bool, + retrieve_batch_data: RetrieveBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + client=client, + _is_async=_is_async, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_batch( # type: ignore + retrieve_batch_data=retrieve_batch_data, client=azure_client + ) + response = azure_client.batches.retrieve(**retrieve_batch_data) + return response + + async def acancel_batch( + self, + cancel_batch_data: CancelBatchRequest, + client: AsyncAzureOpenAI, + ) -> Batch: + response = await client.batches.cancel(**cancel_batch_data) + return response + + def cancel_batch( + self, + _is_async: bool, + cancel_batch_data: CancelBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + max_retries=max_retries, + client=client, + _is_async=_is_async, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + response = azure_client.batches.cancel(**cancel_batch_data) + return response + + async def alist_batches( + self, + client: AsyncAzureOpenAI, + after: Optional[str] = None, + limit: Optional[int] = None, + ): + response = await client.batches.list(after=after, limit=limit) # type: ignore + return response + + def list_batches( + self, + _is_async: bool, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + after: Optional[str] = None, + limit: Optional[int] = None, + client: Optional[AzureOpenAI] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.alist_batches( # type: ignore + client=azure_client, after=after, limit=limit + ) + response = azure_client.batches.list(after=after, limit=limit) # type: ignore + return response diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index e73c1d55ec..c7170601fd 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -1799,6 +1799,15 @@ class OpenAIBatchesAPI(BaseLLM): response = openai_client.batches.retrieve(**retrieve_batch_data) return response + async def acancel_batch( + self, + cancel_batch_data: CancelBatchRequest, + openai_client: AsyncOpenAI, + ) -> Batch: + verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data) + response = await openai_client.batches.cancel(**cancel_batch_data) + return response + def cancel_batch( self, _is_async: bool, @@ -1823,6 +1832,16 @@ class OpenAIBatchesAPI(BaseLLM): raise ValueError( "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." ) + + if _is_async is True: + if not isinstance(openai_client, AsyncOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acancel_batch( # type: ignore + cancel_batch_data=cancel_batch_data, openai_client=openai_client + ) + response = openai_client.batches.cancel(**cancel_batch_data) return response diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 96eabefcc3..118f9964e9 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -11,7 +11,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response import litellm from litellm._logging import verbose_proxy_logger -from litellm.batches.main import CreateBatchRequest, RetrieveBatchRequest +from litellm.batches.main import ( + CancelBatchRequest, + CreateBatchRequest, + RetrieveBatchRequest, +) from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.http_parsing_utils import _read_request_body @@ -353,6 +357,116 @@ async def list_batches( raise handle_exception_on_proxy(e) +@router.post( + "/{provider}/v1/batches/{batch_id:path}/cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +@router.post( + "/v1/batches/{batch_id:path}/cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +@router.post( + "/batches/{batch_id:path}/cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["batch"], +) +async def cancel_batch( + request: Request, + batch_id: str, + fastapi_response: Response, + provider: Optional[str] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Cancel a batch. + This is the equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel + + Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/cancel + + Example Curl + ``` + curl http://localhost:4000/v1/batches/batch_abc123/cancel \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + -X POST + + ``` + """ + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + get_custom_headers, + proxy_config, + proxy_logging_obj, + version, + ) + + data: Dict = {} + try: + data = await _read_request_body(request=request) + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + ) + + # Include original request and headers in the data + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + custom_llm_provider = ( + provider or data.pop("custom_llm_provider", None) or "openai" + ) + _cancel_batch_data = CancelBatchRequest(batch_id=batch_id, **data) + response = await litellm.acancel_batch( + custom_llm_provider=custom_llm_provider, # type: ignore + **_cancel_batch_data + ) + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + request_data=data, + ) + ) + + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format( + str(e) + ) + ) + raise handle_exception_on_proxy(e) + + ###################################################################### # END OF /v1/batches Endpoints Implementation diff --git a/tests/batches_tests/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py index 78867458a6..1bdba0da57 100644 --- a/tests/batches_tests/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -144,6 +144,13 @@ async def test_create_batch(provider): with open(result_file_name, "wb") as file: file.write(result) + # Cancel Batch + cancel_batch_response = await litellm.acancel_batch( + batch_id=create_batch_response.id, + custom_llm_provider=provider, + ) + print("cancel_batch_response=", cancel_batch_response) + pass @@ -282,6 +289,13 @@ async def test_async_create_batch(provider): with open(result_file_name, "wb") as file: file.write(file_content.content) + # Cancel Batch + cancel_batch_response = await litellm.acancel_batch( + batch_id=create_batch_response.id, + custom_llm_provider=provider, + ) + print("cancel_batch_response=", cancel_batch_response) + def cleanup_azure_files(): """ @@ -301,18 +315,6 @@ def cleanup_azure_files(): assert delete_file_response.id == _file.id -def test_retrieve_batch(): - pass - - -def test_cancel_batch(): - pass - - -def test_list_batch(): - pass - - @pytest.mark.asyncio async def test_avertex_batch_prediction(): load_vertex_ai_credentials() diff --git a/tests/openai_misc_endpoints_tests/test_openai_batches_endpoint.py b/tests/openai_misc_endpoints_tests/test_openai_batches_endpoint.py index d8adc1fac5..84f09c34cf 100644 --- a/tests/openai_misc_endpoints_tests/test_openai_batches_endpoint.py +++ b/tests/openai_misc_endpoints_tests/test_openai_batches_endpoint.py @@ -14,85 +14,57 @@ import time BASE_URL = "http://localhost:4000" # Replace with your actual base URL API_KEY = "sk-1234" # Replace with your actual API key +from openai import OpenAI -async def create_batch(session, input_file_id, endpoint, completion_window): - url = f"{BASE_URL}/v1/batches" - headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"} - payload = { - "input_file_id": input_file_id, - "endpoint": endpoint, - "completion_window": completion_window, - } - - async with session.post(url, headers=headers, json=payload) as response: - assert response.status == 200, f"Expected status 200, got {response.status}" - result = await response.json() - print(f"Batch creation successful. Batch ID: {result.get('id', 'N/A')}") - return result - - -async def get_batch_by_id(session, batch_id): - url = f"{BASE_URL}/v1/batches/{batch_id}" - headers = {"Authorization": f"Bearer {API_KEY}"} - - async with session.get(url, headers=headers) as response: - if response.status == 200: - result = await response.json() - return result - else: - print(f"Error: Failed to get batch. Status code: {response.status}") - return None - - -async def list_batches(session): - url = f"{BASE_URL}/v1/batches" - headers = {"Authorization": f"Bearer {API_KEY}"} - - async with session.get(url, headers=headers) as response: - if response.status == 200: - result = await response.json() - return result - else: - print(f"Error: Failed to get batch. Status code: {response.status}") - return None +client = OpenAI(base_url=BASE_URL, api_key=API_KEY) @pytest.mark.asyncio async def test_batches_operations(): - async with aiohttp.ClientSession() as session: - # Test file upload and get file_id - file_id = await upload_file(session, purpose="batch") + _current_dir = os.path.dirname(os.path.abspath(__file__)) + input_file_path = os.path.join(_current_dir, "input.jsonl") + file_obj = client.files.create( + file=open(input_file_path, "rb"), + purpose="batch", + ) - create_batch_response = await create_batch( - session, file_id, "/v1/chat/completions", "24h" - ) - batch_id = create_batch_response.get("id") - assert batch_id is not None + batch = client.batches.create( + input_file_id=file_obj.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) - # Test get batch - get_batch_response = await get_batch_by_id(session, batch_id) - print("response from get batch", get_batch_response) + assert batch.id is not None - assert get_batch_response["id"] == batch_id - assert get_batch_response["input_file_id"] == file_id + # Test get batch + _retrieved_batch = client.batches.retrieve(batch_id=batch.id) + print("response from get batch", _retrieved_batch) - # test LIST Batches - list_batch_response = await list_batches(session) - print("response from list batch", list_batch_response) + assert _retrieved_batch.id == batch.id + assert _retrieved_batch.input_file_id == file_obj.id - assert list_batch_response is not None - assert len(list_batch_response["data"]) > 0 + # Test list batches + _list_batches = client.batches.list() + print("response from list batches", _list_batches) - element_0 = list_batch_response["data"][0] - assert element_0["id"] is not None + assert _list_batches is not None + assert len(_list_batches.data) > 0 - # Test delete file - await delete_file(session, file_id) + # Clean up + # Test cancel batch + _canceled_batch = client.batches.cancel(batch_id=batch.id) + print("response from cancel batch", _canceled_batch) + assert _canceled_batch.status is not None + assert ( + _canceled_batch.status == "cancelling" or _canceled_batch.status == "cancelled" + ) -from openai import OpenAI + # finally delete the file + _deleted_file = client.files.delete(file_id=file_obj.id) + print("response from delete file", _deleted_file) -client = OpenAI(base_url=BASE_URL, api_key=API_KEY) + assert _deleted_file.deleted is True def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str: