diff --git a/litellm/main.py b/litellm/main.py index d19463f53..826eb5c12 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2166,7 +2166,7 @@ def completion( """ assume input to custom LLM api bases follow this format: resp = requests.post( - api_base, + api_base, json={ 'model': 'meta-llama/Llama-2-13b-hf', # model name 'params': { @@ -2331,6 +2331,12 @@ def batch_completion( list: A list of completion results. """ args = locals() + + # extra kw for dealing with exceptions + return_exceptions = args.get("kwargs").get("return_exceptions", False) + if "return_exceptions" in args.get("kwargs"): + args.get("kwargs").pop("return_exceptions") + batch_messages = messages completions = [] model = model @@ -2384,7 +2390,16 @@ def batch_completion( completions.append(future) # Retrieve the results from the futures - results = [future.result() for future in completions] + # results = [future.result() for future in completions] + if return_exceptions: + results = [] + for future in completions: + try: + results.append(future.result()) + except Exception as exc: + results.append(exc) + else: # original + results = [future.result() for future in completions] return results diff --git a/litellm/tests/test_batch_completion_return_exceptions.py b/litellm/tests/test_batch_completion_return_exceptions.py new file mode 100644 index 000000000..df330f65d --- /dev/null +++ b/litellm/tests/test_batch_completion_return_exceptions.py @@ -0,0 +1,30 @@ +"""https://github.com/BerriAI/litellm/pull/3397/commits/a7ec1772b1457594d3af48cdcb0a382279b841c7#diff-44852387ceb00aade916d6b314dfd5d180499e54f35209ae9c07179febe08b4b.""" +"""Test batch_completion's return_exceptions.""" +import pytest +import litellm + +msg1 = [{"role": "user", "content": "hi 1"}] +msg2 = [{"role": "user", "content": "hi 2"}] + + +def test_batch_completion_return_exceptions_default(): + """Test batch_completion's return_exceptions.""" + with pytest.raises(Exception): + _ = litellm.batch_completion( + model="gpt-3.5-turbo", + messages=[msg1, msg2], + api_key="sk_xxx", # deliberately set invalid key + # return_exceptions=False, + ) + + +def test_batch_completion_return_exceptions_true(): + """Test batch_completion's return_exceptions.""" + res = litellm.batch_completion( + model="gpt-3.5-turbo", + messages=[msg1, msg2], + api_key="sk_xxx", # deliberately set invalid key + return_exceptions=True, + ) + + assert isinstance(res[0], litellm.exceptions.AuthenticationError) \ No newline at end of file