forked from phoenix/litellm-mirror
feat(proxy_server.py): support azure batch api endpoints
This commit is contained in:
parent
ada426d652
commit
8625663458
6 changed files with 83 additions and 33 deletions
|
@ -4877,6 +4877,11 @@ async def run_thread(
|
|||
|
||||
|
||||
######################################################################
|
||||
@router.get(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -4890,6 +4895,7 @@ async def run_thread(
|
|||
async def create_batch(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
|
@ -4936,9 +4942,10 @@ async def create_batch(
|
|||
|
||||
_create_batch_data = CreateBatchRequest(**data)
|
||||
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.acreate_batch(
|
||||
custom_llm_provider="openai", **_create_batch_data
|
||||
custom_llm_provider=provider, **_create_batch_data
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
@ -4994,6 +5001,11 @@ async def create_batch(
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches/{batch_id:path}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -5008,6 +5020,7 @@ async def retrieve_batch(
|
|||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
provider: Optional[str] = None,
|
||||
batch_id: str = Path(
|
||||
title="Batch ID to retrieve", description="The ID of the batch to retrieve"
|
||||
),
|
||||
|
@ -5032,9 +5045,10 @@ async def retrieve_batch(
|
|||
batch_id=batch_id,
|
||||
)
|
||||
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.aretrieve_batch(
|
||||
custom_llm_provider="openai", **_retrieve_batch_request
|
||||
custom_llm_provider=provider, **_retrieve_batch_request
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
@ -5091,6 +5105,11 @@ async def retrieve_batch(
|
|||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider}/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
tags=["batch"],
|
||||
)
|
||||
@router.get(
|
||||
"/v1/batches",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
|
@ -5103,6 +5122,7 @@ async def retrieve_batch(
|
|||
)
|
||||
async def list_batches(
|
||||
fastapi_response: Response,
|
||||
provider: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
after: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
|
@ -5123,9 +5143,10 @@ async def list_batches(
|
|||
global proxy_logging_obj
|
||||
verbose_proxy_logger.debug("GET /v1/batches after={} limit={}".format(after, limit))
|
||||
try:
|
||||
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||
if provider is None:
|
||||
provider = "openai"
|
||||
response = await litellm.alist_batches(
|
||||
custom_llm_provider="openai",
|
||||
custom_llm_provider=provider,
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue