mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
add replicate streaming
This commit is contained in:
parent
c45b132675
commit
1c61b7b229
4 changed files with 61 additions and 36 deletions
|
@ -104,14 +104,39 @@ def completion(
|
|||
"max_new_tokens": 50,
|
||||
}
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": input_data},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
## Replicate Compeltion calls have 2 steps
|
||||
## Step1: Start Prediction: gets a prediction url
|
||||
## Step2: Poll prediction url for response
|
||||
## Step2: is handled with and without streaming
|
||||
prediction_url = start_prediction(version_id, input_data, api_key)
|
||||
print_verbose(prediction_url)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
print_verbose("streaming request")
|
||||
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
||||
else:
|
||||
result = handle_prediction_response(prediction_url, api_key, print_verbose)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=result,
|
||||
additional_args={"complete_input_dict": input_data},
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
||||
## Building RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = result
|
||||
|
||||
# Calculate usage
|
||||
|
|
|
@ -371,7 +371,7 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
||||
return response
|
||||
response = model_response
|
||||
|
||||
|
@ -939,9 +939,6 @@ def text_completion(*args, **kwargs):
|
|||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(f"LiteLLM: {print_statement}")
|
||||
if random.random() <= 0.3:
|
||||
print("Get help - https://discord.com/invite/wuPM9dRgDw")
|
||||
|
||||
|
||||
def config_completion(**kwargs):
|
||||
if litellm.config_path != None:
|
||||
|
|
|
@ -349,35 +349,7 @@ def test_completion_azure_deployment_id():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||
# def test_completion_replicate_llama_stream():
|
||||
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
# try:
|
||||
# response = completion(model=model_name, messages=messages, stream=True)
|
||||
# # Add any assertions here to check the response
|
||||
# for result in response:
|
||||
# print(result)
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# def test_completion_replicate_stability_stream():
|
||||
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
# try:
|
||||
# response = completion(
|
||||
# model=model_name,
|
||||
# messages=messages,
|
||||
# stream=True,
|
||||
# custom_llm_provider="replicate",
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# for chunk in response:
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||
|
||||
def test_completion_replicate_llama_2():
|
||||
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
||||
|
@ -396,6 +368,39 @@ def test_completion_replicate_llama_2():
|
|||
|
||||
# test_completion_replicate_llama_2()
|
||||
|
||||
def test_completion_replicate_llama_stream():
|
||||
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||
try:
|
||||
response = completion(model=model_name, messages=messages, stream=True)
|
||||
# Add any assertions here to check the response
|
||||
for result in response:
|
||||
print(result)
|
||||
# chunk_text = result['choices'][0]['delta']['content']
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_replicate_llama_stream()
|
||||
|
||||
# def test_completion_replicate_stability_stream():
|
||||
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||
# try:
|
||||
# response = completion(
|
||||
# model=model_name,
|
||||
# messages=messages,
|
||||
# # stream=True,
|
||||
# custom_llm_provider="replicate",
|
||||
# )
|
||||
# # print(response)
|
||||
# # Add any assertions here to check the response
|
||||
# # for chunk in response:
|
||||
# # print(chunk["choices"][0]["delta"])
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_replicate_stability_stream()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
######## Test TogetherAI ########
|
||||
def test_completion_together_ai():
|
||||
|
|
|
@ -116,8 +116,6 @@ class ModelResponse(OpenAIObject):
|
|||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(f"LiteLLM: {print_statement}")
|
||||
if random.random() <= 0.3:
|
||||
print("Get help - https://discord.com/invite/wuPM9dRgDw")
|
||||
|
||||
####### LOGGING ###################
|
||||
from enum import Enum
|
||||
|
@ -1896,7 +1894,7 @@ class CustomStreamWrapper:
|
|||
if self.model in litellm.anthropic_models:
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
|
||||
elif self.model == "replicate":
|
||||
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = chunk
|
||||
elif (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue