diff --git a/litellm/router.py b/litellm/router.py index f1e590545..c886bec2c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -10,6 +10,7 @@ import copy, httpx from datetime import datetime from typing import Dict, List, Optional, Union, Literal, Any, BinaryIO, Tuple +from typing_extensions import overload import random, threading, time, traceback, uuid import litellm, openai, hashlib, json from litellm.caching import RedisCache, InMemoryCache, DualCache @@ -469,9 +470,26 @@ class Router: ) raise e + # fmt: off + + @overload async def acompletion( - self, model: str, messages: List[Dict[str, str]], **kwargs - ) -> Union[ModelResponse, CustomStreamWrapper]: + self, model: str, messages: List[Dict[str, str]], stream: Literal[True], **kwargs + ) -> CustomStreamWrapper: + ... + + @overload + async def acompletion( + self, model: str, messages: List[Dict[str, str]], stream: Literal[False] = False, **kwargs + ) -> ModelResponse: + ... + + # fmt: on + + # The actual implementation of the function + async def acompletion( + self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs + ): try: kwargs["model"] = model kwargs["messages"] = messages diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 7c59acb79..40b0410a4 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -134,11 +134,13 @@ async def test_router_retries(sync_mode): messages=[{"role": "user", "content": "Hey, how's it going?"}], ) else: - await router.acompletion( + response = await router.acompletion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], ) + print(response.choices[0].message) + @pytest.mark.parametrize( "mistral_api_base",