diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index c29728134..386d24f59 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -2,11 +2,12 @@ import os, types import json import requests # type: ignore import time -from typing import Callable, Optional -from litellm.utils import ModelResponse, Usage -import litellm +from typing import Callable, Optional, Union, Tuple, Any +from litellm.utils import ModelResponse, Usage, CustomStreamWrapper +import litellm, asyncio import httpx # type: ignore from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler class ReplicateError(Exception): @@ -145,6 +146,65 @@ def start_prediction( ) +async def async_start_prediction( + version_id, + input_data, + api_token, + api_base, + logging_obj, + print_verbose, + http_handler: AsyncHTTPHandler, +) -> str: + base_url = api_base + if "deployments" in version_id: + print_verbose("\nLiteLLM: Request to custom replicate deployment") + version_id = version_id.replace("deployments/", "") + base_url = f"https://api.replicate.com/v1/deployments/{version_id}" + print_verbose(f"Deployment base URL: {base_url}\n") + else: # assume it's a model + base_url = f"https://api.replicate.com/v1/models/{version_id}" + headers = { + "Authorization": f"Token {api_token}", + "Content-Type": "application/json", + } + + initial_prediction_data = { + "input": input_data, + } + + if ":" in version_id and len(version_id) > 64: + model_parts = version_id.split(":") + if ( + len(model_parts) > 1 and len(model_parts[1]) == 64 + ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + initial_prediction_data["version"] = model_parts[1] + + ## LOGGING + logging_obj.pre_call( + input=input_data["prompt"], + api_key="", + additional_args={ + "complete_input_dict": initial_prediction_data, + "headers": headers, + "api_base": base_url, + }, + ) + + response = await http_handler.post( + url="{}/predictions".format(base_url), + data=json.dumps(initial_prediction_data), + headers=headers, + ) + + if response.status_code == 201: + response_data = response.json() + return response_data.get("urls", {}).get("get") + else: + raise ReplicateError( + response.status_code, f"Failed to start prediction {response.text}" + ) + + # Function to handle prediction response (non-streaming) def handle_prediction_response(prediction_url, api_token, print_verbose): output_string = "" @@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose): return output_string, logs +async def async_handle_prediction_response( + prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler +) -> Tuple[str, Any]: + output_string = "" + headers = { + "Authorization": f"Token {api_token}", + "Content-Type": "application/json", + } + + status = "" + logs = "" + while True and (status not in ["succeeded", "failed", "canceled"]): + print_verbose(f"replicate: polling endpoint: {prediction_url}") + await asyncio.sleep(0.5) + response = await http_handler.get(prediction_url, headers=headers) + if response.status_code == 200: + response_data = response.json() + if "output" in response_data: + output_string = "".join(response_data["output"]) + print_verbose(f"Non-streamed output:{output_string}") + status = response_data.get("status", None) + logs = response_data.get("logs", "") + if status == "failed": + replicate_error = response_data.get("error", "") + raise ReplicateError( + status_code=400, + message=f"Error: {replicate_error}, \nReplicate logs:{logs}", + ) + else: + # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" + print_verbose("Replicate: Failed to fetch prediction status and output.") + return output_string, logs + + # Function to handle prediction response (streaming) def handle_prediction_response_streaming(prediction_url, api_token, print_verbose): previous_output = "" @@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos ) +# Function to handle prediction response (streaming) +async def async_handle_prediction_response_streaming( + prediction_url, api_token, print_verbose +): + http_handler = AsyncHTTPHandler(concurrent_limit=1) + previous_output = "" + output_string = "" + + headers = { + "Authorization": f"Token {api_token}", + "Content-Type": "application/json", + } + status = "" + while True and (status not in ["succeeded", "failed", "canceled"]): + await asyncio.sleep(0.5) # prevent being rate limited by replicate + print_verbose(f"replicate: polling endpoint: {prediction_url}") + response = await http_handler.get(prediction_url, headers=headers) + if response.status_code == 200: + response_data = response.json() + status = response_data["status"] + if "output" in response_data: + output_string = "".join(response_data["output"]) + new_output = output_string[len(previous_output) :] + print_verbose(f"New chunk: {new_output}") + yield {"output": new_output, "status": status} + previous_output = output_string + status = response_data["status"] + if status == "failed": + replicate_error = response_data.get("error", "") + raise ReplicateError( + status_code=400, message=f"Error: {replicate_error}" + ) + else: + # this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed" + print_verbose( + f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}" + ) + + # Function to extract version ID from model string def model_to_version_id(model): if ":" in model: @@ -222,6 +355,39 @@ def model_to_version_id(model): return model +def process_response( + model_response: ModelResponse, + result: str, + model: str, + encoding: Any, + prompt: str, +) -> ModelResponse: + if len(result) == 0: # edge case, where result from replicate is empty + result = " " + + ## Building RESPONSE OBJECT + if len(result) > 1: + model_response["choices"][0]["message"]["content"] = result + + # Calculate usage + prompt_tokens = len(encoding.encode(prompt, disallowed_special=())) + completion_tokens = len( + encoding.encode( + model_response["choices"][0]["message"].get("content", ""), + disallowed_special=(), + ) + ) + model_response["model"] = "replicate/" + model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + + # Main function for prediction completion def completion( model: str, @@ -229,14 +395,15 @@ def completion( api_base: str, model_response: ModelResponse, print_verbose: Callable, + optional_params: dict, logging_obj, api_key, encoding, custom_prompt_dict={}, - optional_params=None, litellm_params=None, logger_fn=None, -): + acompletion=None, +) -> Union[ModelResponse, CustomStreamWrapper]: # Start a prediction and get the prediction URL version_id = model_to_version_id(model) ## Load Config @@ -274,6 +441,12 @@ def completion( else: prompt = prompt_factory(model=model, messages=messages) + if prompt is None or not isinstance(prompt, str): + raise ReplicateError( + status_code=400, + message="LiteLLM Error - prompt is not a string - {}".format(prompt), + ) + # If system prompt is supported, and a system prompt is provided, use it if system_prompt is not None: input_data = { @@ -285,6 +458,20 @@ def completion( else: input_data = {"prompt": prompt, **optional_params} + if acompletion is not None and acompletion == True: + return async_completion( + model_response=model_response, + model=model, + prompt=prompt, + encoding=encoding, + optional_params=optional_params, + version_id=version_id, + input_data=input_data, + api_key=api_key, + api_base=api_base, + logging_obj=logging_obj, + print_verbose=print_verbose, + ) # type: ignore ## COMPLETION CALL ## Replicate Compeltion calls have 2 steps ## Step1: Start Prediction: gets a prediction url @@ -293,6 +480,7 @@ def completion( model_response["created"] = int( time.time() ) # for pricing this must remain right before calling api + prediction_url = start_prediction( version_id, input_data, @@ -306,9 +494,10 @@ def completion( # Handle the prediction response (streaming or non-streaming) if "stream" in optional_params and optional_params["stream"] == True: print_verbose("streaming request") - return handle_prediction_response_streaming( + _response = handle_prediction_response_streaming( prediction_url, api_key, print_verbose ) + return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore else: result, logs = handle_prediction_response( prediction_url, api_key, print_verbose @@ -328,29 +517,56 @@ def completion( print_verbose(f"raw model_response: {result}") - if len(result) == 0: # edge case, where result from replicate is empty - result = " " - - ## Building RESPONSE OBJECT - if len(result) > 1: - model_response["choices"][0]["message"]["content"] = result - - # Calculate usage - prompt_tokens = len(encoding.encode(prompt, disallowed_special=())) - completion_tokens = len( - encoding.encode( - model_response["choices"][0]["message"].get("content", ""), - disallowed_special=(), - ) + return process_response( + model_response=model_response, + result=result, + model=model, + encoding=encoding, + prompt=prompt, ) - model_response["model"] = "replicate/" + model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + + +async def async_completion( + model_response: ModelResponse, + model: str, + prompt: str, + encoding, + optional_params: dict, + version_id, + input_data, + api_key, + api_base, + logging_obj, + print_verbose, +) -> Union[ModelResponse, CustomStreamWrapper]: + http_handler = AsyncHTTPHandler(concurrent_limit=1) + prediction_url = await async_start_prediction( + version_id, + input_data, + api_key, + api_base, + logging_obj=logging_obj, + print_verbose=print_verbose, + http_handler=http_handler, + ) + + if "stream" in optional_params and optional_params["stream"] == True: + _response = async_handle_prediction_response_streaming( + prediction_url, api_key, print_verbose ) - setattr(model_response, "usage", usage) - return model_response + return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore + + result, logs = await async_handle_prediction_response( + prediction_url, api_key, print_verbose, http_handler=http_handler + ) + + return process_response( + model_response=model_response, + result=result, + model=model, + encoding=encoding, + prompt=prompt, + ) # # Example usage: diff --git a/litellm/main.py b/litellm/main.py index f4420435d..2e4132a42 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -320,6 +320,7 @@ async def acompletion( or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" + or custom_llm_provider == "replicate" or custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" @@ -1188,7 +1189,7 @@ def completion( custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = replicate.completion( + model_response = replicate.completion( # type: ignore model=model, messages=messages, api_base=api_base, @@ -1201,12 +1202,10 @@ def completion( api_key=replicate_key, logging_obj=logging, custom_prompt_dict=custom_prompt_dict, + acompletion=acompletion, ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore - if optional_params.get("stream", False) or acompletion == True: + if optional_params.get("stream", False) == True: ## LOGGING logging.post_call( input=messages, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 7ef5d93c1..a7e965e62 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2301,36 +2301,28 @@ def test_completion_azure_deployment_id(): # test_completion_azure_deployment_id() -# Only works for local endpoint -# def test_completion_anthropic_openai_proxy(): -# try: -# response = completion( -# model="custom_openai/claude-2", -# messages=messages, -# api_base="http://0.0.0.0:8000" -# ) -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -# test_completion_anthropic_openai_proxy() - - -def test_completion_replicate_llama3(): +@pytest.mark.parametrize("sync_mode", [False, True]) +@pytest.mark.asyncio +async def test_completion_replicate_llama3(sync_mode): litellm.set_verbose = True model_name = "replicate/meta/meta-llama-3-8b-instruct" try: - response = completion( - model=model_name, - messages=messages, - ) + if sync_mode: + response = completion( + model=model_name, + messages=messages, + ) + else: + response = await litellm.acompletion( + model=model_name, + messages=messages, + ) + print(f"ASYNC REPLICATE RESPONSE - {response}") print(response) # Add any assertions here to check the response - response_str = response["choices"][0]["message"]["content"] - print("RESPONSE STRING\n", response_str) - if type(response_str) != str: - pytest.fail(f"Error occurred: {e}") + assert isinstance(response, litellm.ModelResponse) + response_format_tests(response=response) except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 6dcdbeb17..ac5062938 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -950,7 +950,63 @@ def test_vertex_ai_stream(): # test_completion_vertexai_stream_bad_key() -# def test_completion_replicate_stream(): + +@pytest.mark.parametrize("sync_mode", [False, True]) +@pytest.mark.asyncio +async def test_completion_replicate_llama3_streaming(sync_mode): + litellm.set_verbose = True + model_name = "replicate/meta/meta-llama-3-8b-instruct" + try: + if sync_mode: + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model=model_name, + messages=messages, + max_tokens=10, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model=model_name, + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # TEMP Commented out - replicate throwing an auth error # try: # litellm.set_verbose = True @@ -984,7 +1040,7 @@ def test_vertex_ai_stream(): # pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_bedrock_cohere_command_r_streaming(sync_mode): try: diff --git a/litellm/utils.py b/litellm/utils.py index 90795457a..53e2c57fa 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8606,7 +8606,10 @@ def exception_type( message=f"ReplicateException - {str(original_exception)}", llm_provider="replicate", model=model, - request=original_exception.request, + request=httpx.Request( + method="POST", + url="https://api.replicate.com/v1/deployments", + ), ) elif custom_llm_provider == "watsonx": if "token_quota_reached" in error_str: @@ -11485,6 +11488,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" + or self.custom_llm_provider == "replicate" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)