diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f74e53f947..eaf53f0bd7 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -149,6 +149,7 @@ from fastapi import ( Request, HTTPException, status, + Path, Depends, Header, Response, @@ -5179,6 +5180,148 @@ async def create_batch( ) +@router.get( + "/v1/batches{batch_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["Batch"], +) +@router.get( + "/batches{batch_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["Batch"], +) +async def retrieve_batch( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + batch_id: str = Path( + title="Batch ID to retrieve", description="The ID of the batch to retrieve" + ), +): + """ + Retrieves a batch. + This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id} + Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve + + Example Curl + ``` + curl http://localhost:4000/v1/batches/batch_abc123 \ + -H "Authorization: Bearer sk-1234" \ + -H "Content-Type: application/json" \ + + ``` + """ + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + form_data = await request.form() + data = {key: value for key, value in form_data.items() if key != "file"} + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + _retrieve_batch_request = RetrieveBatchRequest( + batch_id=batch_id, + ) + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + response = await litellm.aretrieve_batch( + custom_llm_provider="openai", **_retrieve_batch_request + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + ###################################################################### # END OF /v1/batches Endpoints Implementation