forked from phoenix/litellm-mirror
fix(router.py): adding support for async completion calls
https://github.com/BerriAI/litellm/issues/676
This commit is contained in:
parent
30dd0b5c6b
commit
0f08335edd
2 changed files with 46 additions and 3 deletions
|
@ -54,7 +54,6 @@ class Router:
|
|||
messages: List[Dict[str, str]],
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
is_async: Optional[bool] = False,
|
||||
**kwargs):
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -68,6 +67,19 @@ class Router:
|
|||
data["caching"] = self.cache_responses
|
||||
# call via litellm.completion()
|
||||
return litellm.completion(**{**data, **kwargs})
|
||||
|
||||
async def acompletion(self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
**kwargs):
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
data = deployment["litellm_params"]
|
||||
data["messages"] = messages
|
||||
data["caching"] = self.cache_responses
|
||||
return await litellm.acompletion(**{**data, **kwargs})
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
|
@ -83,6 +95,7 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"]
|
||||
data["prompt"] = prompt
|
||||
data["caching"] = self.cache_responses
|
||||
# call via litellm.completion()
|
||||
return litellm.text_completion(**{**data, **kwargs})
|
||||
|
||||
|
@ -96,6 +109,7 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"]
|
||||
data["input"] = input
|
||||
data["caching"] = self.cache_responses
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, **kwargs})
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# This tests calling batch_completions by running 100 messages together
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -151,4 +151,33 @@ def test_litellm_params_not_overwritten_by_function_calling():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_litellm_params_not_overwritten_by_function_calling()
|
||||
# test_litellm_params_not_overwritten_by_function_calling()
|
||||
|
||||
def test_acompletion_on_router():
|
||||
try:
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is the weather like in Boston?"}
|
||||
]
|
||||
|
||||
async def get_response():
|
||||
router = Router(model_list=model_list)
|
||||
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
return response
|
||||
response = asyncio.run(get_response())
|
||||
|
||||
assert isinstance(response['choices'][0]['message']['content'], str)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue