fix(router.py): adding support for async completion calls

https://github.com/BerriAI/litellm/issues/676
This commit is contained in:
Krrish Dholakia 2023-10-24 17:20:19 -07:00
parent 30dd0b5c6b
commit 0f08335edd
2 changed files with 46 additions and 3 deletions

View file

@ -54,7 +54,6 @@ class Router:
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
"""
Example usage:
@ -68,6 +67,19 @@ class Router:
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.completion(**{**data, **kwargs})
async def acompletion(self,
model: str,
messages: List[Dict[str, str]],
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
**kwargs):
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages)
data = deployment["litellm_params"]
data["messages"] = messages
data["caching"] = self.cache_responses
return await litellm.acompletion(**{**data, **kwargs})
def text_completion(self,
model: str,
@ -83,6 +95,7 @@ class Router:
data = deployment["litellm_params"]
data["prompt"] = prompt
data["caching"] = self.cache_responses
# call via litellm.completion()
return litellm.text_completion(**{**data, **kwargs})
@ -96,6 +109,7 @@ class Router:
data = deployment["litellm_params"]
data["input"] = input
data["caching"] = self.cache_responses
# call via litellm.embedding()
return litellm.embedding(**{**data, **kwargs})

View file

@ -2,7 +2,7 @@
# This tests calling batch_completions by running 100 messages together
import sys, os
import traceback
import traceback, asyncio
import pytest
sys.path.insert(
0, os.path.abspath("../..")
@ -151,4 +151,33 @@ def test_litellm_params_not_overwritten_by_function_calling():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_litellm_params_not_overwritten_by_function_calling()
# test_litellm_params_not_overwritten_by_function_calling()
def test_acompletion_on_router():
try:
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
]
messages = [
{"role": "user", "content": "What is the weather like in Boston?"}
]
async def get_response():
router = Router(model_list=model_list)
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
return response
response = asyncio.run(get_response())
assert isinstance(response['choices'][0]['message']['content'], str)
except Exception as e:
traceback.print_exc()
pytest.fail(f"Error occurred: {e}")