fix(router.py): overloads for better router.acompletion typing

This commit is contained in:
Krrish Dholakia 2024-05-13 14:27:16 -07:00
parent bd2f46fd75
commit 1312eece6d
2 changed files with 23 additions and 3 deletions

View file

@ -10,6 +10,7 @@
import copy, httpx import copy, httpx
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
from typing_extensions import overload
import random, threading, time, traceback, uuid import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -469,9 +470,26 @@ class Router:
) )
raise e raise e
# fmt: off
@overload
async def acompletion( async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
):
try: try:
kwargs["model"] = model kwargs["model"] = model
kwargs["messages"] = messages kwargs["messages"] = messages

View file

@ -134,11 +134,13 @@ async def test_router_retries(sync_mode):
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
else: else:
await router.acompletion( response = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}], messages=[{"role": "user", "content": "Hey, how's it going?"}],
) )
print(response.choices[0].message)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mistral_api_base", "mistral_api_base",