forked from phoenix/litellm-mirror
fix(router.py): adding support for async completion calls
https://github.com/BerriAI/litellm/issues/676
This commit is contained in:
parent
30dd0b5c6b
commit
0f08335edd
2 changed files with 46 additions and 3 deletions
|
@ -54,7 +54,6 @@ class Router:
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
is_retry: Optional[bool] = False,
|
is_retry: Optional[bool] = False,
|
||||||
is_fallback: Optional[bool] = False,
|
is_fallback: Optional[bool] = False,
|
||||||
is_async: Optional[bool] = False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -68,6 +67,19 @@ class Router:
|
||||||
data["caching"] = self.cache_responses
|
data["caching"] = self.cache_responses
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.completion(**{**data, **kwargs})
|
return litellm.completion(**{**data, **kwargs})
|
||||||
|
|
||||||
|
async def acompletion(self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
is_retry: Optional[bool] = False,
|
||||||
|
is_fallback: Optional[bool] = False,
|
||||||
|
**kwargs):
|
||||||
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||||
|
data = deployment["litellm_params"]
|
||||||
|
data["messages"] = messages
|
||||||
|
data["caching"] = self.cache_responses
|
||||||
|
return await litellm.acompletion(**{**data, **kwargs})
|
||||||
|
|
||||||
def text_completion(self,
|
def text_completion(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -83,6 +95,7 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["prompt"] = prompt
|
data["prompt"] = prompt
|
||||||
|
data["caching"] = self.cache_responses
|
||||||
# call via litellm.completion()
|
# call via litellm.completion()
|
||||||
return litellm.text_completion(**{**data, **kwargs})
|
return litellm.text_completion(**{**data, **kwargs})
|
||||||
|
|
||||||
|
@ -96,6 +109,7 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"]
|
data = deployment["litellm_params"]
|
||||||
data["input"] = input
|
data["input"] = input
|
||||||
|
data["caching"] = self.cache_responses
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, **kwargs})
|
return litellm.embedding(**{**data, **kwargs})
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
# This tests calling batch_completions by running 100 messages together
|
# This tests calling batch_completions by running 100 messages together
|
||||||
|
|
||||||
import sys, os
|
import sys, os
|
||||||
import traceback
|
import traceback, asyncio
|
||||||
import pytest
|
import pytest
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -151,4 +151,33 @@ def test_litellm_params_not_overwritten_by_function_calling():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_litellm_params_not_overwritten_by_function_calling()
|
# test_litellm_params_not_overwritten_by_function_calling()
|
||||||
|
|
||||||
|
def test_acompletion_on_router():
|
||||||
|
try:
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is the weather like in Boston?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_response():
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
return response
|
||||||
|
response = asyncio.run(get_response())
|
||||||
|
|
||||||
|
assert isinstance(response['choices'][0]['message']['content'], str)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue