forked from phoenix/litellm-mirror
fix(replicate.py): move replicate calls to being completely async
Closes https://github.com/BerriAI/litellm/issues/3128
This commit is contained in:
parent
a2a5884df1
commit
709373b15c
5 changed files with 326 additions and 59 deletions
|
@ -2,11 +2,12 @@ import os, types
|
||||||
import json
|
import json
|
||||||
import requests # type: ignore
|
import requests # type: ignore
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Union, Tuple, Any
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
import litellm
|
import litellm, asyncio
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
|
||||||
class ReplicateError(Exception):
|
class ReplicateError(Exception):
|
||||||
|
@ -145,6 +146,65 @@ def start_prediction(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_start_prediction(
|
||||||
|
version_id,
|
||||||
|
input_data,
|
||||||
|
api_token,
|
||||||
|
api_base,
|
||||||
|
logging_obj,
|
||||||
|
print_verbose,
|
||||||
|
http_handler: AsyncHTTPHandler,
|
||||||
|
) -> str:
|
||||||
|
base_url = api_base
|
||||||
|
if "deployments" in version_id:
|
||||||
|
print_verbose("\nLiteLLM: Request to custom replicate deployment")
|
||||||
|
version_id = version_id.replace("deployments/", "")
|
||||||
|
base_url = f"https://api.replicate.com/v1/deployments/{version_id}"
|
||||||
|
print_verbose(f"Deployment base URL: {base_url}\n")
|
||||||
|
else: # assume it's a model
|
||||||
|
base_url = f"https://api.replicate.com/v1/models/{version_id}"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Token {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
initial_prediction_data = {
|
||||||
|
"input": input_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ":" in version_id and len(version_id) > 64:
|
||||||
|
model_parts = version_id.split(":")
|
||||||
|
if (
|
||||||
|
len(model_parts) > 1 and len(model_parts[1]) == 64
|
||||||
|
): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
|
||||||
|
initial_prediction_data["version"] = model_parts[1]
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_data["prompt"],
|
||||||
|
api_key="",
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": initial_prediction_data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": base_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await http_handler.post(
|
||||||
|
url="{}/predictions".format(base_url),
|
||||||
|
data=json.dumps(initial_prediction_data),
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 201:
|
||||||
|
response_data = response.json()
|
||||||
|
return response_data.get("urls", {}).get("get")
|
||||||
|
else:
|
||||||
|
raise ReplicateError(
|
||||||
|
response.status_code, f"Failed to start prediction {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Function to handle prediction response (non-streaming)
|
# Function to handle prediction response (non-streaming)
|
||||||
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||||
output_string = ""
|
output_string = ""
|
||||||
|
@ -178,6 +238,40 @@ def handle_prediction_response(prediction_url, api_token, print_verbose):
|
||||||
return output_string, logs
|
return output_string, logs
|
||||||
|
|
||||||
|
|
||||||
|
async def async_handle_prediction_response(
|
||||||
|
prediction_url, api_token, print_verbose, http_handler: AsyncHTTPHandler
|
||||||
|
) -> Tuple[str, Any]:
|
||||||
|
output_string = ""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Token {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
status = ""
|
||||||
|
logs = ""
|
||||||
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||||
|
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
response = await http_handler.get(prediction_url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
if "output" in response_data:
|
||||||
|
output_string = "".join(response_data["output"])
|
||||||
|
print_verbose(f"Non-streamed output:{output_string}")
|
||||||
|
status = response_data.get("status", None)
|
||||||
|
logs = response_data.get("logs", "")
|
||||||
|
if status == "failed":
|
||||||
|
replicate_error = response_data.get("error", "")
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"Error: {replicate_error}, \nReplicate logs:{logs}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||||
|
print_verbose("Replicate: Failed to fetch prediction status and output.")
|
||||||
|
return output_string, logs
|
||||||
|
|
||||||
|
|
||||||
# Function to handle prediction response (streaming)
|
# Function to handle prediction response (streaming)
|
||||||
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
def handle_prediction_response_streaming(prediction_url, api_token, print_verbose):
|
||||||
previous_output = ""
|
previous_output = ""
|
||||||
|
@ -214,6 +308,45 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Function to handle prediction response (streaming)
|
||||||
|
async def async_handle_prediction_response_streaming(
|
||||||
|
prediction_url, api_token, print_verbose
|
||||||
|
):
|
||||||
|
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
previous_output = ""
|
||||||
|
output_string = ""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Token {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
status = ""
|
||||||
|
while True and (status not in ["succeeded", "failed", "canceled"]):
|
||||||
|
await asyncio.sleep(0.5) # prevent being rate limited by replicate
|
||||||
|
print_verbose(f"replicate: polling endpoint: {prediction_url}")
|
||||||
|
response = await http_handler.get(prediction_url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
status = response_data["status"]
|
||||||
|
if "output" in response_data:
|
||||||
|
output_string = "".join(response_data["output"])
|
||||||
|
new_output = output_string[len(previous_output) :]
|
||||||
|
print_verbose(f"New chunk: {new_output}")
|
||||||
|
yield {"output": new_output, "status": status}
|
||||||
|
previous_output = output_string
|
||||||
|
status = response_data["status"]
|
||||||
|
if status == "failed":
|
||||||
|
replicate_error = response_data.get("error", "")
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400, message=f"Error: {replicate_error}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this can fail temporarily but it does not mean the replicate request failed, replicate request fails when status=="failed"
|
||||||
|
print_verbose(
|
||||||
|
f"Replicate: Failed to fetch prediction status and output.{response.status_code}{response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Function to extract version ID from model string
|
# Function to extract version ID from model string
|
||||||
def model_to_version_id(model):
|
def model_to_version_id(model):
|
||||||
if ":" in model:
|
if ":" in model:
|
||||||
|
@ -222,6 +355,39 @@ def model_to_version_id(model):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
model_response: ModelResponse,
|
||||||
|
result: str,
|
||||||
|
model: str,
|
||||||
|
encoding: Any,
|
||||||
|
prompt: str,
|
||||||
|
) -> ModelResponse:
|
||||||
|
if len(result) == 0: # edge case, where result from replicate is empty
|
||||||
|
result = " "
|
||||||
|
|
||||||
|
## Building RESPONSE OBJECT
|
||||||
|
if len(result) > 1:
|
||||||
|
model_response["choices"][0]["message"]["content"] = result
|
||||||
|
|
||||||
|
# Calculate usage
|
||||||
|
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(
|
||||||
|
model_response["choices"][0]["message"].get("content", ""),
|
||||||
|
disallowed_special=(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model_response["model"] = "replicate/" + model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
# Main function for prediction completion
|
# Main function for prediction completion
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -229,14 +395,15 @@ def completion(
|
||||||
api_base: str,
|
api_base: str,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
|
optional_params: dict,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_key,
|
api_key,
|
||||||
encoding,
|
encoding,
|
||||||
custom_prompt_dict={},
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
acompletion=None,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
# Start a prediction and get the prediction URL
|
# Start a prediction and get the prediction URL
|
||||||
version_id = model_to_version_id(model)
|
version_id = model_to_version_id(model)
|
||||||
## Load Config
|
## Load Config
|
||||||
|
@ -274,6 +441,12 @@ def completion(
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages)
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
if prompt is None or not isinstance(prompt, str):
|
||||||
|
raise ReplicateError(
|
||||||
|
status_code=400,
|
||||||
|
message="LiteLLM Error - prompt is not a string - {}".format(prompt),
|
||||||
|
)
|
||||||
|
|
||||||
# If system prompt is supported, and a system prompt is provided, use it
|
# If system prompt is supported, and a system prompt is provided, use it
|
||||||
if system_prompt is not None:
|
if system_prompt is not None:
|
||||||
input_data = {
|
input_data = {
|
||||||
|
@ -285,6 +458,20 @@ def completion(
|
||||||
else:
|
else:
|
||||||
input_data = {"prompt": prompt, **optional_params}
|
input_data = {"prompt": prompt, **optional_params}
|
||||||
|
|
||||||
|
if acompletion is not None and acompletion == True:
|
||||||
|
return async_completion(
|
||||||
|
model_response=model_response,
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
encoding=encoding,
|
||||||
|
optional_params=optional_params,
|
||||||
|
version_id=version_id,
|
||||||
|
input_data=input_data,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
) # type: ignore
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
## Replicate Compeltion calls have 2 steps
|
## Replicate Compeltion calls have 2 steps
|
||||||
## Step1: Start Prediction: gets a prediction url
|
## Step1: Start Prediction: gets a prediction url
|
||||||
|
@ -293,6 +480,7 @@ def completion(
|
||||||
model_response["created"] = int(
|
model_response["created"] = int(
|
||||||
time.time()
|
time.time()
|
||||||
) # for pricing this must remain right before calling api
|
) # for pricing this must remain right before calling api
|
||||||
|
|
||||||
prediction_url = start_prediction(
|
prediction_url = start_prediction(
|
||||||
version_id,
|
version_id,
|
||||||
input_data,
|
input_data,
|
||||||
|
@ -306,9 +494,10 @@ def completion(
|
||||||
# Handle the prediction response (streaming or non-streaming)
|
# Handle the prediction response (streaming or non-streaming)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
print_verbose("streaming request")
|
print_verbose("streaming request")
|
||||||
return handle_prediction_response_streaming(
|
_response = handle_prediction_response_streaming(
|
||||||
prediction_url, api_key, print_verbose
|
prediction_url, api_key, print_verbose
|
||||||
)
|
)
|
||||||
|
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||||
else:
|
else:
|
||||||
result, logs = handle_prediction_response(
|
result, logs = handle_prediction_response(
|
||||||
prediction_url, api_key, print_verbose
|
prediction_url, api_key, print_verbose
|
||||||
|
@ -328,29 +517,56 @@ def completion(
|
||||||
|
|
||||||
print_verbose(f"raw model_response: {result}")
|
print_verbose(f"raw model_response: {result}")
|
||||||
|
|
||||||
if len(result) == 0: # edge case, where result from replicate is empty
|
return process_response(
|
||||||
result = " "
|
model_response=model_response,
|
||||||
|
result=result,
|
||||||
## Building RESPONSE OBJECT
|
model=model,
|
||||||
if len(result) > 1:
|
encoding=encoding,
|
||||||
model_response["choices"][0]["message"]["content"] = result
|
prompt=prompt,
|
||||||
|
|
||||||
# Calculate usage
|
|
||||||
prompt_tokens = len(encoding.encode(prompt, disallowed_special=()))
|
|
||||||
completion_tokens = len(
|
|
||||||
encoding.encode(
|
|
||||||
model_response["choices"][0]["message"].get("content", ""),
|
|
||||||
disallowed_special=(),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
model_response["model"] = "replicate/" + model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
async def async_completion(
|
||||||
completion_tokens=completion_tokens,
|
model_response: ModelResponse,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
encoding,
|
||||||
|
optional_params: dict,
|
||||||
|
version_id,
|
||||||
|
input_data,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
logging_obj,
|
||||||
|
print_verbose,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
http_handler = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
prediction_url = await async_start_prediction(
|
||||||
|
version_id,
|
||||||
|
input_data,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
http_handler=http_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
_response = async_handle_prediction_response_streaming(
|
||||||
|
prediction_url, api_key, print_verbose
|
||||||
)
|
)
|
||||||
setattr(model_response, "usage", usage)
|
return CustomStreamWrapper(_response, model, logging_obj=logging_obj, custom_llm_provider="replicate") # type: ignore
|
||||||
return model_response
|
|
||||||
|
result, logs = await async_handle_prediction_response(
|
||||||
|
prediction_url, api_key, print_verbose, http_handler=http_handler
|
||||||
|
)
|
||||||
|
|
||||||
|
return process_response(
|
||||||
|
model_response=model_response,
|
||||||
|
result=result,
|
||||||
|
model=model,
|
||||||
|
encoding=encoding,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# # Example usage:
|
# # Example usage:
|
||||||
|
|
|
@ -320,6 +320,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "ollama_chat"
|
or custom_llm_provider == "ollama_chat"
|
||||||
|
or custom_llm_provider == "replicate"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
or custom_llm_provider == "gemini"
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
|
@ -1188,7 +1189,7 @@ def completion(
|
||||||
|
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
model_response = replicate.completion(
|
model_response = replicate.completion( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1201,12 +1202,10 @@ def completion(
|
||||||
api_key=replicate_key,
|
api_key=replicate_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
|
||||||
# don't try to access stream object,
|
|
||||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
|
||||||
|
|
||||||
if optional_params.get("stream", False) or acompletion == True:
|
if optional_params.get("stream", False) == True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
|
|
@ -2301,36 +2301,28 @@ def test_completion_azure_deployment_id():
|
||||||
|
|
||||||
# test_completion_azure_deployment_id()
|
# test_completion_azure_deployment_id()
|
||||||
|
|
||||||
# Only works for local endpoint
|
|
||||||
# def test_completion_anthropic_openai_proxy():
|
|
||||||
# try:
|
|
||||||
# response = completion(
|
|
||||||
# model="custom_openai/claude-2",
|
|
||||||
# messages=messages,
|
|
||||||
# api_base="http://0.0.0.0:8000"
|
|
||||||
# )
|
|
||||||
# # Add any assertions here to check the response
|
|
||||||
# print(response)
|
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
# test_completion_anthropic_openai_proxy()
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_replicate_llama3(sync_mode):
|
||||||
def test_completion_replicate_llama3():
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||||
try:
|
try:
|
||||||
response = completion(
|
if sync_mode:
|
||||||
model=model_name,
|
response = completion(
|
||||||
messages=messages,
|
model=model_name,
|
||||||
)
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"ASYNC REPLICATE RESPONSE - {response}")
|
||||||
print(response)
|
print(response)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
response_str = response["choices"][0]["message"]["content"]
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
print("RESPONSE STRING\n", response_str)
|
response_format_tests(response=response)
|
||||||
if type(response_str) != str:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -950,7 +950,63 @@ def test_vertex_ai_stream():
|
||||||
|
|
||||||
# test_completion_vertexai_stream_bad_key()
|
# test_completion_vertexai_stream_bad_key()
|
||||||
|
|
||||||
# def test_completion_replicate_stream():
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "replicate/meta/meta-llama-3-8b-instruct"
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
response: litellm.CustomStreamWrapper = completion( # type: ignore
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
else:
|
||||||
|
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=100, # type: ignore
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
has_finish_reason = False
|
||||||
|
idx = 0
|
||||||
|
final_chunk: Optional[litellm.ModelResponse] = None
|
||||||
|
async for chunk in response:
|
||||||
|
final_chunk = chunk
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
has_finish_reason = True
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
idx += 1
|
||||||
|
if has_finish_reason == False:
|
||||||
|
raise Exception("finish reason not set")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# TEMP Commented out - replicate throwing an auth error
|
# TEMP Commented out - replicate throwing an auth error
|
||||||
# try:
|
# try:
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
|
@ -984,7 +1040,7 @@ def test_vertex_ai_stream():
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
async def test_bedrock_cohere_command_r_streaming(sync_mode):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -8606,7 +8606,10 @@ def exception_type(
|
||||||
message=f"ReplicateException - {str(original_exception)}",
|
message=f"ReplicateException - {str(original_exception)}",
|
||||||
llm_provider="replicate",
|
llm_provider="replicate",
|
||||||
model=model,
|
model=model,
|
||||||
request=original_exception.request,
|
request=httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://api.replicate.com/v1/deployments",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
if "token_quota_reached" in error_str:
|
if "token_quota_reached" in error_str:
|
||||||
|
@ -11485,6 +11488,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
or self.custom_llm_provider == "sagemaker"
|
or self.custom_llm_provider == "sagemaker"
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
|
or self.custom_llm_provider == "replicate"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
or (self.custom_llm_provider == "bedrock" and "cohere" in self.model)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue