forked from phoenix/litellm-mirror
add replicate streaming
This commit is contained in:
parent
c45b132675
commit
1c61b7b229
4 changed files with 61 additions and 36 deletions
|
@ -104,14 +104,39 @@ def completion(
|
||||||
"max_new_tokens": 50,
|
"max_new_tokens": 50,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="",
|
||||||
|
additional_args={"complete_input_dict": input_data},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
## Replicate Compeltion calls have 2 steps
|
||||||
|
## Step1: Start Prediction: gets a prediction url
|
||||||
|
## Step2: Poll prediction url for response
|
||||||
|
## Step2: is handled with and without streaming
|
||||||
prediction_url = start_prediction(version_id, input_data, api_key)
|
prediction_url = start_prediction(version_id, input_data, api_key)
|
||||||
print_verbose(prediction_url)
|
print_verbose(prediction_url)
|
||||||
|
|
||||||
# Handle the prediction response (streaming or non-streaming)
|
# Handle the prediction response (streaming or non-streaming)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
print_verbose("streaming request")
|
||||||
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
return handle_prediction_response_streaming(prediction_url, api_key, print_verbose)
|
||||||
else:
|
else:
|
||||||
result = handle_prediction_response(prediction_url, api_key, print_verbose)
|
result = handle_prediction_response(prediction_url, api_key, print_verbose)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="",
|
||||||
|
original_response=result,
|
||||||
|
additional_args={"complete_input_dict": input_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
print_verbose(f"raw model_response: {result}")
|
||||||
|
|
||||||
|
## Building RESPONSE OBJECT
|
||||||
model_response["choices"][0]["message"]["content"] = result
|
model_response["choices"][0]["message"]["content"] = result
|
||||||
|
|
||||||
# Calculate usage
|
# Calculate usage
|
||||||
|
|
|
@ -371,7 +371,7 @@ def completion(
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging)
|
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
|
@ -939,9 +939,6 @@ def text_completion(*args, **kwargs):
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(f"LiteLLM: {print_statement}")
|
print(f"LiteLLM: {print_statement}")
|
||||||
if random.random() <= 0.3:
|
|
||||||
print("Get help - https://discord.com/invite/wuPM9dRgDw")
|
|
||||||
|
|
||||||
|
|
||||||
def config_completion(**kwargs):
|
def config_completion(**kwargs):
|
||||||
if litellm.config_path != None:
|
if litellm.config_path != None:
|
||||||
|
|
|
@ -349,35 +349,7 @@ def test_completion_azure_deployment_id():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# # Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.
|
||||||
# def test_completion_replicate_llama_stream():
|
|
||||||
# model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
|
||||||
# try:
|
|
||||||
# response = completion(model=model_name, messages=messages, stream=True)
|
|
||||||
# # Add any assertions here to check the response
|
|
||||||
# for result in response:
|
|
||||||
# print(result)
|
|
||||||
# print(response)
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_replicate_stability_stream():
|
|
||||||
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
|
||||||
# try:
|
|
||||||
# response = completion(
|
|
||||||
# model=model_name,
|
|
||||||
# messages=messages,
|
|
||||||
# stream=True,
|
|
||||||
# custom_llm_provider="replicate",
|
|
||||||
# )
|
|
||||||
# # Add any assertions here to check the response
|
|
||||||
# for chunk in response:
|
|
||||||
# print(chunk["choices"][0]["delta"])
|
|
||||||
# print(response)
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_replicate_llama_2():
|
def test_completion_replicate_llama_2():
|
||||||
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
|
||||||
|
@ -396,6 +368,39 @@ def test_completion_replicate_llama_2():
|
||||||
|
|
||||||
# test_completion_replicate_llama_2()
|
# test_completion_replicate_llama_2()
|
||||||
|
|
||||||
|
def test_completion_replicate_llama_stream():
|
||||||
|
model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages, stream=True)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for result in response:
|
||||||
|
print(result)
|
||||||
|
# chunk_text = result['choices'][0]['delta']['content']
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_completion_replicate_llama_stream()
|
||||||
|
|
||||||
|
# def test_completion_replicate_stability_stream():
|
||||||
|
# model_name = "stability-ai/stablelm-tuned-alpha-7b:c49dae362cbaecd2ceabb5bd34fdb68413c4ff775111fea065d259d577757beb"
|
||||||
|
# try:
|
||||||
|
# response = completion(
|
||||||
|
# model=model_name,
|
||||||
|
# messages=messages,
|
||||||
|
# # stream=True,
|
||||||
|
# custom_llm_provider="replicate",
|
||||||
|
# )
|
||||||
|
# # print(response)
|
||||||
|
# # Add any assertions here to check the response
|
||||||
|
# # for chunk in response:
|
||||||
|
# # print(chunk["choices"][0]["delta"])
|
||||||
|
# print(response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_completion_replicate_stability_stream()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
######## Test TogetherAI ########
|
######## Test TogetherAI ########
|
||||||
def test_completion_together_ai():
|
def test_completion_together_ai():
|
||||||
|
|
|
@ -116,8 +116,6 @@ class ModelResponse(OpenAIObject):
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(f"LiteLLM: {print_statement}")
|
print(f"LiteLLM: {print_statement}")
|
||||||
if random.random() <= 0.3:
|
|
||||||
print("Get help - https://discord.com/invite/wuPM9dRgDw")
|
|
||||||
|
|
||||||
####### LOGGING ###################
|
####### LOGGING ###################
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -1896,7 +1894,7 @@ class CustomStreamWrapper:
|
||||||
if self.model in litellm.anthropic_models:
|
if self.model in litellm.anthropic_models:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
|
completion_obj["content"] = self.handle_anthropic_chunk(chunk)
|
||||||
elif self.model == "replicate":
|
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
completion_obj["content"] = chunk
|
completion_obj["content"] = chunk
|
||||||
elif (
|
elif (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue