From cf0a9f591ccf6a6eca1eeaa8aa4e31e2f94e48af Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 17 Nov 2023 17:56:09 -0800 Subject: [PATCH] fix(router.py): introducing usage-based-routing --- litellm/router.py | 217 ++++++++++++++++++----------------- litellm/tests/test_router.py | 36 +++--- litellm/utils.py | 10 -- 3 files changed, 133 insertions(+), 130 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 2e078f0761..02096d81ba 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4,6 +4,7 @@ import random, threading, time import litellm, openai import logging, asyncio import inspect +from openai import AsyncOpenAI class Router: """ @@ -36,7 +37,7 @@ class Router: num_retries: int = 0, timeout: float = 600, default_litellm_params = {}, # default params for Router.chat.completion.create - routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: + routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing"] = "simple-shuffle") -> None: if model_list: self.set_model_list(model_list) @@ -69,8 +70,9 @@ class Router: if cache_responses: litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests self.cache_responses = cache_responses - - + self.cache = litellm.Cache(cache_config) # use Redis for tracking load balancing + ## USAGE TRACKING ## + litellm.success_callback = [self.deployment_callback] def _start_health_check_thread(self): """ @@ -138,6 +140,8 @@ class Router: potential_deployments.append(item) item = random.choice(potential_deployments) return item or item[0] + elif self.routing_strategy == "usage-based-routing": + return self.get_usage_based_available_deployment(model=model, messages=messages, input=input) raise ValueError("No models available.") @@ -242,6 +246,9 @@ class Router: if k not in data: # prioritize model-specific params > default router params data[k] = v response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs}) + # client = AsyncOpenAI() + # print(f"MAKING OPENAI CALL") + # response = await client.chat.completions.create(model=model, messages=messages) return response except Exception as e: if self.num_retries > 0: @@ -301,119 +308,117 @@ class Router: data[k] = v return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs}) - # def deployment_callback( - # self, - # kwargs, # kwargs to completion - # completion_response, # response from completion - # start_time, end_time # start/end time - # ): - # """ - # Function LiteLLM submits a callback to after a successful - # completion. Purpose of this is ti update TPM/RPM usage per model - # """ - # model_name = kwargs.get('model', None) # i.e. gpt35turbo - # custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure - # if custom_llm_provider: - # model_name = f"{custom_llm_provider}/{model_name}" - # total_tokens = completion_response['usage']['total_tokens'] - # self._set_deployment_usage(model_name, total_tokens) + def deployment_callback( + self, + kwargs, # kwargs to completion + completion_response, # response from completion + start_time, end_time # start/end time + ): + """ + Function LiteLLM submits a callback to after a successful + completion. Purpose of this is to update TPM/RPM usage per model + """ + model_name = kwargs.get('model', None) # i.e. gpt35turbo + custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure + if custom_llm_provider: + model_name = f"{custom_llm_provider}/{model_name}" + total_tokens = completion_response['usage']['total_tokens'] + self._set_deployment_usage(model_name, total_tokens) - # def get_available_deployment(self, - # model: str, - # messages: Optional[List[Dict[str, str]]] = None, - # input: Optional[Union[str, List]] = None): - # """ - # Returns a deployment with the lowest TPM/RPM usage. - # """ - # # get list of potential deployments - # potential_deployments = [] - # for item in self.model_list: - # if item["model_name"] == model: - # potential_deployments.append(item) + def get_usage_based_available_deployment(self, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + input: Optional[Union[str, List]] = None): + """ + Returns a deployment with the lowest TPM/RPM usage. + """ + # get list of potential deployments + potential_deployments = [] + for item in self.model_list: + if item["model_name"] == model: + potential_deployments.append(item) - # # set first model as current model to calculate token count - # deployment = potential_deployments[0] + # get current call usage + token_count = 0 + if messages is not None: + token_count = litellm.token_counter(model=model, messages=messages) + elif input is not None: + if isinstance(input, List): + input_text = "".join(text for text in input) + else: + input_text = input + token_count = litellm.token_counter(model=model, text=input_text) - # # get encoding - # token_count = 0 - # if messages is not None: - # token_count = litellm.token_counter(model=deployment["model_name"], messages=messages) - # elif input is not None: - # if isinstance(input, List): - # input_text = "".join(text for text in input) - # else: - # input_text = input - # token_count = litellm.token_counter(model=deployment["model_name"], text=input_text) + # ----------------------- + # Find lowest used model + # ---------------------- + lowest_tpm = float("inf") + deployment = None - # # ----------------------- - # # Find lowest used model - # # ---------------------- - # lowest_tpm = float("inf") - # deployment = None + # return deployment with lowest tpm usage + for item in potential_deployments: + item_tpm, item_rpm = self._get_deployment_usage(deployment_name=item["litellm_params"]["model"]) - # # Go through all the models to get tpm, rpm - # for item in potential_deployments: - # item_tpm, item_rpm = self._get_deployment_usage(deployment_name=item["litellm_params"]["model"]) + if item_tpm == 0: + return item + elif ("tpm" in item and item_tpm + token_count > item["tpm"] + or "rpm" in item and item_rpm + 1 >= item["rpm"]): # if user passed in tpm / rpm in the model_list + continue + elif item_tpm < lowest_tpm: + lowest_tpm = item_tpm + deployment = item - # if item_tpm == 0: - # return item - # elif item_tpm + token_count > item["tpm"] or item_rpm + 1 >= item["rpm"]: - # continue - # elif item_tpm < lowest_tpm: - # lowest_tpm = item_tpm - # deployment = item + # if none, raise exception + if deployment is None: + raise ValueError("No models available.") - # # if none, raise exception - # if deployment is None: - # raise ValueError("No models available.") + # return model + return deployment - # # return model - # return deployment + def _get_deployment_usage( + self, + deployment_name: str + ): + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f'{deployment_name}:tpm:{current_minute}' + rpm_key = f'{deployment_name}:rpm:{current_minute}' - # def _get_deployment_usage( - # self, - # deployment_name: str - # ): - # # ------------ - # # Setup values - # # ------------ - # current_minute = datetime.now().strftime("%H-%M") - # tpm_key = f'{deployment_name}:tpm:{current_minute}' - # rpm_key = f'{deployment_name}:rpm:{current_minute}' + # ------------ + # Return usage + # ------------ + tpm = self.cache.get_cache(cache_key=tpm_key) or 0 + rpm = self.cache.get_cache(cache_key=rpm_key) or 0 - # # ------------ - # # Return usage - # # ------------ - # tpm = self.cache.get_cache(cache_key=tpm_key) or 0 - # rpm = self.cache.get_cache(cache_key=rpm_key) or 0 + return int(tpm), int(rpm) - # return int(tpm), int(rpm) + def increment(self, key: str, increment_value: int): + # get value + cached_value = self.cache.get_cache(cache_key=key) + # update value + try: + cached_value = cached_value + increment_value + except: + cached_value = increment_value + # save updated value + self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds) - # def increment(self, key: str, increment_value: int): - # # get value - # cached_value = self.cache.get_cache(cache_key=key) - # # update value - # try: - # cached_value = cached_value + increment_value - # except: - # cached_value = increment_value - # # save updated value - # self.cache.add_cache(result=cached_value, cache_key=key, ttl=self.default_cache_time_seconds) + def _set_deployment_usage( + self, + model_name: str, + total_tokens: int + ): + # ------------ + # Setup values + # ------------ + current_minute = datetime.now().strftime("%H-%M") + tpm_key = f'{model_name}:tpm:{current_minute}' + rpm_key = f'{model_name}:rpm:{current_minute}' - # def _set_deployment_usage( - # self, - # model_name: str, - # total_tokens: int - # ): - # # ------------ - # # Setup values - # # ------------ - # current_minute = datetime.now().strftime("%H-%M") - # tpm_key = f'{model_name}:tpm:{current_minute}' - # rpm_key = f'{model_name}:rpm:{current_minute}' - - # # ------------ - # # Update usage - # # ------------ - # self.increment(tpm_key, total_tokens) - # self.increment(rpm_key, 1) \ No newline at end of file + # ------------ + # Update usage + # ------------ + self.increment(tpm_key, total_tokens) + self.increment(rpm_key, 1) \ No newline at end of file diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 9508054437..847827746c 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -239,14 +239,31 @@ def test_acompletion_on_router(): "tpm": 100000, "rpm": 10000, }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "api_version": os.getenv("AZURE_API_VERSION") + }, + "tpm": 100000, + "rpm": 10000, + } ] messages = [ - {"role": "user", "content": "What is the weather like in SF?"} + {"role": "user", "content": "What is the weather like in Boston?"} ] - + start_time = time.time() async def get_response(): - router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10) + router = Router(model_list=model_list, + redis_host=os.environ["REDIS_HOST"], + redis_password=os.environ["REDIS_PASSWORD"], + redis_port=os.environ["REDIS_PORT"], + cache_responses=True, + timeout=30, + routing_strategy="usage-based-routing") response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) print(f"response1: {response1}") response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) @@ -254,6 +271,8 @@ def test_acompletion_on_router(): assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"] asyncio.run(get_response()) except litellm.Timeout as e: + end_time = time.time() + print(f"timeout error occurred: {end_time - start_time}") pass except Exception as e: traceback.print_exc() @@ -304,17 +323,6 @@ def test_function_calling_on_router(): ] response = router.completion(model="gpt-3.5-turbo", messages=messages, functions=function1) print(f"final returned response: {response}") - # async def get_response(): - # messages=[ - # { - # "role": "user", - # "content": "what's the weather in boston" - # } - # ], - # response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1) - # print(f"response1: {response1}") - # return response - # response = asyncio.run(get_response()) assert isinstance(response["choices"][0]["message"]["function_call"], dict) except Exception as e: print(f"An exception occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index f1ccdbfc1a..be5fc86caf 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1069,16 +1069,6 @@ def client(original_function): try: global callback_list, add_breadcrumb, user_logger_fn, Logging function_id = kwargs["id"] if "id" in kwargs else None - if litellm.client_session is None: - litellm.client_session = httpx.Client( - limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), - timeout = None - ) - if litellm.aclient_session is None: - litellm.aclient_session = httpx.AsyncClient( - limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), - timeout = None - ) if litellm.use_client or ("use_client" in kwargs and kwargs["use_client"] == True): print_verbose(f"litedebugger initialized") if "lite_debugger" not in litellm.input_callback: