diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 3bd1d07a47..3963a4e114 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -216,8 +216,91 @@ def create_batch( raise e -def retrieve_batch(): - pass +def retrieve_batch( + batch_id: str, + custom_llm_provider: Literal["openai"] = "openai", + metadata: Optional[Dict[str, str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + extra_body: Optional[Dict[str, str]] = None, + **kwargs, +): + """ + Retrieves a batch. + + LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id} + """ + try: + optional_params = GenericLiteLLMParams(**kwargs) + if custom_llm_provider == "openai": + + # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + api_base = ( + optional_params.api_base + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + ### TIMEOUT LOGIC ### + timeout = ( + optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + ) + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + _retrieve_batch_request = RetrieveBatchRequest( + batch_id=batch_id, + extra_headers=extra_headers, + extra_body=extra_body, + ) + + response = openai_batches_instance.retrieve_batch( + retrieve_batch_data=_retrieve_batch_request, + api_base=api_base, + api_key=api_key, + organization=organization, + timeout=timeout, + max_retries=optional_params.max_retries, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + except Exception as e: + raise e def cancel_batch(): diff --git a/litellm/tests/test_openai_batches.py b/litellm/tests/test_openai_batches.py index b99991baba..fc797635b0 100644 --- a/litellm/tests/test_openai_batches.py +++ b/litellm/tests/test_openai_batches.py @@ -14,6 +14,7 @@ from litellm import ( create_batch, create_file, ) +import time def test_create_batch(): @@ -34,7 +35,7 @@ def test_create_batch(): batch_input_file_id is not None ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" - response = litellm.create_batch( + create_batch_response = litellm.create_batch( completion_window="24h", endpoint="/v1/chat/completions", input_file_id=batch_input_file_id, @@ -42,17 +43,28 @@ def test_create_batch(): metadata={"key1": "value1", "key2": "value2"}, ) - print("response from litellm.create_batch=", response) + print("response from litellm.create_batch=", create_batch_response) assert ( - response.id is not None - ), f"Failed to create batch, expected a non null batch_id but got {response.id}" + create_batch_response.id is not None + ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" assert ( - response.endpoint == "/v1/chat/completions" - ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {response.endpoint}" + create_batch_response.endpoint == "/v1/chat/completions" + ), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}" assert ( - response.input_file_id == batch_input_file_id - ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {response.input_file_id}" + create_batch_response.input_file_id == batch_input_file_id + ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" + + time.sleep(30) + + retrieved_batch = litellm.retrieve_batch( + batch_id=create_batch_response.id, custom_llm_provider="openai" + ) + print("retrieved batch=", retrieved_batch) + # just assert that we retrieved a non None batch + + assert retrieved_batch.id == create_batch_response.id + pass