From 1c61b7b229f1c677cf464c55f78adfd6af095eae Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 6 Sep 2023 10:23:13 -0700 Subject: [PATCH] add replicate streaming --- litellm/llms/replicate.py | 25 +++++++++++++ litellm/main.py | 5 +-- litellm/tests/test_completion.py | 63 +++++++++++++++++--------------- litellm/utils.py | 4 +- 4 files changed, 61 insertions(+), 36 deletions(-) diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index b570c27d2..4e14fa50d 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -104,14 +104,39 @@ def completion( "max_new_tokens": 50, } + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": input_data}, + ) + ## COMPLETION CALL + ## Replicate Compeltion calls have 2 steps + ## Step1: Start Prediction: gets a prediction url + ## Step2: Poll prediction url for response + ## Step2: is handled with and without streaming prediction_url = start_prediction(version_id, input_data, api_key) print_verbose(prediction_url) # Handle the prediction response (streaming or non-streaming) if "stream" in optional_params and optional_params["stream"] == True: + print_verbose("streaming request") return handle_prediction_response_streaming(prediction_url, api_key, print_verbose) else: result = handle_prediction_response(prediction_url, api_key, print_verbose) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=result, + additional_args={"complete_input_dict": input_data}, + ) + + print_verbose(f"raw model_response: {result}") + + ## Building RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = result # Calculate usage diff --git a/litellm/main.py b/litellm/main.py index 0c68b048f..d27daf403 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -371,7 +371,7 @@ def completion( ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, logging_obj=logging) + response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") return response response = model_response @@ -939,9 +939,6 @@ def text_completion(*args, **kwargs): def print_verbose(print_statement): if litellm.set_verbose: print(f"LiteLLM: {print_statement}") - if random.random() <= 0.3: - print("Get help - https://discord.com/invite/wuPM9dRgDw") - def config_completion(**kwargs): if litellm.config_path != None: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c31a96256..a2880916d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -349,35 +349,7 @@ def test_completion_azure_deployment_id(): except Exception as e: pytest.fail(f"Error occurred: {e}") -# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. -# def test_completion_replicate_llama_stream(): -# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" -# try: -# response = completion(model=model_name, messages=messages, stream=True) -# # Add any assertions here to check the response -# for result in response: -# print(result) -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") - - -# def test_completion_replicate_stability_stream(): -# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" -# try: -# response = completion( -# model=model_name, -# messages=messages, -# stream=True, -# custom_llm_provider="replicate", -# ) -# # Add any assertions here to check the response -# for chunk in response: -# print(chunk["choices"][0]["delta"]) -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") - +# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect. def test_completion_replicate_llama_2(): model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf" @@ -396,6 +368,39 @@ def test_completion_replicate_llama_2(): # test_completion_replicate_llama_2() +def test_completion_replicate_llama_stream(): + model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" + try: + response = completion(model=model_name, messages=messages, stream=True) + # Add any assertions here to check the response + for result in response: + print(result) + # chunk_text = result['choices'][0]['delta']['content'] + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_completion_replicate_llama_stream() + +# def test_completion_replicate_stability_stream(): +# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb" +# try: +# response = completion( +# model=model_name, +# messages=messages, +# # stream=True, +# custom_llm_provider="replicate", +# ) +# # print(response) +# # Add any assertions here to check the response +# # for chunk in response: +# # print(chunk["choices"][0]["delta"]) +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") +# test_completion_replicate_stability_stream() + + + + ######## Test TogetherAI ######## def test_completion_together_ai(): diff --git a/litellm/utils.py b/litellm/utils.py index 7d4a1f488..9fa7b0d93 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -116,8 +116,6 @@ class ModelResponse(OpenAIObject): def print_verbose(print_statement): if litellm.set_verbose: print(f"LiteLLM: {print_statement}") - if random.random() <= 0.3: - print("Get help - https://discord.com/invite/wuPM9dRgDw") ####### LOGGING ################### from enum import Enum @@ -1896,7 +1894,7 @@ class CustomStreamWrapper: if self.model in litellm.anthropic_models: chunk = next(self.completion_stream) completion_obj["content"] = self.handle_anthropic_chunk(chunk) - elif self.model == "replicate": + elif self.model == "replicate" or self.custom_llm_provider == "replicate": chunk = next(self.completion_stream) completion_obj["content"] = chunk elif (