forked from phoenix/litellm-mirror
fix(replicate.py): move replicate calls to being completely async
Closes https://github.com/BerriAI/litellm/issues/3128
This commit is contained in:
parent
a2a5884df1
commit
709373b15c
5 changed files with 326 additions and 59 deletions
|
@ -2,11 +2,12 @@ import os, types
|
|||
import json
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
from typing import Callable, Optional, Union, Tuple, Any
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm, asyncio
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
class ReplicateError(Exception):
|
||||
|
@ -145,6 +146,65 @@ def start_prediction(
|
|||
)
|
||||
|
||||
|
||||
async def async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_token,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
http_handler: AsyncHTTPHandler,
|
||||
) -> str:
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||
version_id = version_id.replace("deployments/", "")
|
||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||
print_verbose(f"Deployment base URL: {base_url}\n")
|
||||
else: # assume it's a model
|
||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
"input": input_data,
|
||||
}
|
||||
|
||||
if ":" in version_id and len(version_id) > 64:
|
||||
model_parts = version_id.split(":")
|
||||
if (
|
||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
initial_prediction_data["version"] = model_parts[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": initial_prediction_data,
|
||||
"headers": headers,
|
||||
"api_base": base_url,
|
||||
},
|
||||
)
|
||||
|
||||
response = await http_handler.post(
|
||||
url="{}/predictions".format(base_url),
|
||||
data=json.dumps(initial_prediction_data),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 201:
|
||||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ReplicateError(
|
||||
response.status_code, f"Failed to start prediction {response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
output_string = ""
|
||||
|
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|||
return output_string, logs
|
||||
|
||||
|
||||
async def async_handle_prediction_response(
|
||||
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
|
||||
) -> Tuple[str, Any]:
|
||||
output_string = ""
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
status = ""
|
||||
logs = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
await asyncio.sleep(0.5)
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
print_verbose(f"Non-streamed output:{output_string}")
|
||||
status = response_data.get("status", None)
|
||||
logs = response_data.get("logs", "")
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
||||
return output_string, logs
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||
previous_output = ""
|
||||
|
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
async def async_handle_prediction_response_streaming(
|
||||
prediction_url, api_token, print_verbose
|
||||
):
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400, message=f"Error: {replicate_error}"
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
if ":" in model:
|
||||
|
@ -222,6 +355,39 @@ def model_to_version_id(model):
|
|||
return model
|
||||
|
||||
|
||||
def process_response(
|
||||
model_response: ModelResponse,
|
||||
result: str,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
prompt: str,
|
||||
) -> ModelResponse:
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
|
@ -229,14 +395,15 @@ def completion(
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
logging_obj,
|
||||
api_key,
|
||||
encoding,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
acompletion=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
## Load Config
|
||||
|
@ -274,6 +441,12 @@ def completion(
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if prompt is None or not isinstance(prompt, str):
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||
)
|
||||
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
|
@ -285,6 +458,20 @@ def completion(
|
|||
else:
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
if acompletion is not None and acompletion == True:
|
||||
return async_completion(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
encoding=encoding,
|
||||
optional_params=optional_params,
|
||||
version_id=version_id,
|
||||
input_data=input_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
) # type: ignore
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
## Step1: Start Prediction: gets a prediction url
|
||||
|
@ -293,6 +480,7 @@ def completion(
|
|||
model_response["created"] = int(
|
||||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
|
||||
prediction_url = start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
|
@ -306,9 +494,10 @@ def completion(
|
|||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
print_verbose("streaming request")
|
||||
return handle_prediction_response_streaming(
|
||||
_response = handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
else:
|
||||
result, logs = handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose
|
||||
|
@ -328,29 +517,56 @@ def completion(
|
|||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
prompt: str,
|
||||
encoding,
|
||||
optional_params: dict,
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
prediction_url = await async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
http_handler=http_handler,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
_response = async_handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
result, logs = await async_handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose, http_handler=http_handler
|
||||
)
|
||||
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# # Example usage:
|
||||
|
|
|
@ -320,6 +320,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "ollama_chat"
|
||||
or custom_llm_provider == "replicate"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
|
@ -1188,7 +1189,7 @@ def completion(
|
|||
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
model_response = replicate.completion(
|
||||
model_response = replicate.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -1201,12 +1202,10 @@ def completion(
|
|||
api_key=replicate_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
|
|
@ -2301,36 +2301,28 @@ def test_completion_azure_deployment_id():
|
|||
|
||||
# test_completion_azure_deployment_id()
|
||||
|
||||
# Only works for local endpoint
|
||||
# def test_completion_anthropic_openai_proxy():
|
||||
# try:
|
||||
# response = completion(
|
||||
# model="custom_openai/claude-2",
|
||||
# messages=messages,
|
||||
# api_base="http://0.0.0.0:8000"
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_anthropic_openai_proxy()
|
||||
|
||||
|
||||
def test_completion_replicate_llama3():
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_replicate_llama3(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||
try:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
)
|
||||
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
||||
print(response)
|
||||
# Add any assertions here to check the response
|
||||
response_str = response["choices"][0]["message"]["content"]
|
||||
print("RESPONSE STRING\n", response_str)
|
||||
if type(response_str) != str:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
response_format_tests(response=response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -950,7 +950,63 @@ def test_vertex_ai_stream():
|
|||
|
||||
# test_completion_vertexai_stream_bad_key()
|
||||
|
||||
# def test_completion_replicate_stream():
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||
try:
|
||||
if sync_mode:
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=100, # type: ignore
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
has_finish_reason = False
|
||||
idx = 0
|
||||
final_chunk: Optional[litellm.ModelResponse] = None
|
||||
async for chunk in response:
|
||||
final_chunk = chunk
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
has_finish_reason = True
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason == False:
|
||||
raise Exception("finish reason not set")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# TEMP Commented out - replicate throwing an auth error
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
|
@ -984,7 +1040,7 @@ def test_vertex_ai_stream():
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||
try:
|
||||
|
|
|
@ -8606,7 +8606,10 @@ def exception_type(
|
|||
message=f"ReplicateException - {str(original_exception)}",
|
||||
llm_provider="replicate",
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
request=httpx.Request(
|
||||
method="POST",
|
||||
url="https://api.replicate.com/v1/deployments",
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
if "token_quota_reached" in error_str:
|
||||
|
@ -11485,6 +11488,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "replicate"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "predibase"
|
||||
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue