diff --git a/litellm/llms/replicate.py b/litellm/llms/replicate.py index 3bcf5f34d..d639a8d1e 100644 --- a/litellm/llms/replicate.py +++ b/litellm/llms/replicate.py @@ -6,6 +6,7 @@ from typing import Callable, Optional from litellm.utils import ModelResponse, Usage import litellm import httpx +from .prompt_templates.factory import prompt_factory, custom_prompt class ReplicateError(Exception): def __init__(self, status_code, message): @@ -186,6 +187,7 @@ def completion( logging_obj, api_key, encoding, + custom_prompt_dict={}, optional_params=None, litellm_params=None, logger_fn=None, @@ -213,10 +215,19 @@ def completion( **optional_params } else: - # Convert messages to prompt - prompt = "" - for message in messages: - prompt += message["content"] + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", {}), + initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + bos_token=model_prompt_details.get("bos_token", ""), + eos_token=model_prompt_details.get("eos_token", ""), + messages=messages, + ) + else: + prompt = prompt_factory(model=model, messages=messages) input_data = { "prompt": prompt, @@ -245,7 +256,7 @@ def completion( input=prompt, api_key="", original_response=result, - additional_args={"complete_input_dict": input_data,"logs": logs}, + additional_args={"complete_input_dict": input_data,"logs": logs, "api_base": prediction_url, }, ) print_verbose(f"raw model_response: {result}") diff --git a/litellm/main.py b/litellm/main.py index ccac102b3..e98d88656 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -684,6 +684,11 @@ def completion( or "https://api.replicate.com/v1" ) + custom_prompt_dict = ( + custom_prompt_dict + or litellm.custom_prompt_dict + ) + model_response = replicate.completion( model=model, messages=messages, @@ -696,6 +701,7 @@ def completion( encoding=encoding, # for calculating input/output tokens api_key=replicate_key, logging_obj=logging, + custom_prompt_dict=custom_prompt_dict ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, diff --git a/litellm/router.py b/litellm/router.py index 1798f75b9..a10f300f2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -99,8 +99,8 @@ class Router: # default litellm args self.default_litellm_params = default_litellm_params - self.default_litellm_params["timeout"] = timeout - self.default_litellm_params["max_retries"] = 0 + self.default_litellm_params.setdefault("timeout", timeout) + self.default_litellm_params.setdefault("max_retries", 0) ### HEALTH CHECK THREAD ### @@ -204,7 +204,8 @@ class Router: kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) - response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout) + # response = await asyncio.wait_for(self.async_function_with_fallbacks(**kwargs), timeout=timeout) + response = await self.async_function_with_fallbacks(**kwargs) return response except Exception as e: @@ -402,7 +403,7 @@ class Router: # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) return response - except Exception as e: + except (Exception, asyncio.CancelledError) as e: original_exception = e ### CHECK IF RATE LIMIT / CONTEXT WINDOW ERROR w/ fallbacks available if ((isinstance(original_exception, litellm.ContextWindowExceededError) and context_window_fallbacks is None) @@ -410,7 +411,10 @@ class Router: raise original_exception ### RETRY #### check if it should retry + back-off if required - if hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code): + if isinstance(original_exception, asyncio.CancelledError): + timeout = 0 # immediately retry + await asyncio.sleep(timeout) + elif hasattr(original_exception, "status_code") and hasattr(original_exception, "response") and litellm._should_retry(status_code=original_exception.status_code): if hasattr(original_exception.response, "headers"): timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=original_exception.response.headers) else: @@ -428,8 +432,11 @@ class Router: response = await response return response - except Exception as e: - if hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): + except (Exception, asyncio.CancelledError) as e: + if isinstance(original_exception, asyncio.CancelledError): + timeout = 0 # immediately retry + await asyncio.sleep(timeout) + elif hasattr(e, "status_code") and hasattr(e, "response") and litellm._should_retry(status_code=e.status_code): remaining_retries = num_retries - current_attempt if hasattr(e.response, "headers"): timeout = litellm._calculate_retry_after(remaining_retries=num_retries, max_retries=num_retries, response_headers=e.response.headers) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index ec4abf113..819f08e1b 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -630,7 +630,7 @@ def test_re_use_openaiClient(): print(f"response: {response}") except Exception as e: pytest.fail("got Exception", e) -test_re_use_openaiClient() +# test_re_use_openaiClient() def test_completion_azure(): try: @@ -845,7 +845,7 @@ def test_completion_azure_deployment_id(): def test_completion_replicate_vicuna(): print("TESTING REPLICATE") - litellm.set_verbose=False + litellm.set_verbose=True model_name = "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b" try: response = completion( @@ -898,6 +898,43 @@ def test_completion_replicate_llama2_stream(): pytest.fail(f"Error occurred: {e}") # test_completion_replicate_llama2_stream() +def test_replicate_custom_prompt_dict(): + litellm.set_verbose = True + model_name = "replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0" + litellm.register_prompt_template( + model="replicate/meta/llama-2-7b-chat:13c3cdee13ee059ab779f0291d29054dab00a47dad8261375654de5540165fb0", + initial_prompt_value="You are a good assistant", # [OPTIONAL] + roles={ + "system": { + "pre_message": "[INST] <>\n", # [OPTIONAL] + "post_message": "\n<>\n [/INST]\n" # [OPTIONAL] + }, + "user": { + "pre_message": "[INST] ", # [OPTIONAL] + "post_message": " [/INST]" # [OPTIONAL] + }, + "assistant": { + "pre_message": "\n", # [OPTIONAL] + "post_message": "\n" # [OPTIONAL] + } + }, + final_prompt_value="Now answer as best you can:" # [OPTIONAL] + ) + response = completion( + model=model_name, + messages=[ + { + "role": "user", + "content": "what is yc write 1 paragraph", + } + ], + num_retries=3 + ) + print(f"response: {response}") + litellm.custom_prompt_dict = {} # reset + +test_replicate_custom_prompt_dict() + # commenthing this out since we won't be always testing a custom replicate deployment # def test_completion_replicate_deployments(): # print("TESTING REPLICATE") diff --git a/litellm/tests/test_profiling_router.py b/litellm/tests/test_profiling_router.py index d9c6590ad..d922b816b 100644 --- a/litellm/tests/test_profiling_router.py +++ b/litellm/tests/test_profiling_router.py @@ -1,80 +1,116 @@ -# #### What this tests #### -# # This profiles a router call to find where calls are taking the most time. +#### What this tests #### +# This profiles a router call to find where calls are taking the most time. -# import sys, os, time, logging -# import traceback, asyncio, uuid -# import pytest -# import cProfile -# from pstats import Stats -# sys.path.insert( -# 0, os.path.abspath("../..") -# ) # Adds the parent directory to the system path -# import litellm -# from litellm import Router -# from concurrent.futures import ThreadPoolExecutor -# from dotenv import load_dotenv -# from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks +import sys, os, time, logging +import traceback, asyncio, uuid +import pytest +import cProfile +from pstats import Stats +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm import Router +from concurrent.futures import ThreadPoolExecutor +from dotenv import load_dotenv +from aiodebug import log_slow_callbacks # Import the aiodebug utility for logging slow callbacks -# load_dotenv() +load_dotenv() -# logging.basicConfig( -# level=logging.DEBUG, -# format='%(asctime)s %(levelname)s: %(message)s', -# datefmt='%I:%M:%S %p', -# filename='aiologs.log', # Name of the log file where logs will be written -# filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append -# ) +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s %(levelname)s: %(message)s', + datefmt='%I:%M:%S %p', + filename='aiologs.log', # Name of the log file where logs will be written + filemode='w' # 'w' to overwrite the log file on each run, use 'a' to append +) -# model_list = [{ -# "model_name": "azure-model", -# "litellm_params": { -# "model": "azure/gpt-turbo", -# "api_key": "os.environ/AZURE_FRANCE_API_KEY", -# "api_base": "https://openai-france-1234.openai.azure.com", -# "rpm": 1440, -# } -# }, { -# "model_name": "azure-model", -# "litellm_params": { -# "model": "azure/gpt-35-turbo", -# "api_key": "os.environ/AZURE_EUROPE_API_KEY", -# "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", -# "rpm": 6 -# } -# }, { -# "model_name": "azure-model", -# "litellm_params": { -# "model": "azure/gpt-35-turbo", -# "api_key": "os.environ/AZURE_CANADA_API_KEY", -# "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", -# "rpm": 6 -# } -# }] +model_list = [{ + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + } +}, { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6 + } +}, { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_CANADA_API_KEY", + "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", + "rpm": 6 + } +}] -# router = Router(model_list=model_list, set_verbose=False, num_retries=3) +router = Router(model_list=model_list, set_verbose=False, num_retries=3) -# async def router_completion(): -# try: -# messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] -# response = await router.acompletion(model="azure-model", messages=messages) -# return response -# except Exception as e: -# print(e, file=sys.stderr) -# traceback.print_exc() -# return None +async def router_completion(): + try: + messages=[{"role": "user", "content": f"This is a test: {uuid.uuid4()}"}] + response = await router.acompletion(model="azure-model", messages=messages) + return response + except asyncio.exceptions.CancelledError: + print("Task was cancelled") + return None + except Exception as e: + return None -# async def loadtest_fn(): -# start = time.time() -# n = 1000 -# tasks = [router_completion() for _ in range(n)] -# chat_completions = await asyncio.gather(*tasks) -# successful_completions = [c for c in chat_completions if c is not None] -# print(n, time.time() - start, len(successful_completions)) +async def loadtest_fn(n = 1000): + start = time.time() + tasks = [router_completion() for _ in range(n)] + chat_completions = await asyncio.gather(*tasks) + successful_completions = [c for c in chat_completions if c is not None] + print(n, time.time() - start, len(successful_completions)) # loop = asyncio.get_event_loop() # loop.set_debug(True) # log_slow_callbacks.enable(0.05) # Log callbacks slower than 0.05 seconds # # Excute the load testing function within the asyncio event loop -# loop.run_until_complete(loadtest_fn()) \ No newline at end of file +# loop.run_until_complete(loadtest_fn()) + +### SUSTAINED LOAD TESTS ### +import time, asyncio +async def make_requests(n): + tasks = [router_completion() for _ in range(n)] + print(f"num tasks: {len(tasks)}") + chat_completions = await asyncio.gather(*tasks) + successful_completions = [c for c in chat_completions if c is not None] + print(f"successful_completions: {len(successful_completions)}") + return successful_completions + +async def main(): + start_time = time.time() + total_successful_requests = 0 + request_limit = 1000 + batches = 2 # batches of 1k requests + start = time.time() + tasks = [] # list to hold all tasks + + async def request_loop(): + nonlocal tasks + for _ in range(batches): + # Make 1,000 requests + task = asyncio.create_task(make_requests(request_limit)) + tasks.append(task) + + # Introduce a delay to achieve 1,000 requests per second + await asyncio.sleep(1) + + await request_loop() + results = await asyncio.gather(*tasks) + total_successful_requests = sum(len(res) for res in results) + + print(request_limit*batches, time.time() - start, total_successful_requests) + +asyncio.run(main()) \ No newline at end of file