add replicate streaming

This commit is contained in:
ishaan-jaff 2023-09-06 10:23:13 -07:00
parent c45b132675
commit 1c61b7b229
4 changed files with 61 additions and 36 deletions

View file

@ -104,14 +104,39 @@ def completion(
"max_new_tokens": 50,
}
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": input_data},
)
## COMPLETION CALL
## Replicate Compeltion calls have 2 steps
## Step1: Start Prediction: gets a prediction url
## Step2: Poll prediction url for response
## Step2: is handled with and without streaming
prediction_url = start_prediction(version_id, input_data, api_key)
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
print_verbose("streaming request")
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
else:
result = handle_prediction_response(prediction_url, api_key, print_verbose)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=result,
additional_args={"complete_input_dict": input_data},
)
print_verbose(f"raw model_response: {result}")
## Building RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = result
# Calculate usage

View file

@ -371,7 +371,7 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
return response
response = model_response
@ -939,9 +939,6 @@ def text_completion(*args, **kwargs):
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM: {print_statement}")
if random.random() <= 0.3:
print("Get help - https://discord.com/invite/wuPM9dRgDw")
def config_completion(**kwargs):
if litellm.config_path != None:

View file

@ -349,35 +349,7 @@ def test_completion_azure_deployment_id():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
# def test_completion_replicate_llama_stream():
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
# try:
# response = completion(model=model_name, messages=messages, stream=True)
# # Add any assertions here to check the response
# for result in response:
# print(result)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# def test_completion_replicate_stability_stream():
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
# try:
# response = completion(
# model=model_name,
# messages=messages,
# stream=True,
# custom_llm_provider="replicate",
# )
# # Add any assertions here to check the response
# for chunk in response:
# print(chunk["choices"][0]["delta"])
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
def test_completion_replicate_llama_2():
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
@ -396,6 +368,39 @@ def test_completion_replicate_llama_2():
# test_completion_replicate_llama_2()
def test_completion_replicate_llama_stream():
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
try:
response = completion(model=model_name, messages=messages, stream=True)
# Add any assertions here to check the response
for result in response:
print(result)
# chunk_text = result['choices'][0]['delta']['content']
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama_stream()
# def test_completion_replicate_stability_stream():
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
# try:
# response = completion(
# model=model_name,
# messages=messages,
# # stream=True,
# custom_llm_provider="replicate",
# )
# # print(response)
# # Add any assertions here to check the response
# # for chunk in response:
# # print(chunk["choices"][0]["delta"])
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_stability_stream()
######## Test TogetherAI ########
def test_completion_together_ai():

View file

@ -116,8 +116,6 @@ class ModelResponse(OpenAIObject):
def print_verbose(print_statement):
if litellm.set_verbose:
print(f"LiteLLM: {print_statement}")
if random.random() <= 0.3:
print("Get help - https://discord.com/invite/wuPM9dRgDw")
####### LOGGING ###################
from enum import Enum
@ -1896,7 +1894,7 @@ class CustomStreamWrapper:
if self.model in litellm.anthropic_models:
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
elif self.model == "replicate":
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk
elif (