mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(replicate.py): move replicate calls to being completely async
Closes https://github.com/BerriAI/litellm/issues/3128
This commit is contained in:
parent
a2a5884df1
commit
709373b15c
5 changed files with 326 additions and 59 deletions
|
@ -2,11 +2,12 @@ import os, types
|
|||
import json
|
||||
import requests # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
from typing import Callable, Optional, Union, Tuple, Any
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm, asyncio
|
||||
import httpx # type: ignore
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
class ReplicateError(Exception):
|
||||
|
@ -145,6 +146,65 @@ def start_prediction(
|
|||
)
|
||||
|
||||
|
||||
async def async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_token,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
http_handler: AsyncHTTPHandler,
|
||||
) -> str:
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||
version_id = version_id.replace("deployments/", "")
|
||||
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||
print_verbose(f"Deployment base URL: {base_url}\n")
|
||||
else: # assume it's a model
|
||||
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
initial_prediction_data = {
|
||||
"input": input_data,
|
||||
}
|
||||
|
||||
if ":" in version_id and len(version_id) > 64:
|
||||
model_parts = version_id.split(":")
|
||||
if (
|
||||
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||
initial_prediction_data["version"] = model_parts[1]
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_data["prompt"],
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": initial_prediction_data,
|
||||
"headers": headers,
|
||||
"api_base": base_url,
|
||||
},
|
||||
)
|
||||
|
||||
response = await http_handler.post(
|
||||
url="{}/predictions".format(base_url),
|
||||
data=json.dumps(initial_prediction_data),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 201:
|
||||
response_data = response.json()
|
||||
return response_data.get("urls", {}).get("get")
|
||||
else:
|
||||
raise ReplicateError(
|
||||
response.status_code, f"Failed to start prediction {response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (non-streaming)
|
||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||
output_string = ""
|
||||
|
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
|
|||
return output_string, logs
|
||||
|
||||
|
||||
async def async_handle_prediction_response(
|
||||
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
|
||||
) -> Tuple[str, Any]:
|
||||
output_string = ""
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
status = ""
|
||||
logs = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
await asyncio.sleep(0.5)
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
print_verbose(f"Non-streamed output:{output_string}")
|
||||
status = response_data.get("status", None)
|
||||
logs = response_data.get("logs", "")
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
||||
return output_string, logs
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||
previous_output = ""
|
||||
|
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
)
|
||||
|
||||
|
||||
# Function to handle prediction response (streaming)
|
||||
async def async_handle_prediction_response_streaming(
|
||||
prediction_url, api_token, print_verbose
|
||||
):
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
previous_output = ""
|
||||
output_string = ""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Token {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
status = ""
|
||||
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
||||
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||
response = await http_handler.get(prediction_url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
status = response_data["status"]
|
||||
if "output" in response_data:
|
||||
output_string = "".join(response_data["output"])
|
||||
new_output = output_string[len(previous_output) :]
|
||||
print_verbose(f"New chunk: {new_output}")
|
||||
yield {"output": new_output, "status": status}
|
||||
previous_output = output_string
|
||||
status = response_data["status"]
|
||||
if status == "failed":
|
||||
replicate_error = response_data.get("error", "")
|
||||
raise ReplicateError(
|
||||
status_code=400, message=f"Error: {replicate_error}"
|
||||
)
|
||||
else:
|
||||
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||
print_verbose(
|
||||
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||
)
|
||||
|
||||
|
||||
# Function to extract version ID from model string
|
||||
def model_to_version_id(model):
|
||||
if ":" in model:
|
||||
|
@ -222,6 +355,39 @@ def model_to_version_id(model):
|
|||
return model
|
||||
|
||||
|
||||
def process_response(
|
||||
model_response: ModelResponse,
|
||||
result: str,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
prompt: str,
|
||||
) -> ModelResponse:
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
# Main function for prediction completion
|
||||
def completion(
|
||||
model: str,
|
||||
|
@ -229,14 +395,15 @@ def completion(
|
|||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
optional_params: dict,
|
||||
logging_obj,
|
||||
api_key,
|
||||
encoding,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
acompletion=None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
# Start a prediction and get the prediction URL
|
||||
version_id = model_to_version_id(model)
|
||||
## Load Config
|
||||
|
@ -274,6 +441,12 @@ def completion(
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
if prompt is None or not isinstance(prompt, str):
|
||||
raise ReplicateError(
|
||||
status_code=400,
|
||||
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||
)
|
||||
|
||||
# If system prompt is supported, and a system prompt is provided, use it
|
||||
if system_prompt is not None:
|
||||
input_data = {
|
||||
|
@ -285,6 +458,20 @@ def completion(
|
|||
else:
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
if acompletion is not None and acompletion == True:
|
||||
return async_completion(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
encoding=encoding,
|
||||
optional_params=optional_params,
|
||||
version_id=version_id,
|
||||
input_data=input_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
) # type: ignore
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
## Step1: Start Prediction: gets a prediction url
|
||||
|
@ -293,6 +480,7 @@ def completion(
|
|||
model_response["created"] = int(
|
||||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
|
||||
prediction_url = start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
|
@ -306,9 +494,10 @@ def completion(
|
|||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
print_verbose("streaming request")
|
||||
return handle_prediction_response_streaming(
|
||||
_response = handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
else:
|
||||
result, logs = handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose
|
||||
|
@ -328,29 +517,56 @@ def completion(
|
|||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
||||
if len(result) == 0: # edge case, where result from replicate is empty
|
||||
result = " "
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
if len(result) > 1:
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||
completion_tokens = len(
|
||||
encoding.encode(
|
||||
model_response["choices"][0]["message"].get("content", ""),
|
||||
disallowed_special=(),
|
||||
)
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
model_response["model"] = "replicate/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
prompt: str,
|
||||
encoding,
|
||||
optional_params: dict,
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||
prediction_url = await async_start_prediction(
|
||||
version_id,
|
||||
input_data,
|
||||
api_key,
|
||||
api_base,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
http_handler=http_handler,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
_response = async_handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
result, logs = await async_handle_prediction_response(
|
||||
prediction_url, api_key, print_verbose, http_handler=http_handler
|
||||
)
|
||||
|
||||
return process_response(
|
||||
model_response=model_response,
|
||||
result=result,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# # Example usage:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue