diff --git a/litellm/main.py b/litellm/main.py index 0d05a29904..32bf7ce97e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -144,7 +144,11 @@ async def acompletion(*args, **kwargs): response = completion(*args, **kwargs) else: # Await normally - response = await completion(*args, **kwargs) + init_response = completion(*args, **kwargs) + if isinstance(init_response, dict): + response = init_response + else: + response = await init_response else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) diff --git a/litellm/router.py b/litellm/router.py index 08d3987bf7..73796e85ee 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Literal import random, threading, time import litellm import logging @@ -29,12 +29,16 @@ class Router: redis_host: Optional[str] = None, redis_port: Optional[int] = None, redis_password: Optional[str] = None, - cache_responses: bool = False) -> None: + cache_responses: bool = False, + routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None: if model_list: self.set_model_list(model_list) - self.healthy_deployments: List = [] - ### HEALTH CHECK THREAD ### - commenting out as further testing required - self._start_health_check_thread() + self.healthy_deployments: List = self.model_list + + self.routing_strategy = routing_strategy + ### HEALTH CHECK THREAD ### + if self.routing_strategy == "least-busy": + self._start_health_check_thread() ### CACHING ### if redis_host is not None and redis_port is not None and redis_password is not None: @@ -104,13 +108,15 @@ class Router: """ Returns the deployment with the shortest queue """ - ### COMMENTING OUT AS IT NEEDS FURTHER TESTING logging.debug(f"self.healthy_deployments: {self.healthy_deployments}") - if len(self.healthy_deployments) > 0: - for item in self.healthy_deployments: - if item[0]["model_name"] == model: # first one in queue will be the one with the most availability - return item[0] - else: + if self.routing_strategy == "least-busy": + if len(self.healthy_deployments) > 0: + for item in self.healthy_deployments: + if item[0]["model_name"] == model: # first one in queue will be the one with the most availability + return item[0] + else: + raise ValueError("No models available.") + elif self.routing_strategy == "simple-shuffle": potential_deployments = [] for item in self.model_list: if item["model_name"] == model: diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index a14ed1f3ce..bd2eacc602 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -134,7 +134,7 @@ Who among the mentioned figures from Ancient Greece contributed to the domain of print(results) -test_multiple_deployments() +# test_multiple_deployments() ### FUNCTION CALLING def test_function_calling(): @@ -228,6 +228,7 @@ def test_function_calling(): def test_acompletion_on_router(): try: + litellm.set_verbose = True model_list = [ { "model_name": "gpt-3.5-turbo", @@ -245,16 +246,69 @@ def test_acompletion_on_router(): ] async def get_response(): - router = Router(model_list=model_list) - response = await router.acompletion(model="gpt-3.5-turbo", messages=messages) - return response - response = asyncio.run(get_response()) - - assert isinstance(response['choices'][0]['message']['content'], str) + router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True) + response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) + print(f"response1: {response1}") + response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages) + print(f"response2: {response2}") + assert response1["choices"][0]["message"]["content"] == response2["choices"][0]["message"]["content"] + asyncio.run(get_response()) except Exception as e: traceback.print_exc() pytest.fail(f"Error occurred: {e}") +test_acompletion_on_router() + +def test_function_calling_on_router(): + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + ] + function1 = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + router = Router( + model_list=model_list, + redis_host=os.getenv("REDIS_HOST"), + redis_password=os.getenv("REDIS_PASSWORD"), + redis_port=os.getenv("REDIS_PORT") + ) + async def get_response(): + messages=[ + { + "role": "user", + "content": "what's the weather in boston" + } + ], + response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, functions=function1) + print(f"response1: {response1}") + return response + response = asyncio.run(get_response()) + assert isinstance(response["choices"][0]["message"]["content"]["function_call"], str) + except Exception as e: + print(f"An exception occurred: {e}") + +# test_function_calling_on_router() def test_aembedding_on_router(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 53c7e7117f..e5a7afd6d9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -907,6 +907,29 @@ def client(original_function): # [Non-Blocking Error] pass + def convert_to_model_response_object(response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None): + try: + if response_object is None or model_response_object is None: + raise OpenAIError(status_code=500, message="Error in response object format") + choice_list=[] + for idx, choice in enumerate(response_object["choices"]): + message = Message(content=choice["message"]["content"], role=choice["message"]["role"]) + choice = Choices(finish_reason=choice["finish_reason"], index=idx, message=message) + choice_list.append(choice) + model_response_object.choices = choice_list + + if "usage" in response_object: + model_response_object.usage = response_object["usage"] + + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "model" in response_object: + model_response_object.model = response_object["model"] + return model_response_object + except: + OpenAIError(status_code=500, message="Invalid response object.") + def wrapper(*args, **kwargs): start_time = datetime.datetime.now() result = None @@ -932,7 +955,7 @@ def client(original_function): # [OPTIONAL] CHECK CACHE # remove this after deprecating litellm.caching - print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}") + print_verbose(f"litellm.caching: {litellm.caching}; litellm.caching_with_models: {litellm.caching_with_models}; litellm.cache: {litellm.cache}") if (litellm.caching or litellm.caching_with_models) and litellm.cache is None: litellm.cache = Cache() @@ -945,7 +968,7 @@ def client(original_function): cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result != None: print_verbose(f"Cache Hit!") - return cached_result + return convert_to_model_response_object(response_object=cached_result, model_response_object=ModelResponse()) # MODEL CALL result = original_function(*args, **kwargs)