From 0f08335eddd17c697d0dc57e7e58c5984cd84622 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 24 Oct 2023 17:20:19 -0700 Subject: [PATCH] fix(router.py): adding support for async completion calls https://github.com/BerriAI/litellm/issues/676 --- litellm/router.py | 16 +++++++++++++++- litellm/tests/test_router.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 5b330a239..9756e7714 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -54,7 +54,6 @@ class Router: messages: List[Dict[str, str]], is_retry: Optional[bool] = False, is_fallback: Optional[bool] = False, - is_async: Optional[bool] = False, **kwargs): """ Example usage: @@ -68,6 +67,19 @@ class Router: data["caching"] = self.cache_responses # call via litellm.completion() return litellm.completion(**{**data, **kwargs}) + + async def acompletion(self, + model: str, + messages: List[Dict[str, str]], + is_retry: Optional[bool] = False, + is_fallback: Optional[bool] = False, + **kwargs): + # pick the one that is available (lowest TPM/RPM) + deployment = self.get_available_deployment(model=model, messages=messages) + data = deployment["litellm_params"] + data["messages"] = messages + data["caching"] = self.cache_responses + return await litellm.acompletion(**{**data, **kwargs}) def text_completion(self, model: str, @@ -83,6 +95,7 @@ class Router: data = deployment["litellm_params"] data["prompt"] = prompt + data["caching"] = self.cache_responses # call via litellm.completion() return litellm.text_completion(**{**data, **kwargs}) @@ -96,6 +109,7 @@ class Router: data = deployment["litellm_params"] data["input"] = input + data["caching"] = self.cache_responses # call via litellm.embedding() return litellm.embedding(**{**data, **kwargs}) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index e581c435d..1e44b5d31 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2,7 +2,7 @@ # This tests calling batch_completions by running 100 messages together import sys, os -import traceback +import traceback, asyncio import pytest sys.path.insert( 0, os.path.abspath("../..") @@ -151,4 +151,33 @@ def test_litellm_params_not_overwritten_by_function_calling(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_litellm_params_not_overwritten_by_function_calling() \ No newline at end of file +# test_litellm_params_not_overwritten_by_function_calling() + +def test_acompletion_on_router(): + try: + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo-0613", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + "tpm": 100000, + "rpm": 10000, + }, + ] + + messages = [ + {"role": "user", "content": "What is the weather like in Boston?"} + ] + + async def get_response(): + router = Router(model_list=model_list) + response = await router.acompletion(model="gpt-3.5-turbo", messages=messages) + return response + response = asyncio.run(get_response()) + + assert isinstance(response['choices'][0]['message']['content'], str) + except Exception as e: + traceback.print_exc() + pytest.fail(f"Error occurred: {e}")