(Feat) add `"/v1/batches/{batch_id:path}/cancel" endpoint (#7406)

* use 1 file for azure batches handling

* add cancel_batch endpoint

* add a cancel batch on open ai

* add cancel_batch endpoint

* add cancel batches to test

* remove unused imports

* test_batches_operations

* update test_batches_operations
This commit is contained in:
Ishaan Jaff 2024-12-24 20:23:50 -08:00 committed by GitHub
parent 440009fb32
commit 54cb64d03d
7 changed files with 589 additions and 304 deletions

View file

@ -20,11 +20,16 @@ import httpx
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.azure import AzureBatchesAPI from litellm.llms.azure.batches.handler import AzureBatchesAPI
from litellm.llms.openai.openai import OpenAIBatchesAPI from litellm.llms.openai.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.utils import client, get_litellm_params, supports_httpx_timeout from litellm.utils import client, get_litellm_params, supports_httpx_timeout
@ -582,9 +587,163 @@ def list_batches(
raise e raise e
def cancel_batch(): async def acancel_batch(
pass batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
"""
Async: Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
loop = asyncio.get_event_loop()
kwargs["acancel_batch"] = True
# Use a partial function to pass your keyword arguments
func = partial(
cancel_batch,
batch_id,
custom_llm_provider,
metadata,
extra_headers,
extra_body,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
except Exception as e:
raise e
async def acancel_batch(): def cancel_batch(
pass batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
"""
Cancels a batch.
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
_cancel_batch_request = CancelBatchRequest(
batch_id=batch_id,
extra_headers=extra_headers,
extra_body=extra_body,
)
_is_async = kwargs.pop("acancel_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_batches_instance.cancel_batch(
_is_async=_is_async,
cancel_batch_data=_cancel_batch_request,
api_base=api_base,
api_key=api_key,
organization=organization,
timeout=timeout,
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
if extra_body is not None:
extra_body.pop("azure_ad_token", None)
else:
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.cancel_batch(
_is_async=_is_async,
api_base=api_base,
api_key=api_key,
api_version=api_version,
timeout=timeout,
max_retries=optional_params.max_retries,
cancel_batch_data=_cancel_batch_request,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
except Exception as e:
raise e

View file

@ -2,7 +2,7 @@ import asyncio
import json import json
import os import os
import time import time
from typing import Any, Callable, Coroutine, List, Literal, Optional, Union from typing import Any, Callable, List, Literal, Optional, Union
import httpx # type: ignore import httpx # type: ignore
from openai import AsyncAzureOpenAI, AzureOpenAI from openai import AsyncAzureOpenAI, AzureOpenAI
@ -28,13 +28,7 @@ from litellm.utils import (
modify_url, modify_url,
) )
from ...types.llms.openai import ( from ...types.llms.openai import HttpxBinaryResponseContent
Batch,
CancelBatchRequest,
CreateBatchRequest,
HttpxBinaryResponseContent,
RetrieveBatchRequest,
)
from ..base import BaseLLM from ..base import BaseLLM
from .common_utils import AzureOpenAIError, process_azure_headers from .common_utils import AzureOpenAIError, process_azure_headers
@ -1613,216 +1607,3 @@ class AzureChatCompletion(BaseLLM):
response["x-ms-region"] = completion.headers["x-ms-region"] response["x-ms-region"] = completion.headers["x-ms-region"]
return response return response
class AzureBatchesAPI(BaseLLM):
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
response = await azure_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response

View file

@ -0,0 +1,238 @@
"""
Azure Batches API Handler
"""
from typing import Any, Coroutine, Optional, Union
import httpx
import litellm
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
class AzureBatchesAPI:
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
_is_async: bool = False,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
received_args = locals()
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client" or k == "_is_async":
pass
elif k == "api_base" and v is not None:
data["azure_endpoint"] = v
elif v is not None:
data[k] = v
if "api_version" not in data:
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
if _is_async is True:
openai_client = AsyncAzureOpenAI(**data)
else:
openai_client = AzureOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: AsyncAzureOpenAI,
) -> Batch:
response = await azure_client.batches.create(**create_batch_data)
return response
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
api_version=api_version,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = azure_client.batches.create(**create_batch_data)
return response
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.retrieve(**retrieve_batch_data)
return response
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = azure_client.batches.retrieve(**retrieve_batch_data)
return response
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
client: AsyncAzureOpenAI,
) -> Batch:
response = await client.batches.cancel(**cancel_batch_data)
return response
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return response
async def alist_batches(
self,
client: AsyncAzureOpenAI,
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[AzureOpenAI] = None,
):
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
api_version=api_version,
client=client,
_is_async=_is_async,
)
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, AsyncAzureOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response

View file

@ -1799,6 +1799,15 @@ class OpenAIBatchesAPI(BaseLLM):
response = openai_client.batches.retrieve(**retrieve_batch_data) response = openai_client.batches.retrieve(**retrieve_batch_data)
return response return response
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
openai_client: AsyncOpenAI,
) -> Batch:
verbose_logger.debug("async cancelling batch, args= %s", cancel_batch_data)
response = await openai_client.batches.cancel(**cancel_batch_data)
return response
def cancel_batch( def cancel_batch(
self, self,
_is_async: bool, _is_async: bool,
@ -1823,6 +1832,16 @@ class OpenAIBatchesAPI(BaseLLM):
raise ValueError( raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
) )
if _is_async is True:
if not isinstance(openai_client, AsyncOpenAI):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acancel_batch( # type: ignore
cancel_batch_data=cancel_batch_data, openai_client=openai_client
)
response = openai_client.batches.cancel(**cancel_batch_data) response = openai_client.batches.cancel(**cancel_batch_data)
return response return response

View file

@ -11,7 +11,11 @@ from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.batches.main import CreateBatchRequest, RetrieveBatchRequest from litellm.batches.main import (
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.proxy._types import * from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
@ -353,6 +357,116 @@ async def list_batches(
raise handle_exception_on_proxy(e) raise handle_exception_on_proxy(e)
@router.post(
"/{provider}/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/v1/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
@router.post(
"/batches/{batch_id:path}/cancel",
dependencies=[Depends(user_api_key_auth)],
tags=["batch"],
)
async def cancel_batch(
request: Request,
batch_id: str,
fastapi_response: Response,
provider: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Cancel a batch.
This is the equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/cancel
Example Curl
```
curl http://localhost:4000/v1/batches/batch_abc123/cancel \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-X POST
```
"""
from litellm.proxy.proxy_server import (
add_litellm_data_to_request,
general_settings,
get_custom_headers,
proxy_config,
proxy_logging_obj,
version,
)
data: Dict = {}
try:
data = await _read_request_body(request=request)
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
custom_llm_provider = (
provider or data.pop("custom_llm_provider", None) or "openai"
)
_cancel_batch_data = CancelBatchRequest(batch_id=batch_id, **data)
response = await litellm.acancel_batch(
custom_llm_provider=custom_llm_provider, # type: ignore
**_cancel_batch_data
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format(
str(e)
)
)
raise handle_exception_on_proxy(e)
###################################################################### ######################################################################
# END OF /v1/batches Endpoints Implementation # END OF /v1/batches Endpoints Implementation

View file

@ -144,6 +144,13 @@ async def test_create_batch(provider):
with open(result_file_name, "wb") as file: with open(result_file_name, "wb") as file:
file.write(result) file.write(result)
# Cancel Batch
cancel_batch_response = await litellm.acancel_batch(
batch_id=create_batch_response.id,
custom_llm_provider=provider,
)
print("cancel_batch_response=", cancel_batch_response)
pass pass
@ -282,6 +289,13 @@ async def test_async_create_batch(provider):
with open(result_file_name, "wb") as file: with open(result_file_name, "wb") as file:
file.write(file_content.content) file.write(file_content.content)
# Cancel Batch
cancel_batch_response = await litellm.acancel_batch(
batch_id=create_batch_response.id,
custom_llm_provider=provider,
)
print("cancel_batch_response=", cancel_batch_response)
def cleanup_azure_files(): def cleanup_azure_files():
""" """
@ -301,18 +315,6 @@ def cleanup_azure_files():
assert delete_file_response.id == _file.id assert delete_file_response.id == _file.id
def test_retrieve_batch():
pass
def test_cancel_batch():
pass
def test_list_batch():
pass
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_avertex_batch_prediction(): async def test_avertex_batch_prediction():
load_vertex_ai_credentials() load_vertex_ai_credentials()

View file

@ -14,85 +14,57 @@ import time
BASE_URL = "http://localhost:4000" # Replace with your actual base URL BASE_URL = "http://localhost:4000" # Replace with your actual base URL
API_KEY = "sk-1234" # Replace with your actual API key API_KEY = "sk-1234" # Replace with your actual API key
from openai import OpenAI
async def create_batch(session, input_file_id, endpoint, completion_window): client = OpenAI(base_url=BASE_URL, api_key=API_KEY)
url = f"{BASE_URL}/v1/batches"
headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
payload = {
"input_file_id": input_file_id,
"endpoint": endpoint,
"completion_window": completion_window,
}
async with session.post(url, headers=headers, json=payload) as response:
assert response.status == 200, f"Expected status 200, got {response.status}"
result = await response.json()
print(f"Batch creation successful. Batch ID: {result.get('id', 'N/A')}")
return result
async def get_batch_by_id(session, batch_id):
url = f"{BASE_URL}/v1/batches/{batch_id}"
headers = {"Authorization": f"Bearer {API_KEY}"}
async with session.get(url, headers=headers) as response:
if response.status == 200:
result = await response.json()
return result
else:
print(f"Error: Failed to get batch. Status code: {response.status}")
return None
async def list_batches(session):
url = f"{BASE_URL}/v1/batches"
headers = {"Authorization": f"Bearer {API_KEY}"}
async with session.get(url, headers=headers) as response:
if response.status == 200:
result = await response.json()
return result
else:
print(f"Error: Failed to get batch. Status code: {response.status}")
return None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_batches_operations(): async def test_batches_operations():
async with aiohttp.ClientSession() as session: _current_dir = os.path.dirname(os.path.abspath(__file__))
# Test file upload and get file_id input_file_path = os.path.join(_current_dir, "input.jsonl")
file_id = await upload_file(session, purpose="batch") file_obj = client.files.create(
file=open(input_file_path, "rb"),
purpose="batch",
)
create_batch_response = await create_batch( batch = client.batches.create(
session, file_id, "/v1/chat/completions", "24h" input_file_id=file_obj.id,
) endpoint="/v1/chat/completions",
batch_id = create_batch_response.get("id") completion_window="24h",
assert batch_id is not None )
# Test get batch assert batch.id is not None
get_batch_response = await get_batch_by_id(session, batch_id)
print("response from get batch", get_batch_response)
assert get_batch_response["id"] == batch_id # Test get batch
assert get_batch_response["input_file_id"] == file_id _retrieved_batch = client.batches.retrieve(batch_id=batch.id)
print("response from get batch", _retrieved_batch)
# test LIST Batches assert _retrieved_batch.id == batch.id
list_batch_response = await list_batches(session) assert _retrieved_batch.input_file_id == file_obj.id
print("response from list batch", list_batch_response)
assert list_batch_response is not None # Test list batches
assert len(list_batch_response["data"]) > 0 _list_batches = client.batches.list()
print("response from list batches", _list_batches)
element_0 = list_batch_response["data"][0] assert _list_batches is not None
assert element_0["id"] is not None assert len(_list_batches.data) > 0
# Test delete file # Clean up
await delete_file(session, file_id) # Test cancel batch
_canceled_batch = client.batches.cancel(batch_id=batch.id)
print("response from cancel batch", _canceled_batch)
assert _canceled_batch.status is not None
assert (
_canceled_batch.status == "cancelling" or _canceled_batch.status == "cancelled"
)
from openai import OpenAI # finally delete the file
_deleted_file = client.files.delete(file_id=file_obj.id)
print("response from delete file", _deleted_file)
client = OpenAI(base_url=BASE_URL, api_key=API_KEY) assert _deleted_file.deleted is True
def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str: def create_batch_oai_sdk(filepath: str, custom_llm_provider: str) -> str: