diff --git a/litellm/main.py b/litellm/main.py index 0ee123b64..ffb273946 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -212,15 +212,15 @@ async def acompletion(*args, **kwargs): else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - return _async_streaming( - response=response, - model=model, - custom_llm_provider=custom_llm_provider, - args=args, - ) - else: - return response + # if kwargs.get("stream", False): # return an async generator + # return _async_streaming( + # response=response, + # model=model, + # custom_llm_provider=custom_llm_provider, + # args=args, + # ) + # else: + return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c130625b1..aa820379e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -86,6 +86,7 @@ from fastapi import ( Depends, BackgroundTasks, Header, + Response, ) from fastapi.routing import APIRouter from fastapi.security import OAuth2PasswordBearer @@ -1068,6 +1069,7 @@ def model_list(): ) async def completion( request: Request, + fastapi_response: Response, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks(), @@ -1143,17 +1145,23 @@ async def completion( else: # router is not set response = await litellm.atext_completion(**data) + model_id = response._hidden_params.get("model_id", None) or "" + print(f"final response: {response}") if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses + custom_headers = {"x-litellm-model-id": model_id} return StreamingResponse( async_data_generator( - user_api_key_dict=user_api_key_dict, response=response + user_api_key_dict=user_api_key_dict, + response=response, ), media_type="text/event-stream", + headers=custom_headers, ) + fastapi_response.headers["x-litellm-model-id"] = model_id return response except Exception as e: print(f"EXCEPTION RAISED IN PROXY MAIN.PY") @@ -1187,6 +1195,7 @@ async def completion( ) # azure compatible endpoint async def chat_completion( request: Request, + fastapi_response: Response, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks(), @@ -1282,19 +1291,24 @@ async def chat_completion( else: # router is not set response = await litellm.acompletion(**data) - print(f"final response: {response}") + model_id = response._hidden_params.get("model_id", None) or "" if ( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses + custom_headers = {"x-litellm-model-id": model_id} return StreamingResponse( async_data_generator( - user_api_key_dict=user_api_key_dict, response=response + user_api_key_dict=user_api_key_dict, + response=response, ), media_type="text/event-stream", + headers=custom_headers, ) + fastapi_response.headers["x-litellm-model-id"] = model_id return response except Exception as e: + traceback.print_exc() await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e ) diff --git a/litellm/router.py b/litellm/router.py index 02eb91019..a181ff515 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -191,7 +191,9 @@ class Router: ) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ### ROUTING SETUP ### if routing_strategy == "least-busy": - self.leastbusy_logger = LeastBusyLoggingHandler(router_cache=self.cache) + self.leastbusy_logger = LeastBusyLoggingHandler( + router_cache=self.cache, model_list=self.model_list + ) ## add callback if isinstance(litellm.input_callback, list): litellm.input_callback.append(self.leastbusy_logger) # type: ignore @@ -506,7 +508,13 @@ class Router: **kwargs, ): try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._acompletion + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) + messages = [{"role": "user", "content": prompt}] # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment( @@ -530,7 +538,6 @@ class Router: if self.num_retries > 0: kwargs["model"] = model kwargs["messages"] = messages - kwargs["original_exception"] = e kwargs["original_function"] = self.completion return self.function_with_retries(**kwargs) else: @@ -546,16 +553,34 @@ class Router: **kwargs, ): try: + kwargs["model"] = model + kwargs["prompt"] = prompt + kwargs["original_function"] = self._atext_completion + kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) + timeout = kwargs.get("request_timeout", self.timeout) kwargs.setdefault("metadata", {}).update({"model_group": model}) - messages = [{"role": "user", "content": prompt}] - # pick the one that is available (lowest TPM/RPM) + response = await self.async_function_with_fallbacks(**kwargs) + + return response + except Exception as e: + raise e + + async def _atext_completion(self, model: str, prompt: str, **kwargs): + try: + self.print_verbose( + f"Inside _atext_completion()- model: {model}; kwargs: {kwargs}" + ) deployment = self.get_available_deployment( model=model, - messages=messages, + messages=[{"role": "user", "content": prompt}], specific_deployment=kwargs.pop("specific_deployment", None), ) - + kwargs.setdefault("metadata", {}).update( + {"deployment": deployment["litellm_params"]["model"]} + ) + kwargs["model_info"] = deployment.get("model_info", {}) data = deployment["litellm_params"].copy() + model_name = data["model"] for k, v in self.default_litellm_params.items(): if ( k not in kwargs @@ -564,27 +589,38 @@ class Router: elif k == "metadata": kwargs[k].update(v) - ########## remove -ModelID-XXXX from model ############## - original_model_string = data["model"] - # Find the index of "ModelID" in the string - index_of_model_id = original_model_string.find("-ModelID") - # Remove everything after "-ModelID" if it exists - if index_of_model_id != -1: - data["model"] = original_model_string[:index_of_model_id] + potential_model_client = self._get_client( + deployment=deployment, kwargs=kwargs, client_type="async" + ) + # check if provided keys == client keys # + dynamic_api_key = kwargs.get("api_key", None) + if ( + dynamic_api_key is not None + and potential_model_client is not None + and dynamic_api_key != potential_model_client.api_key + ): + model_client = None else: - data["model"] = original_model_string - # call via litellm.atext_completion() - response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore + model_client = potential_model_client + self.total_calls[model_name] += 1 + response = await asyncio.wait_for( + litellm.atext_completion( + **{ + **data, + "prompt": prompt, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + ), + timeout=self.timeout, + ) + self.success_calls[model_name] += 1 return response except Exception as e: - if self.num_retries > 0: - kwargs["model"] = model - kwargs["messages"] = messages - kwargs["original_exception"] = e - kwargs["original_function"] = self.completion - return self.function_with_retries(**kwargs) - else: - raise e + if model_name is not None: + self.fail_calls[model_name] += 1 + raise e def embedding( self, @@ -1531,34 +1567,10 @@ class Router: model ] # update the model to the actual value if an alias has been passed in if self.routing_strategy == "least-busy" and self.leastbusy_logger is not None: - deployments = self.leastbusy_logger.get_available_deployments( - model_group=model + deployment = self.leastbusy_logger.get_available_deployments( + model_group=model, healthy_deployments=healthy_deployments ) - self.print_verbose(f"deployments in least-busy router: {deployments}") - # pick least busy deployment - min_traffic = float("inf") - min_deployment = None - for k, v in deployments.items(): - if v < min_traffic: - min_traffic = v - min_deployment = k - self.print_verbose(f"min_deployment: {min_deployment};") - ############## No Available Deployments passed, we do a random pick ################# - if min_deployment is None: - min_deployment = random.choice(healthy_deployments) - ############## Available Deployments passed, we find the relevant item ################# - else: - ## check if min deployment is a string, if so, cast it to int - for m in healthy_deployments: - if isinstance(min_deployment, str) and isinstance( - m["model_info"]["id"], int - ): - min_deployment = int(min_deployment) - if m["model_info"]["id"] == min_deployment: - return m - self.print_verbose(f"no healthy deployment with that id found!") - min_deployment = random.choice(healthy_deployments) - return min_deployment + return deployment elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm ############## Check if we can do a RPM/TPM based weighted pick ################# diff --git a/litellm/router_strategy/least_busy.py b/litellm/router_strategy/least_busy.py index 64a0aa99a..b2d2983b0 100644 --- a/litellm/router_strategy/least_busy.py +++ b/litellm/router_strategy/least_busy.py @@ -6,7 +6,7 @@ # - use litellm.success + failure callbacks to log when a request completed # - in get_available_deployment, for a given model group name -> pick based on traffic -import dotenv, os, requests +import dotenv, os, requests, random from typing import Optional dotenv.load_dotenv() # Loading env variables using dotenv @@ -20,9 +20,10 @@ class LeastBusyLoggingHandler(CustomLogger): logged_success: int = 0 logged_failure: int = 0 - def __init__(self, router_cache: DualCache): + def __init__(self, router_cache: DualCache, model_list: list): self.router_cache = router_cache self.mapping_deployment_to_id: dict = {} + self.model_list = model_list def log_pre_api_call(self, model, messages, kwargs): """ @@ -168,8 +169,28 @@ class LeastBusyLoggingHandler(CustomLogger): except Exception as e: pass - def get_available_deployments(self, model_group: str): + def get_available_deployments(self, model_group: str, healthy_deployments: list): request_count_api_key = f"{model_group}_request_count" - return_dict = self.router_cache.get_cache(key=request_count_api_key) or {} + deployments = self.router_cache.get_cache(key=request_count_api_key) or {} + all_deployments = deployments + for d in healthy_deployments: + ## if healthy deployment not yet used + if d["model_info"]["id"] not in all_deployments: + all_deployments[d["model_info"]["id"]] = 0 # map deployment to id - return return_dict + # pick least busy deployment + min_traffic = float("inf") + min_deployment = None + for k, v in all_deployments.items(): + if v < min_traffic: + min_traffic = v + min_deployment = k + if min_deployment is not None: + ## check if min deployment is a string, if so, cast it to int + for m in healthy_deployments: + if m["model_info"]["id"] == min_deployment: + return m + min_deployment = random.choice(healthy_deployments) + else: + min_deployment = random.choice(healthy_deployments) + return min_deployment diff --git a/litellm/tests/test_least_busy_routing.py b/litellm/tests/test_least_busy_routing.py index 74e27f34a..c404683a9 100644 --- a/litellm/tests/test_least_busy_routing.py +++ b/litellm/tests/test_least_busy_routing.py @@ -1,7 +1,7 @@ #### What this tests #### # This tests the router's ability to identify the least busy deployment -import sys, os, asyncio, time +import sys, os, asyncio, time, random import traceback from dotenv import load_dotenv @@ -128,3 +128,139 @@ def test_router_get_available_deployments(): assert return_dict[1] == 10 assert return_dict[2] == 54 assert return_dict[3] == 100 + + +## Test with Real calls ## + + +@pytest.mark.asyncio +async def test_router_atext_completion_streaming(): + prompt = "Hello, can you generate a 500 words poem?" + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_CANADA_API_KEY", + "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 3}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="least-busy", + set_verbose=False, + num_retries=3, + ) # type: ignore + + ### Call the async calls in sequence, so we start 1 call before going to the next. + + ## CALL 1 + await asyncio.sleep(random.uniform(0, 2)) + await router.atext_completion(model=model, prompt=prompt, stream=True) + + ## CALL 2 + await asyncio.sleep(random.uniform(0, 2)) + await router.atext_completion(model=model, prompt=prompt, stream=True) + + ## CALL 3 + await asyncio.sleep(random.uniform(0, 2)) + await router.atext_completion(model=model, prompt=prompt, stream=True) + + cache_key = f"{model}_request_count" + ## check if calls equally distributed + cache_dict = router.cache.get_cache(key=cache_key) + for k, v in cache_dict.items(): + assert v == 1 + + +# asyncio.run(test_router_atext_completion_streaming()) + + +@pytest.mark.asyncio +async def test_router_completion_streaming(): + messages = [ + {"role": "user", "content": "Hello, can you generate a 500 words poem?"} + ] + model = "azure-model" + model_list = [ + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-turbo", + "api_key": "os.environ/AZURE_FRANCE_API_KEY", + "api_base": "https://openai-france-1234.openai.azure.com", + "rpm": 1440, + }, + "model_info": {"id": 1}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_EUROPE_API_KEY", + "api_base": "https://my-endpoint-europe-berri-992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 2}, + }, + { + "model_name": "azure-model", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "os.environ/AZURE_CANADA_API_KEY", + "api_base": "https://my-endpoint-canada-berri992.openai.azure.com", + "rpm": 6, + }, + "model_info": {"id": 3}, + }, + ] + router = Router( + model_list=model_list, + routing_strategy="least-busy", + set_verbose=False, + num_retries=3, + ) # type: ignore + + ### Call the async calls in sequence, so we start 1 call before going to the next. + + ## CALL 1 + await asyncio.sleep(random.uniform(0, 2)) + await router.acompletion(model=model, messages=messages, stream=True) + + ## CALL 2 + await asyncio.sleep(random.uniform(0, 2)) + await router.acompletion(model=model, messages=messages, stream=True) + + ## CALL 3 + await asyncio.sleep(random.uniform(0, 2)) + await router.acompletion(model=model, messages=messages, stream=True) + + cache_key = f"{model}_request_count" + ## check if calls equally distributed + cache_dict = router.cache.get_cache(key=cache_key) + for k, v in cache_dict.items(): + assert v == 1 diff --git a/litellm/utils.py b/litellm/utils.py index 2cb1b9e64..5b0fe8b1d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -479,6 +479,8 @@ class EmbeddingResponse(OpenAIObject): usage: Optional[Usage] = None """Usage statistics for the embedding request.""" + _hidden_params: dict = {} + def __init__( self, model=None, usage=None, stream=False, response_ms=None, data=None ): @@ -640,6 +642,8 @@ class ImageResponse(OpenAIObject): usage: Optional[dict] = None + _hidden_params: dict = {} + def __init__(self, created=None, data=None, response_ms=None): if response_ms: _response_ms = response_ms @@ -2053,6 +2057,10 @@ def client(original_function): target=logging_obj.success_handler, args=(result, start_time, end_time) ).start() # RETURN RESULT + if hasattr(result, "_hidden_params"): + result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( + "id", None + ) result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai @@ -2273,6 +2281,10 @@ def client(original_function): target=logging_obj.success_handler, args=(result, start_time, end_time) ).start() # RETURN RESULT + if hasattr(result, "_hidden_params"): + result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( + "id", None + ) if isinstance(result, ModelResponse): result._response_ms = ( end_time - start_time @@ -6527,6 +6539,13 @@ class CustomStreamWrapper: self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] self.holding_chunk = "" self.complete_response = "" + self._hidden_params = { + "model_id": ( + self.logging_obj.model_call_details.get("litellm_params", {}) + .get("model_info", {}) + .get("id", None) + ) + } # returned as x-litellm-model-id response header in proxy def __iter__(self): return self @@ -7417,6 +7436,15 @@ class CustomStreamWrapper: threading.Thread( target=self.logging_obj.success_handler, args=(response,) ).start() # log response + # RETURN RESULT + if hasattr(response, "_hidden_params"): + response._hidden_params["model_id"] = ( + self.logging_obj.model_call_details.get( + "litellm_params", {} + ) + .get("model_info", {}) + .get("id", None) + ) return response except StopIteration: raise # Re-raise StopIteration @@ -7467,6 +7495,16 @@ class CustomStreamWrapper: processed_chunk, ) ) + # RETURN RESULT + if hasattr(processed_chunk, "_hidden_params"): + model_id = ( + self.logging_obj.model_call_details.get( + "litellm_params", {} + ) + .get("model_info", {}) + .get("id", None) + ) + processed_chunk._hidden_params["model_id"] = model_id return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls