fix(replicate.py): fix custom prompt formatting

This commit is contained in:
Krrish Dholakia 2023-11-29 19:44:02 -08:00
parent c05da0797b
commit 1f5a1122fc
5 changed files with 177 additions and 80 deletions

View file

@ -6,6 +6,7 @@ from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx import httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
class ReplicateError(Exception): class ReplicateError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
@ -186,6 +187,7 @@ def completion(
logging_obj, logging_obj,
api_key, api_key,
encoding, encoding,
custom_prompt_dict={},
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -213,10 +215,19 @@ def completion(
**optional_params **optional_params
} }
else: else:
# Convert messages to prompt if model in custom_prompt_dict:
prompt = "" # check if the model has a registered custom prompt
for message in messages: model_prompt_details = custom_prompt_dict[model]
prompt += message["content"] prompt = custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
bos_token=model_prompt_details.get("bos_token", ""),
eos_token=model_prompt_details.get("eos_token", ""),
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
input_data = { input_data = {
"prompt": prompt, "prompt": prompt,
@ -245,7 +256,7 @@ def completion(
input=prompt, input=prompt,
api_key="", api_key="",
original_response=result, original_response=result,
additional_args={"complete_input_dict": input_data,"logs": logs}, additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, },
) )
print_verbose(f"raw model_response: {result}") print_verbose(f"raw model_response: {result}")

View file

@ -684,6 +684,11 @@ def completion(
or "https://api.replicate.com/v1" or "https://api.replicate.com/v1"
) )
custom_prompt_dict = (
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = replicate.completion( model_response = replicate.completion(
model=model, model=model,
messages=messages, messages=messages,
@ -696,6 +701,7 @@ def completion(
encoding=encoding, # for calculating input/output tokens encoding=encoding, # for calculating input/output tokens
api_key=replicate_key, api_key=replicate_key,
logging_obj=logging, logging_obj=logging,
custom_prompt_dict=custom_prompt_dict
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object, # don't try to access stream object,

View file

@ -99,8 +99,8 @@ class Router:
# default litellm args # default litellm args
self.default_litellm_params = default_litellm_params self.default_litellm_params = default_litellm_params
self.default_litellm_params["timeout"] = timeout self.default_litellm_params.setdefault("timeout", timeout)
self.default_litellm_params["max_retries"] = 0 self.default_litellm_params.setdefault("max_retries", 0)
### HEALTH CHECK THREAD ### ### HEALTH CHECK THREAD ###
@ -204,7 +204,8 @@ class Router:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout) timeout = kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model}) kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout) # response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout)
response = await self.async_function_with_fallbacks(**kwargs)
return response return response
except Exception as e: except Exception as e:
@ -402,7 +403,7 @@ class Router:
# if the function call is successful, no exception will be raised and we'll break out of the loop # if the function call is successful, no exception will be raised and we'll break out of the loop
response = await original_function(*args, **kwargs) response = await original_function(*args, **kwargs)
return response return response
except Exception as e: except (Exception, asyncio.CancelledError) as e:
original_exception = e original_exception = e
### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available
if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None)
@ -410,7 +411,10 @@ class Router:
raise original_exception raise original_exception
### RETRY ### RETRY
#### check if it should retry + back-off if required #### check if it should retry + back-off if required
if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code): if isinstance(original_exception, asyncio.CancelledError):
timeout = 0 # immediately retry
await asyncio.sleep(timeout)
elif hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code):
if hasattr(original_exception.response, "headers"): if hasattr(original_exception.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers) timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers)
else: else:
@ -428,8 +432,11 @@ class Router:
response = await response response = await response
return response return response
except Exception as e: except (Exception, asyncio.CancelledError) as e:
if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): if isinstance(original_exception, asyncio.CancelledError):
timeout = 0 # immediately retry
await asyncio.sleep(timeout)
elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code):
remaining_retries = num_retries - current_attempt remaining_retries = num_retries - current_attempt
if hasattr(e.response, "headers"): if hasattr(e.response, "headers"):
timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers) timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers)

View file

