Add return_exceptions to batch_completion (retry)

This commit is contained in:
ffreemt 2024-05-05 13:11:21 +08:00
parent e5df63e015
commit 0b408ba6f4
2 changed files with 47 additions and 2 deletions

View file

@ -2166,7 +2166,7 @@ def completion(
""" """
assume input to custom LLM api bases follow this format: assume input to custom LLM api bases follow this format:
resp = requests.post( resp = requests.post(
api_base, api_base,
json={ json={
'model': 'meta-llama/Llama-2-13b-hf', # model name 'model': 'meta-llama/Llama-2-13b-hf', # model name
'params': { 'params': {
@ -2331,6 +2331,12 @@ def batch_completion(
list: A list of completion results. list: A list of completion results.
""" """
args = locals() args = locals()
# extra kw for dealing with exceptions
return_exceptions = args.get("kwargs").get("return_exceptions", False)
if "return_exceptions" in args.get("kwargs"):
args.get("kwargs").pop("return_exceptions")
batch_messages = messages batch_messages = messages
completions = [] completions = []
model = model model = model
@ -2384,7 +2390,16 @@ def batch_completion(
completions.append(future) completions.append(future)
# Retrieve the results from the futures # Retrieve the results from the futures
results = [future.result() for future in completions] # results = [future.result() for future in completions]
if return_exceptions:
results = []
for future in completions:
try:
results.append(future.result())
except Exception as exc:
results.append(exc)
else: # original
results = [future.result() for future in completions]
return results return results

View file

@ -0,0 +1,30 @@
"""https://github.com/BerriAI/litellm/pull/3397/commits/a7ec1772b1457594d3af48cdcb0a382279b841c7#diff-44852387ceb00aade916d6b314dfd5d180499e54f35209ae9c07179febe08b4b."""
"""Test batch_completion's return_exceptions."""
import pytest
import litellm
msg1 = [{"role": "user", "content": "hi 1"}]
msg2 = [{"role": "user", "content": "hi 2"}]
def test_batch_completion_return_exceptions_default():
"""Test batch_completion's return_exceptions."""
with pytest.raises(Exception):
_ = litellm.batch_completion(
model="gpt-3.5-turbo",
messages=[msg1, msg2],
api_key="sk_xxx", # deliberately set invalid key
# return_exceptions=False,
)
def test_batch_completion_return_exceptions_true():
"""Test batch_completion's return_exceptions."""
res = litellm.batch_completion(
model="gpt-3.5-turbo",
messages=[msg1, msg2],
api_key="sk_xxx", # deliberately set invalid key
return_exceptions=True,
)
assert isinstance(res[0], litellm.exceptions.AuthenticationError)