fix(replicate.py): move replicate calls to being completely async

Closes https://github.com/BerriAI/litellm/issues/3128
This commit is contained in:
Krrish Dholakia 2024-05-16 17:24:08 -07:00
parent a2a5884df1
commit 709373b15c
5 changed files with 326 additions and 59 deletions

View file

@ -2,11 +2,12 @@ import os, types
import json import json
import requests # type: ignore import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional, Union, Tuple, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
import litellm import litellm, asyncio
import httpx # type: ignore import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
class ReplicateError(Exception): class ReplicateError(Exception):
@ -145,6 +146,65 @@ def start_prediction(
) )
async def async_start_prediction(
version_id,
input_data,
api_token,
api_base,
logging_obj,
print_verbose,
http_handler: AsyncHTTPHandler,
) -> str:
base_url = api_base
if "deployments" in version_id:
print_verbose("\nLiteLLM: Request to custom replicate deployment")
version_id = version_id.replace("deployments/", "")
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
print_verbose(f"Deployment base URL: {base_url}\n")
else: # assume it's a model
base_url = f"https://api.replicate.com/v1/models/{version_id}"
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
initial_prediction_data = {
"input": input_data,
}
if ":" in version_id and len(version_id) > 64:
model_parts = version_id.split(":")
if (
len(model_parts) > 1 and len(model_parts[1]) == 64
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
initial_prediction_data["version"] = model_parts[1]
## LOGGING
logging_obj.pre_call(
input=input_data["prompt"],
api_key="",
additional_args={
"complete_input_dict": initial_prediction_data,
"headers": headers,
"api_base": base_url,
},
)
response = await http_handler.post(
url="{}/predictions".format(base_url),
data=json.dumps(initial_prediction_data),
headers=headers,
)
if response.status_code == 201:
response_data = response.json()
return response_data.get("urls", {}).get("get")
else:
raise ReplicateError(
response.status_code, f"Failed to start prediction {response.text}"
)
# Function to handle prediction response (non-streaming) # Function to handle prediction response (non-streaming)
def handle_prediction_response(prediction_url, api_token, print_verbose): def handle_prediction_response(prediction_url, api_token, print_verbose):
output_string = "" output_string = ""
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
return output_string, logs return output_string, logs
async def async_handle_prediction_response(
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
) -> Tuple[str, Any]:
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
logs = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
print_verbose(f"replicate: polling endpoint: {prediction_url}")
await asyncio.sleep(0.5)
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
if "output" in response_data:
output_string = "".join(response_data["output"])
print_verbose(f"Non-streamed output:{output_string}")
status = response_data.get("status", None)
logs = response_data.get("logs", "")
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400,
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose("Replicate: Failed to fetch prediction status and output.")
return output_string, logs
# Function to handle prediction response (streaming) # Function to handle prediction response (streaming)
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
previous_output = "" previous_output = ""
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
) )
# Function to handle prediction response (streaming)
async def async_handle_prediction_response_streaming(
prediction_url, api_token, print_verbose
):
http_handler = AsyncHTTPHandler(concurrent_limit=1)
previous_output = ""
output_string = ""
headers = {
"Authorization": f"Token {api_token}",
"Content-Type": "application/json",
}
status = ""
while True and (status not in ["succeeded", "failed", "canceled"]):
await asyncio.sleep(0.5) # prevent being rate limited by replicate
print_verbose(f"replicate: polling endpoint: {prediction_url}")
response = await http_handler.get(prediction_url, headers=headers)
if response.status_code == 200:
response_data = response.json()
status = response_data["status"]
if "output" in response_data:
output_string = "".join(response_data["output"])
new_output = output_string[len(previous_output) :]
print_verbose(f"New chunk: {new_output}")
yield {"output": new_output, "status": status}
previous_output = output_string
status = response_data["status"]
if status == "failed":
replicate_error = response_data.get("error", "")
raise ReplicateError(
status_code=400, message=f"Error: {replicate_error}"
)
else:
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
print_verbose(
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
)
# Function to extract version ID from model string # Function to extract version ID from model string
def model_to_version_id(model): def model_to_version_id(model):
if ":" in model: if ":" in model:
@ -222,6 +355,39 @@ def model_to_version_id(model):
return model return model
def process_response(
model_response: ModelResponse,
result: str,
model: str,
encoding: Any,
prompt: str,
) -> ModelResponse:
if len(result) == 0: # edge case, where result from replicate is empty
result = " "
## Building RESPONSE OBJECT
if len(result) > 1:
model_response["choices"][0]["message"]["content"] = result
# Calculate usage
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
completion_tokens = len(
encoding.encode(
model_response["choices"][0]["message"].get("content", ""),
disallowed_special=(),
)
)
model_response["model"] = "replicate/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
# Main function for prediction completion # Main function for prediction completion
def completion( def completion(
model: str, model: str,
@ -229,14 +395,15 @@ def completion(
api_base: str, api_base: str,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
optional_params: dict,
logging_obj, logging_obj,
api_key, api_key,
encoding, encoding,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): acompletion=None,
) -> Union[ModelResponse, CustomStreamWrapper]:
# Start a prediction and get the prediction URL # Start a prediction and get the prediction URL
version_id = model_to_version_id(model) version_id = model_to_version_id(model)
## Load Config ## Load Config
@ -274,6 +441,12 @@ def completion(
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
if prompt is None or not isinstance(prompt, str):
raise ReplicateError(
status_code=400,
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
)
# If system prompt is supported, and a system prompt is provided, use it # If system prompt is supported, and a system prompt is provided, use it
if system_prompt is not None: if system_prompt is not None:
input_data = { input_data = {
@ -285,6 +458,20 @@ def completion(
else: else:
input_data = {"prompt": prompt, **optional_params} input_data = {"prompt": prompt, **optional_params}
if acompletion is not None and acompletion == True:
return async_completion(
model_response=model_response,
model=model,
prompt=prompt,
encoding=encoding,
optional_params=optional_params,
version_id=version_id,
input_data=input_data,
api_key=api_key,
api_base=api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
) # type: ignore
## COMPLETION CALL ## COMPLETION CALL
## Replicate Compeltion calls have 2 steps ## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url ## Step1: Start Prediction: gets a prediction url
@ -293,6 +480,7 @@ def completion(
model_response["created"] = int( model_response["created"] = int(
time.time() time.time()
) # for pricing this must remain right before calling api ) # for pricing this must remain right before calling api
prediction_url = start_prediction( prediction_url = start_prediction(
version_id, version_id,
input_data, input_data,
@ -306,9 +494,10 @@ def completion(
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request") print_verbose("streaming request")
return handle_prediction_response_streaming( _response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose prediction_url, api_key, print_verbose
) )
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
else: else:
result, logs = handle_prediction_response( result, logs = handle_prediction_response(
prediction_url, api_key, print_verbose prediction_url, api_key, print_verbose
@ -328,29 +517,56 @@ def completion(
print_verbose(f"raw model_response: {result}") print_verbose(f"raw model_response: {result}")
if len(result) == 0: # edge case, where result from replicate is empty return process_response(
result = " " model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
)
## Building RESPONSE OBJECT
if len(result) > 1:
model_response["choices"][0]["message"]["content"] = result
# Calculate usage async def async_completion(
prompt_tokens = len(encoding.encode(prompt, disallowed_special=())) model_response: ModelResponse,
completion_tokens = len( model: str,
encoding.encode( prompt: str,
model_response["choices"][0]["message"].get("content", ""), encoding,
disallowed_special=(), optional_params: dict,
version_id,
input_data,
api_key,
api_base,
logging_obj,
print_verbose,
) -> Union[ModelResponse, CustomStreamWrapper]:
http_handler = AsyncHTTPHandler(concurrent_limit=1)
prediction_url = await async_start_prediction(
version_id,
input_data,
api_key,
api_base,
logging_obj=logging_obj,
print_verbose=print_verbose,
http_handler=http_handler,
) )
if "stream" in optional_params and optional_params["stream"] == True:
_response = async_handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
) )
model_response["model"] = "replicate/" + model return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
usage = Usage(
prompt_tokens=prompt_tokens, result, logs = await async_handle_prediction_response(
completion_tokens=completion_tokens, prediction_url, api_key, print_verbose, http_handler=http_handler
total_tokens=prompt_tokens + completion_tokens, )
return process_response(
model_response=model_response,
result=result,
model=model,
encoding=encoding,
prompt=prompt,
) )
setattr(model_response, "usage", usage)
return model_response
# # Example usage: # # Example usage:

View file

@ -320,6 +320,7 @@ async def acompletion(
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat" or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini" or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker" or custom_llm_provider == "sagemaker"
@ -1188,7 +1189,7 @@ def completion(
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = replicate.completion( model_response = replicate.completion( # type: ignore
model=model, model=model,
messages=messages, messages=messages,
api_base=api_base, api_base=api_base,
@ -1201,12 +1202,10 @@ def completion(
api_key=replicate_key, api_key=replicate_key,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion,
) )
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) == True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,

View file

@ -2301,36 +2301,28 @@ def test_completion_azure_deployment_id():
# test_completion_azure_deployment_id() # test_completion_azure_deployment_id()
# Only works for local endpoint
# def test_completion_anthropic_openai_proxy():
# try:
# response = completion(
# model="custom_openai/claude-2",
# messages=messages,
# api_base="http://0.0.0.0:8000"
# )
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_anthropic_openai_proxy() @pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_completion_replicate_llama3(sync_mode):
def test_completion_replicate_llama3():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "replicate/meta/meta-llama-3-8b-instruct" model_name = "replicate/meta/meta-llama-3-8b-instruct"
try: try:
if sync_mode:
response = completion( response = completion(
model=model_name, model=model_name,
messages=messages, messages=messages,
) )
else:
response = await litellm.acompletion(
model=model_name,
messages=messages,
)
print(f"ASYNC REPLICATE RESPONSE - {response}")
print(response) print(response)
# Add any assertions here to check the response # Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"] assert isinstance(response, litellm.ModelResponse)
print("RESPONSE STRING\n", response_str) response_format_tests(response=response)
if type(response_str) != str:
pytest.fail(f"Error occurred: {e}")
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -950,7 +950,63 @@ def test_vertex_ai_stream():
# test_completion_vertexai_stream_bad_key() # test_completion_vertexai_stream_bad_key()
# def test_completion_replicate_stream():
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_completion_replicate_llama3_streaming(sync_mode):
litellm.set_verbose = True
model_name = "replicate/meta/meta-llama-3-8b-instruct"
try:
if sync_mode:
final_chunk: Optional[litellm.ModelResponse] = None
response: litellm.CustomStreamWrapper = completion( # type: ignore
model=model_name,
messages=messages,
max_tokens=10, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
for idx, chunk in enumerate(response):
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model=model_name,
messages=messages,
max_tokens=100, # type: ignore
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
has_finish_reason = False
idx = 0
final_chunk: Optional[litellm.ModelResponse] = None
async for chunk in response:
final_chunk = chunk
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
has_finish_reason = True
break
complete_response += chunk
idx += 1
if has_finish_reason == False:
raise Exception("finish reason not set")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# TEMP Commented out - replicate throwing an auth error # TEMP Commented out - replicate throwing an auth error
# try: # try:
# litellm.set_verbose = True # litellm.set_verbose = True
@ -984,7 +1040,7 @@ def test_vertex_ai_stream():
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bedrock_cohere_command_r_streaming(sync_mode): async def test_bedrock_cohere_command_r_streaming(sync_mode):
try: try:

View file

@ -8606,7 +8606,10 @@ def exception_type(
message=f"ReplicateException - {str(original_exception)}", message=f"ReplicateException - {str(original_exception)}",
llm_provider="replicate", llm_provider="replicate",
model=model, model=model,
request=original_exception.request, request=httpx.Request(
method="POST",
url="https://api.replicate.com/v1/deployments",
),
) )
elif custom_llm_provider == "watsonx": elif custom_llm_provider == "watsonx":
if "token_quota_reached" in error_str: if "token_quota_reached" in error_str:
@ -11485,6 +11488,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "replicate"
or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase" or self.custom_llm_provider == "predibase"
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model) or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)