forked from phoenix/litellm-mirror
fix(router.py): overloads for better router.acompletion typing
This commit is contained in:
parent
bd2f46fd75
commit
1312eece6d
2 changed files with 23 additions and 3 deletions
|
@ -10,6 +10,7 @@
|
||||||
import copy, httpx
|
import copy, httpx
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||||
|
from typing_extensions import overload
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
import litellm, openai, hashlib, json
|
import litellm, openai, hashlib, json
|
||||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||||
|
@ -469,9 +470,26 @@ class Router:
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> CustomStreamWrapper:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||||
|
) -> ModelResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# The actual implementation of the function
|
||||||
|
async def acompletion(
|
||||||
|
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
|
|
|
@ -134,11 +134,13 @@ async def test_router_retries(sync_mode):
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await router.acompletion(
|
response = await router.acompletion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].message)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"mistral_api_base",
|
"mistral_api_base",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue