mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(feat) parallel HF text completion + completion_with_retries show exception
This commit is contained in:
parent
7219fcb968
commit
2498d95dc5
1 changed files with 20 additions and 15 deletions
|
@ -1412,8 +1412,8 @@ def completion_with_retries(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import tenacity
|
import tenacity
|
||||||
except:
|
except Exception as e:
|
||||||
raise Exception("tenacity import failed please run `pip install tenacity`")
|
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||||
|
|
||||||
num_retries = kwargs.pop("num_retries", 3)
|
num_retries = kwargs.pop("num_retries", 3)
|
||||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||||
|
@ -1989,27 +1989,32 @@ def text_completion(
|
||||||
|
|
||||||
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
||||||
if type(prompt) == list:
|
if type(prompt) == list:
|
||||||
|
import concurrent.futures
|
||||||
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
||||||
## if it's a 2d list - each element in the list is a text_completion() request
|
## if it's a 2d list - each element in the list is a text_completion() request
|
||||||
if len(prompt) > 0 and type(prompt[0]) == list:
|
if len(prompt) > 0 and type(prompt[0]) == list:
|
||||||
responses = [None for x in prompt] # init responses
|
responses = [None for x in prompt] # init responses
|
||||||
for i, individual_prompt in enumerate(prompt):
|
def process_prompt(i, individual_prompt):
|
||||||
decoded_prompt = tokenizer.decode(individual_prompt) # type: ignore
|
decoded_prompt = tokenizer.decode(individual_prompt)
|
||||||
all_params = {**kwargs, **optional_params} # combine optional params and kwargs
|
all_params = {**kwargs, **optional_params}
|
||||||
response = text_completion(
|
response = text_completion(
|
||||||
model = model, # type: ignore
|
model=model,
|
||||||
prompt = decoded_prompt, # type: ignore
|
prompt=decoded_prompt,
|
||||||
|
num_retries=3,# ensure this does not fail for the batch
|
||||||
*args,
|
*args,
|
||||||
**all_params,
|
**all_params,
|
||||||
)
|
)
|
||||||
responses[i] = response["choices"][0]
|
#print(response)
|
||||||
|
text_completion_response["id"] = response.get("id", None)
|
||||||
text_completion_response["id"] = response.get("id", None)
|
text_completion_response["object"] = "text_completion"
|
||||||
text_completion_response["object"] = "text_completion"
|
text_completion_response["created"] = response.get("created", None)
|
||||||
text_completion_response["created"] = response.get("created", None)
|
text_completion_response["model"] = response.get("model", None)
|
||||||
text_completion_response["model"] = response.get("model", None)
|
return response["choices"][0]
|
||||||
text_completion_response["choices"] = responses
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
text_completion_response["usage"] = response.get("usage", None)
|
futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)]
|
||||||
|
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||||
|
responses[i] = future.result()
|
||||||
|
text_completion_response["choices"] = responses
|
||||||
|
|
||||||
return text_completion_response
|
return text_completion_response
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue