fix(router.py): overloads for better router.acompletion typing

This commit is contained in:
Krrish Dholakia 2024-05-13 14:27:16 -07:00
parent bd2f46fd75
commit 1312eece6d
2 changed files with 23 additions and 3 deletions

View file

@ -10,6 +10,7 @@
import copy, httpx
from datetime import datetime
from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple
from typing_extensions import overload
import random, threading, time, traceback, uuid
import litellm, openai, hashlib, json
from litellm.caching import RedisCache, InMemoryCache, DualCache
@ -469,9 +470,26 @@ class Router:
)
raise e
# fmt: off
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs
) -> CustomStreamWrapper:
...
@overload
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs
) -> ModelResponse:
...
# fmt: on
# The actual implementation of the function
async def acompletion(
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
):
try:
kwargs["model"] = model
kwargs["messages"] = messages