forked from phoenix/litellm-mirror
fix(replicate.py): fix custom prompt formatting
This commit is contained in:
parent
c05da0797b
commit
1f5a1122fc
5 changed files with 177 additions and 80 deletions
|
@ -6,6 +6,7 @@ from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
|
||||||
class ReplicateError(Exception):
|
class ReplicateError(Exception):
|
||||||
def __init__(self, status_code, message):
|
def __init__(self, status_code, message):
|
||||||
|
@ -186,6 +187,7 @@ def completion(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
api_key,
|
api_key,
|
||||||
encoding,
|
encoding,
|
||||||
|
custom_prompt_dict={},
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -213,10 +215,19 @@ def completion(
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Convert messages to prompt
|
if model in custom_prompt_dict:
|
||||||
prompt = ""
|
# check if the model has a registered custom prompt
|
||||||
for message in messages:
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
prompt += message["content"]
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details.get("roles", {}),
|
||||||
|
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||||
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||||
|
bos_token=model_prompt_details.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_details.get("eos_token", ""),
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
@ -245,7 +256,7 @@ def completion(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key="",
|
api_key="",
|
||||||
original_response=result,
|
original_response=result,
|
||||||
additional_args={"complete_input_dict": input_data,"logs": logs},
|
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"raw model_response: {result}")
|
print_verbose(f"raw model_response: {result}")
|
||||||
|
|
|
@ -684,6 +684,11 @@ def completion(
|
||||||
or "https://api.replicate.com/v1"
|
or "https://api.replicate.com/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
custom_prompt_dict = (
|
||||||
|
custom_prompt_dict
|
||||||
|
or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
|
|
||||||
model_response = replicate.completion(
|
model_response = replicate.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -696,6 +701,7 @@ def completion(
|
||||||
encoding=encoding, # for calculating input/output tokens
|
encoding=encoding, # for calculating input/output tokens
|
||||||
api_key=replicate_key,
|
api_key=replicate_key,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
custom_prompt_dict=custom_prompt_dict
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
|
|
|
@ -99,8 +99,8 @@ class Router:
|
||||||
|
|
||||||
# default litellm args
|
# default litellm args
|
||||||
self.default_litellm_params = default_litellm_params
|
self.default_litellm_params = default_litellm_params
|
||||||
self.default_litellm_params["timeout"] = timeout
|
self.default_litellm_params.setdefault("timeout", timeout)
|
||||||
self.default_litellm_params["max_retries"] = 0
|
self.default_litellm_params.setdefault("max_retries", 0)
|
||||||
|
|
||||||
|
|
||||||
### HEALTH CHECK THREAD ###
|
### HEALTH CHECK THREAD ###
|
||||||
|
@ -204,7 +204,8 @@ class Router:
|
||||||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||||
timeout = kwargs.get("request_timeout", self.timeout)
|
timeout = kwargs.get("request_timeout", self.timeout)
|
||||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
# response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||||
|
response = await self.async_function_with_fallbacks(**kwargs)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -402,7 +403,7 @@ class Router:
|
||||||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||||
response = await original_function(*args, **kwargs)
|
response = await original_function(*args, **kwargs)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
original_exception = e
|
original_exception = e
|
||||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||||
|
@ -410,7 +411,10 @@ class Router:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
### RETRY
|
### RETRY
|
||||||
#### check if it should retry + back-off if required
|
#### check if it should retry + back-off if required
|
||||||
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
if isinstance(original_exception, asyncio.CancelledError):
|
||||||
|
timeout = 0 # immediately retry
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
elif hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
||||||
if hasattr(original_exception.response, "headers"):
|
if hasattr(original_exception.response, "headers"):
|
||||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
|
||||||
else:
|
else:
|
||||||
|
@ -428,8 +432,11 @@ class Router:
|
||||||
response = await response
|
response = await response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
if isinstance(original_exception, asyncio.CancelledError):
|
||||||
|
timeout = 0 # immediately retry
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||||
remaining_retries = num_retries - current_attempt
|
remaining_retries = num_retries - current_attempt
|
||||||
if hasattr(e.response, "headers"):
|
if hasattr(e.response, "headers"):
|
||||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
|
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
|
||||||
|
|
|
@ -630,7 +630,7 @@ def test_re_use_openaiClient():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("got Exception", e)
|
pytest.fail("got Exception", e)
|
||||||
test_re_use_openaiClient()
|
# test_re_use_openaiClient()
|
||||||
|
|
||||||
def test_completion_azure():
|
def test_completion_azure():
|
||||||
try:
|
try:
|
||||||
|
@ -845,7 +845,7 @@ def test_completion_azure_deployment_id():
|
||||||
|
|
||||||
def test_completion_replicate_vicuna():
|
def test_completion_replicate_vicuna():
|
||||||
print("TESTING REPLICATE")
|
print("TESTING REPLICATE")
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=True
|
||||||
model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b"
|
model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b"
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
|
@ -898,6 +898,43 @@ def test_completion_replicate_llama2_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_replicate_llama2_stream()
|
# test_completion_replicate_llama2_stream()
|
||||||
|
|
||||||
|
def test_replicate_custom_prompt_dict():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0"
|
||||||
|
litellm.register_prompt_template(
|
||||||
|
model="replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0",
|
||||||
|
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
||||||
|
roles={
|
||||||
|
"system": {
|
||||||
|
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||||
|
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"pre_message": "[INST] ", # [OPTIONAL]
|
||||||
|
"post_message": " [/INST]" # [OPTIONAL]
|
||||||
|
},
|
||||||
|
"assistant": {
|
||||||
|
"pre_message": "\n", # [OPTIONAL]
|
||||||
|
"post_message": "\n" # [OPTIONAL]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||||
|
)
|
||||||
|
response = completion(
|
||||||
|
model=model_name,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is yc write 1 paragraph",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
num_retries=3
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
litellm.custom_prompt_dict = {} # reset
|
||||||
|
|
||||||
|
test_replicate_custom_prompt_dict()
|
||||||
|
|
||||||
# commenthing this out since we won't be always testing a custom replicate deployment
|
# commenthing this out since we won't be always testing a custom replicate deployment
|
||||||
# def test_completion_replicate_deployments():
|
# def test_completion_replicate_deployments():
|
||||||
# print("TESTING REPLICATE")
|
# print("TESTING REPLICATE")
|
||||||
|
|
|
@ -1,80 +1,116 @@
|
||||||
# #### What this tests ####
|
#### What this tests ####
|
||||||
# # This profiles a router call to find where calls are taking the most time.
|
# This profiles a router call to find where calls are taking the most time.
|
||||||
|
|
||||||
# import sys, os, time, logging
|
import sys, os, time, logging
|
||||||
# import traceback, asyncio, uuid
|
import traceback, asyncio, uuid
|
||||||
# import pytest
|
import pytest
|
||||||
# import cProfile
|
import cProfile
|
||||||
# from pstats import Stats
|
from pstats import Stats
|
||||||
# sys.path.insert(
|
sys.path.insert(
|
||||||
# 0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
# import litellm
|
import litellm
|
||||||
# from litellm import Router
|
from litellm import Router
|
||||||
# from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
# from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
|
from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
|
||||||
|
|
||||||
# load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# logging.basicConfig(
|
logging.basicConfig(
|
||||||
# level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
# format='%(asctime)s %(levelname)s: %(message)s',
|
format='%(asctime)s %(levelname)s: %(message)s',
|
||||||
# datefmt='%I:%M:%S %p',
|
datefmt='%I:%M:%S %p',
|
||||||
# filename='aiologs.log', # Name of the log file where logs will be written
|
filename='aiologs.log', # Name of the log file where logs will be written
|
||||||
# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
|
filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
|
||||||
# )
|
)
|
||||||
|
|
||||||
|
|
||||||
# model_list = [{
|
model_list = [{
|
||||||
# "model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "azure/gpt-turbo",
|
"model": "azure/gpt-turbo",
|
||||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||||
# "rpm": 1440,
|
"rpm": 1440,
|
||||||
# }
|
}
|
||||||
# }, {
|
}, {
|
||||||
# "model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "azure/gpt-35-turbo",
|
"model": "azure/gpt-35-turbo",
|
||||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||||
# "rpm": 6
|
"rpm": 6
|
||||||
# }
|
}
|
||||||
# }, {
|
}, {
|
||||||
# "model_name": "azure-model",
|
"model_name": "azure-model",
|
||||||
# "litellm_params": {
|
"litellm_params": {
|
||||||
# "model": "azure/gpt-35-turbo",
|
"model": "azure/gpt-35-turbo",
|
||||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||||
# "rpm": 6
|
"rpm": 6
|
||||||
# }
|
}
|
||||||
# }]
|
}]
|
||||||
|
|
||||||
# router = Router(model_list=model_list, set_verbose=False, num_retries=3)
|
router = Router(model_list=model_list, set_verbose=False, num_retries=3)
|
||||||
|
|
||||||
# async def router_completion():
|
async def router_completion():
|
||||||
# try:
|
try:
|
||||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||||
# response = await router.acompletion(model="azure-model", messages=messages)
|
response = await router.acompletion(model="azure-model", messages=messages)
|
||||||
# return response
|
return response
|
||||||
# except Exception as e:
|
except asyncio.exceptions.CancelledError:
|
||||||
# print(e, file=sys.stderr)
|
print("Task was cancelled")
|
||||||
# traceback.print_exc()
|
return None
|
||||||
# return None
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
# async def loadtest_fn():
|
async def loadtest_fn(n = 1000):
|
||||||
# start = time.time()
|
start = time.time()
|
||||||
# n = 1000
|
tasks = [router_completion() for _ in range(n)]
|
||||||
# tasks = [router_completion() for _ in range(n)]
|
chat_completions = await asyncio.gather(*tasks)
|
||||||
# chat_completions = await asyncio.gather(*tasks)
|
successful_completions = [c for c in chat_completions if c is not None]
|
||||||
# successful_completions = [c for c in chat_completions if c is not None]
|
print(n, time.time() - start, len(successful_completions))
|
||||||
# print(n, time.time() - start, len(successful_completions))
|
|
||||||
|
|
||||||
# loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
# loop.set_debug(True)
|
# loop.set_debug(True)
|
||||||
# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds
|
# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds
|
||||||
|
|
||||||
# # Excute the load testing function within the asyncio event loop
|
# # Excute the load testing function within the asyncio event loop
|
||||||
# loop.run_until_complete(loadtest_fn())
|
# loop.run_until_complete(loadtest_fn())
|
||||||
|
|
||||||
|
### SUSTAINED LOAD TESTS ###
|
||||||
|
import time, asyncio
|
||||||
|
async def make_requests(n):
|
||||||
|
tasks = [router_completion() for _ in range(n)]
|
||||||
|
print(f"num tasks: {len(tasks)}")
|
||||||
|
chat_completions = await asyncio.gather(*tasks)
|
||||||
|
successful_completions = [c for c in chat_completions if c is not None]
|
||||||
|
print(f"successful_completions: {len(successful_completions)}")
|
||||||
|
return successful_completions
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
start_time = time.time()
|
||||||
|
total_successful_requests = 0
|
||||||
|
request_limit = 1000
|
||||||
|
batches = 2 # batches of 1k requests
|
||||||
|
start = time.time()
|
||||||
|
tasks = [] # list to hold all tasks
|
||||||
|
|
||||||
|
async def request_loop():
|
||||||
|
nonlocal tasks
|
||||||
|
for _ in range(batches):
|
||||||
|
# Make 1,000 requests
|
||||||
|
task = asyncio.create_task(make_requests(request_limit))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Introduce a delay to achieve 1,000 requests per second
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
await request_loop()
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
total_successful_requests = sum(len(res) for res in results)
|
||||||
|
|
||||||
|
print(request_limit*batches, time.time() - start, total_successful_requests)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
Loading…
Add table
Add a link
Reference in a new issue