@ -630,7 +630,7 @@ def test_re_use_openaiClient():
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail("got Exception", e) pytest.fail("got Exception", e)
test_re_use_openaiClient() # test_re_use_openaiClient()
def test_completion_azure(): def test_completion_azure():
try: try:
@ -845,7 +845,7 @@ def test_completion_azure_deployment_id():
def test_completion_replicate_vicuna(): def test_completion_replicate_vicuna():
print("TESTING REPLICATE") print("TESTING REPLICATE")
litellm.set_verbose=False litellm.set_verbose=True
model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b" model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b"
try: try:
response = completion( response = completion(
@ -898,6 +898,43 @@ def test_completion_replicate_llama2_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama2_stream() # test_completion_replicate_llama2_stream()
def test_replicate_custom_prompt_dict():
litellm.set_verbose = True
model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0"
litellm.register_prompt_template(
model="replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0",
initial_prompt_value="You are a good assistant", # [OPTIONAL]
roles={
"system": {
"pre_message": "[INST] <<SYS>>\n", # [OPTIONAL]
"post_message": "\n<</SYS>>\n [/INST]\n" # [OPTIONAL]
},
"user": {
"pre_message": "[INST] ", # [OPTIONAL]
"post_message": " [/INST]" # [OPTIONAL]
},
"assistant": {
"pre_message": "\n", # [OPTIONAL]
"post_message": "\n" # [OPTIONAL]
}
},
final_prompt_value="Now answer as best you can:" # [OPTIONAL]
)
response = completion(
model=model_name,
messages=[
{
"role": "user",
"content": "what is yc write 1 paragraph",
}
],
num_retries=3
)
print(f"response: {response}")
litellm.custom_prompt_dict = {} # reset
test_replicate_custom_prompt_dict()
# commenthing this out since we won't be always testing a custom replicate deployment # commenthing this out since we won't be always testing a custom replicate deployment
# def test_completion_replicate_deployments(): # def test_completion_replicate_deployments():
# print("TESTING REPLICATE") # print("TESTING REPLICATE")

View file

@ -1,80 +1,116 @@
# #### What this tests #### #### What this tests ####
# # This profiles a router call to find where calls are taking the most time. # This profiles a router call to find where calls are taking the most time.
# import sys, os, time, logging import sys, os, time, logging
# import traceback, asyncio, uuid import traceback, asyncio, uuid
# import pytest import pytest
# import cProfile import cProfile
# from pstats import Stats from pstats import Stats
# sys.path.insert( sys.path.insert(
# 0, os.path.abspath("../..") 0, os.path.abspath("../..")
# ) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
# import litellm import litellm
# from litellm import Router from litellm import Router
# from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
# from dotenv import load_dotenv from dotenv import load_dotenv
# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks
# load_dotenv() load_dotenv()
# logging.basicConfig( logging.basicConfig(
# level=logging.DEBUG, level=logging.DEBUG,
# format='%(asctime)s %(levelname)s: %(message)s', format='%(asctime)s %(levelname)s: %(message)s',
# datefmt='%I:%M:%S %p', datefmt='%I:%M:%S %p',
# filename='aiologs.log', # Name of the log file where logs will be written filename='aiologs.log', # Name of the log file where logs will be written
# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append
# ) )
# model_list = [{ model_list = [{
# "model_name": "azure-model", "model_name": "azure-model",
# "litellm_params": { "litellm_params": {
# "model": "azure/gpt-turbo", "model": "azure/gpt-turbo",
# "api_key": "os.environ/AZURE_FRANCE_API_KEY", "api_key": "os.environ/AZURE_FRANCE_API_KEY",
# "api_base": "https://openai-france-1234.openai.azure.com", "api_base": "https://openai-france-1234.openai.azure.com",
# "rpm": 1440, "rpm": 1440,
# } }
# }, { }, {
# "model_name": "azure-model", "model_name": "azure-model",
# "litellm_params": { "litellm_params": {
# "model": "azure/gpt-35-turbo", "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_EUROPE_API_KEY", "api_key": "os.environ/AZURE_EUROPE_API_KEY",
# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com",
# "rpm": 6 "rpm": 6
# } }
# }, { }, {
# "model_name": "azure-model", "model_name": "azure-model",
# "litellm_params": { "litellm_params": {
# "model": "azure/gpt-35-turbo", "model": "azure/gpt-35-turbo",
# "api_key": "os.environ/AZURE_CANADA_API_KEY", "api_key": "os.environ/AZURE_CANADA_API_KEY",
# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", "api_base": "https://my-endpoint-canada-berri992.openai.azure.com",
# "rpm": 6 "rpm": 6
# } }
# }] }]
# router = Router(model_list=model_list, set_verbose=False, num_retries=3) router = Router(model_list=model_list, set_verbose=False, num_retries=3)
# async def router_completion(): async def router_completion():
# try: try:
# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}]
# response = await router.acompletion(model="azure-model", messages=messages) response = await router.acompletion(model="azure-model", messages=messages)
# return response return response
# except Exception as e: except asyncio.exceptions.CancelledError:
# print(e, file=sys.stderr) print("Task was cancelled")
# traceback.print_exc() return None
# return None except Exception as e:
return None
# async def loadtest_fn(): async def loadtest_fn(n = 1000):
# start = time.time() start = time.time()
# n = 1000 tasks = [router_completion() for _ in range(n)]
# tasks = [router_completion() for _ in range(n)] chat_completions = await asyncio.gather(*tasks)
# chat_completions = await asyncio.gather(*tasks) successful_completions = [c for c in chat_completions if c is not None]
# successful_completions = [c for c in chat_completions if c is not None] print(n, time.time() - start, len(successful_completions))
# print(n, time.time() - start, len(successful_completions))
# loop = asyncio.get_event_loop() # loop = asyncio.get_event_loop()
# loop.set_debug(True) # loop.set_debug(True)
# log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds # log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds
# # Excute the load testing function within the asyncio event loop # # Excute the load testing function within the asyncio event loop
# loop.run_until_complete(loadtest_fn()) # loop.run_until_complete(loadtest_fn())
### SUSTAINED LOAD TESTS ###
import time, asyncio
async def make_requests(n):
tasks = [router_completion() for _ in range(n)]
print(f"num tasks: {len(tasks)}")
chat_completions = await asyncio.gather(*tasks)
successful_completions = [c for c in chat_completions if c is not None]
print(f"successful_completions: {len(successful_completions)}")
return successful_completions
async def main():
start_time = time.time()
total_successful_requests = 0
request_limit = 1000
batches = 2 # batches of 1k requests
start = time.time()
tasks = [] # list to hold all tasks
async def request_loop():
nonlocal tasks
for _ in range(batches):
# Make 1,000 requests
task = asyncio.create_task(make_requests(request_limit))
tasks.append(task)
# Introduce a delay to achieve 1,000 requests per second
await asyncio.sleep(1)
await request_loop()
results = await asyncio.gather(*tasks)
total_successful_requests = sum(len(res) for res in results)
print(request_limit*batches, time.time() - start, total_successful_requests)
asyncio.run(main())