fix(replicate.py): move replicate calls to being completely async

Closes https://github.com/BerriAI/litellm/issues/3128
This commit is contained in:
Krrish Dholakia 2024-05-16 17:24:08 -07:00
parent a2a5884df1
commit 709373b15c
5 changed files with 326 additions and 59 deletions

View file

@ -320,6 +320,7 @@ async def acompletion(
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
@ -1188,7 +1189,7 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = replicate.completion(
model_response = replicate.completion( # type: ignore
model=model,
messages=messages,
api_base=api_base,
@ -1201,12 +1202,10 @@ def completion(
api_key=replicate_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) == True:
## LOGGING
logging.post_call(
input=messages,