forked from phoenix/litellm-mirror
fix(router.py): overloads for better router.acompletion typing
This commit is contained in:
parent
bd2f46fd75
commit
1312eece6d
2 changed files with 23 additions and 3 deletions
|
@ -10,6 +10,7 @@
|
|||
import copy, httpx
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
|
||||
from typing_extensions import overload
|
||||
import random, threading, time, traceback, uuid
|
||||
import litellm, openai, hashlib, json
|
||||
from litellm.caching import RedisCache, InMemoryCache, DualCache
|
||||
|
@ -469,9 +470,26 @@ class Router:
|
|||
)
|
||||
raise e
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
|
||||
) -> CustomStreamWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
|
||||
) -> ModelResponse:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
# The actual implementation of the function
|
||||
async def acompletion(
|
||||
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
|
||||
):
|
||||
try:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
|
|
|
@ -134,11 +134,13 @@ async def test_router_retries(sync_mode):
|
|||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
else:
|
||||
await router.acompletion(
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
|
||||
print(response.choices[0].message)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mistral_api_base",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue