forked from phoenix/litellm-mirror
fix(replicate.py): fix custom prompt formatting
This commit is contained in:
parent
c05da0797b
commit
1f5a1122fc
5 changed files with 177 additions and 80 deletions
|
@ -6,6 +6,7 @@ from typing import Callable, Optional
|
|||
from litellm.utils import ModelResponse, Usage
|
||||
import litellm
|
||||
import httpx
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
class ReplicateError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
|
@ -186,6 +187,7 @@ def completion(
|
|||
logging_obj,
|
||||
api_key,
|
||||
encoding,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -213,10 +215,19 @@ def completion(
|
|||
**optional_params
|
||||
}
|
||||
else:
|
||||
# Convert messages to prompt
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_details.get("bos_token", ""),
|
||||
eos_token=model_prompt_details.get("eos_token", ""),
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
|
@ -245,7 +256,7 @@ def completion(
|
|||
input=prompt,
|
||||
api_key="",
|
||||
original_response=result,
|
||||
additional_args={"complete_input_dict": input_data,"logs": logs},
|
||||
additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
|
||||
)
|
||||
|
||||
print_verbose(f"raw model_response: {result}")
|
||||
|
|
|
@ -684,6 +684,11 @@ def completion(
|
|||
or "https://api.replicate.com/v1"
|
||||
)
|
||||
|
||||
custom_prompt_dict = (
|
||||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
|
||||
model_response = replicate.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -696,6 +701,7 @@ def completion(
|
|||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=replicate_key,
|
||||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
|
|
|
@ -99,8 +99,8 @@ class Router:
|
|||
|
||||
# default litellm args
|
||||
self.default_litellm_params = default_litellm_params
|
||||
self.default_litellm_params["timeout"] = timeout
|
||||
self.default_litellm_params["max_retries"] = 0
|
||||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
|
||||
|
||||
### HEALTH CHECK THREAD ###
|
||||
|
@ -204,7 +204,8 @@ class Router:
|
|||
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
|
||||
timeout = kwargs.get("request_timeout", self.timeout)
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||
# response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
|
||||
response = await self.async_function_with_fallbacks(**kwargs)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -402,7 +403,7 @@ class Router:
|
|||
# if the function call is successful, no exception will be raised and we'll break out of the loop
|
||||
response = await original_function(*args, **kwargs)
|
||||
return response
|
||||
except Exception as e:
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
original_exception = e
|
||||
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
|
||||
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
|
||||
|
@ -410,7 +411,10 @@ class Router:
|
|||
raise original_exception
|
||||
### RETRY
|
||||
#### check if it should retry + back-off if required
|
||||
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
||||
if isinstance(original_exception, asyncio.CancelledError):
|
||||
timeout = 0 # immediately retry
|
||||
await asyncio.sleep(timeout)
|
||||
elif hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
|
||||
if hasattr(original_exception.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
|
||||
else:
|
||||
|
@ -428,8 +432,11 @@ class Router:
|
|||
response = await response
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
if isinstance(original_exception, asyncio.CancelledError):
|
||||
timeout = 0 # immediately retry
|
||||
await asyncio.sleep(timeout)
|
||||
elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
|
||||
remaining_retries = num_retries - current_attempt
|
||||
if hasattr(e.response, "headers"):
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)
|
||||
|
|
|
@ -630,7 +630,7 @@ def test_re_use_openaiClient():
|
|||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail("got Exception", e)
|
||||
test_re_use_openaiClient()
|
||||
# test_re_use_openaiClient()
|
||||
|
||||
def test_completion_azure():
|
||||
try:
|
||||
|
@ -845,7 +845,7 @@ def test_completion_azure_deployment_id():
|
|||
|
||||
def test_completion_replicate_vicuna():
|
||||
print("TESTING REPLICATE")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b"
|
||||
try:
|
||||
response = completion(
|
||||
|
@ -898,6 +898,43 @@ def test_completion_replicate_llama2_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_replicate_llama2_stream()
|
||||
|
||||
def test_replicate_custom_prompt_dict():
|
||||
litellm.set_verbose = True
|
||||
model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0"
|
||||
litellm.register_prompt_template(
|
||||
model="replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0",
|
||||
initial_prompt_value="You are a good assistant", # [OPTIONAL]
|
||||
roles={
|
||||
"system": {
|
||||
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
|
||||
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
|
||||
},
|
||||
"user": {
|
||||
"pre_message": "[INST] ", # [OPTIONAL]
|
||||
"post_message": " [/INST]" # [OPTIONAL]
|
||||
},
|
||||
"assistant": {
|
||||
"pre_message": "\n", # [OPTIONAL]
|
||||
"post_message": "\n" # [OPTIONAL]
|
||||
}
|
||||
},
|
||||
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
|
||||
)
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is yc write 1 paragraph",
|
||||
}
|
||||
],
|
||||
num_retries=3
|
||||
)
|
||||
print(f"response: {response}")
|
||||
litellm.custom_prompt_dict = {} # reset
|
||||
|
||||
test_replicate_custom_prompt_dict()
|
||||
|
||||
# commenthing this out since we won't be always testing a custom replicate deployment
|
||||
# def test_completion_replicate_deployments():
|
||||
# print("TESTING REPLICATE")
|
||||
|
|
|
@ -1,80 +1,116 @@
|
|||
# #### What this tests ####
|
||||
# # This profiles a router call to find where calls are taking the most time.
|
||||
#### What this tests ####
|
||||
# This profiles a router call to find where calls are taking the most time.
|
||||
|
||||
# import sys, os, time, logging
|
||||
# import traceback, asyncio, uuid
|
||||
# import pytest
|
||||
# import cProfile
|
||||
# from pstats import Stats
|
||||
# sys.path.insert(
|
||||
# 0, os.path.abspath("../..")
|
||||
# ) # Adds the parent directory to the system path
|
||||
# import litellm
|
||||
# from litellm import Router
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
# from dotenv import load_dotenv
|
||||
# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
|
||||
import sys, os, time, logging
|
||||
import traceback, asyncio, uuid
|
||||
import pytest
|
||||
import cProfile
|
||||
from pstats import Stats
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dotenv import load_dotenv
|
||||
from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
|
||||
|
||||
# load_dotenv()
|
||||
load_dotenv()
|
||||
|
||||
# logging.basicConfig(
|
||||
# level=logging.DEBUG,
|
||||
# format='%(asctime)s %(levelname)s: %(message)s',
|
||||
# datefmt='%I:%M:%S %p',
|
||||
# filename='aiologs.log', # Name of the log file where logs will be written
|
||||
# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
|
||||
# )
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s: %(message)s',
|
||||
datefmt='%I:%M:%S %p',
|
||||
filename='aiologs.log', # Name of the log file where logs will be written
|
||||
filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
|
||||
)
|
||||
|
||||
|
||||
# model_list = [{
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-turbo",
|
||||
# "api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
# "api_base": "https://openai-france-1234.openai.azure.com",
|
||||
# "rpm": 1440,
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }, {
|
||||
# "model_name": "azure-model",
|
||||
# "litellm_params": {
|
||||
# "model": "azure/gpt-35-turbo",
|
||||
# "api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
# "rpm": 6
|
||||
# }
|
||||
# }]
|
||||
model_list = [{
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-turbo",
|
||||
"api_key": "os.environ/AZURE_FRANCE_API_KEY",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com",
|
||||
"rpm": 1440,
|
||||
}
|
||||
}, {
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_EUROPE_API_KEY",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
|
||||
"rpm": 6
|
||||
}
|
||||
}, {
|
||||
"model_name": "azure-model",
|
||||
"litellm_params": {
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_key": "os.environ/AZURE_CANADA_API_KEY",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
|
||||
"rpm": 6
|
||||
}
|
||||
}]
|
||||
|
||||
# router = Router(model_list=model_list, set_verbose=False, num_retries=3)
|
||||
router = Router(model_list=model_list, set_verbose=False, num_retries=3)
|
||||
|
||||
# async def router_completion():
|
||||
# try:
|
||||
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||
# response = await router.acompletion(model="azure-model", messages=messages)
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# print(e, file=sys.stderr)
|
||||
# traceback.print_exc()
|
||||
# return None
|
||||
async def router_completion():
|
||||
try:
|
||||
messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
|
||||
response = await router.acompletion(model="azure-model", messages=messages)
|
||||
return response
|
||||
except asyncio.exceptions.CancelledError:
|
||||
print("Task was cancelled")
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# async def loadtest_fn():
|
||||
# start = time.time()
|
||||
# n = 1000
|
||||
# tasks = [router_completion() for _ in range(n)]
|
||||
# chat_completions = await asyncio.gather(*tasks)
|
||||
# successful_completions = [c for c in chat_completions if c is not None]
|
||||
# print(n, time.time() - start, len(successful_completions))
|
||||
async def loadtest_fn(n = 1000):
|
||||
start = time.time()
|
||||
tasks = [router_completion() for _ in range(n)]
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
print(n, time.time() - start, len(successful_completions))
|
||||
|
||||
# loop = asyncio.get_event_loop()
|
||||
# loop.set_debug(True)
|
||||
# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds
|
||||
|
||||
# # Excute the load testing function within the asyncio event loop
|
||||
# loop.run_until_complete(loadtest_fn())
|
||||
# loop.run_until_complete(loadtest_fn())
|
||||
|
||||
### SUSTAINED LOAD TESTS ###
|
||||
import time, asyncio
|
||||
async def make_requests(n):
|
||||
tasks = [router_completion() for _ in range(n)]
|
||||
print(f"num tasks: {len(tasks)}")
|
||||
chat_completions = await asyncio.gather(*tasks)
|
||||
successful_completions = [c for c in chat_completions if c is not None]
|
||||
print(f"successful_completions: {len(successful_completions)}")
|
||||
return successful_completions
|
||||
|
||||
async def main():
|
||||
start_time = time.time()
|
||||
total_successful_requests = 0
|
||||
request_limit = 1000
|
||||
batches = 2 # batches of 1k requests
|
||||
start = time.time()
|
||||
tasks = [] # list to hold all tasks
|
||||
|
||||
async def request_loop():
|
||||
nonlocal tasks
|
||||
for _ in range(batches):
|
||||
# Make 1,000 requests
|
||||
task = asyncio.create_task(make_requests(request_limit))
|
||||
tasks.append(task)
|
||||
|
||||
# Introduce a delay to achieve 1,000 requests per second
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await request_loop()
|
||||
results = await asyncio.gather(*tasks)
|
||||
total_successful_requests = sum(len(res) for res in results)
|
||||
|
||||
print(request_limit*batches, time.time() - start, total_successful_requests)
|
||||
|
||||
asyncio.run(main())
|
Loading…
Add table
Add a link
Reference in a new issue