forked from phoenix/litellm-mirror
Add return_exceptions to batch_completion (retry)
This commit is contained in:
parent
fc63c3f555
commit
2713272bba
2 changed files with 47 additions and 2 deletions
|
@ -2166,7 +2166,7 @@ def completion(
|
||||||
"""
|
"""
|
||||||
assume input to custom LLM api bases follow this format:
|
assume input to custom LLM api bases follow this format:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
api_base,
|
api_base,
|
||||||
json={
|
json={
|
||||||
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
||||||
'params': {
|
'params': {
|
||||||
|
@ -2331,6 +2331,12 @@ def batch_completion(
|
||||||
list: A list of completion results.
|
list: A list of completion results.
|
||||||
"""
|
"""
|
||||||
args = locals()
|
args = locals()
|
||||||
|
|
||||||
|
# extra kw for dealing with exceptions
|
||||||
|
return_exceptions = args.get("kwargs").get("return_exceptions", False)
|
||||||
|
if "return_exceptions" in args.get("kwargs"):
|
||||||
|
args.get("kwargs").pop("return_exceptions")
|
||||||
|
|
||||||
batch_messages = messages
|
batch_messages = messages
|
||||||
completions = []
|
completions = []
|
||||||
model = model
|
model = model
|
||||||
|
@ -2384,7 +2390,16 @@ def batch_completion(
|
||||||
completions.append(future)
|
completions.append(future)
|
||||||
|
|
||||||
# Retrieve the results from the futures
|
# Retrieve the results from the futures
|
||||||
results = [future.result() for future in completions]
|
# results = [future.result() for future in completions]
|
||||||
|
if return_exceptions:
|
||||||
|
results = []
|
||||||
|
for future in completions:
|
||||||
|
try:
|
||||||
|
results.append(future.result())
|
||||||
|
except Exception as exc:
|
||||||
|
results.append(exc)
|
||||||
|
else: # original
|
||||||
|
results = [future.result() for future in completions]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
30
litellm/tests/test_batch_completion_return_exceptions.py
Normal file
30
litellm/tests/test_batch_completion_return_exceptions.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
"""https://github.com/BerriAI/litellm/pull/3397/commits/a7ec1772b1457594d3af48cdcb0a382279b841c7#diff-44852387ceb00aade916d6b314dfd5d180499e54f35209ae9c07179febe08b4b."""
|
||||||
|
"""Test batch_completion's return_exceptions."""
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
msg1 = [{"role": "user", "content": "hi 1"}]
|
||||||
|
msg2 = [{"role": "user", "content": "hi 2"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_completion_return_exceptions_default():
|
||||||
|
"""Test batch_completion's return_exceptions."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
_ = litellm.batch_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[msg1, msg2],
|
||||||
|
api_key="sk_xxx", # deliberately set invalid key
|
||||||
|
# return_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_completion_return_exceptions_true():
|
||||||
|
"""Test batch_completion's return_exceptions."""
|
||||||
|
res = litellm.batch_completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[msg1, msg2],
|
||||||
|
api_key="sk_xxx", # deliberately set invalid key
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(res[0], litellm.exceptions.AuthenticationError)
|
Loading…
Add table
Add a link
Reference in a new issue