diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index fac1eb2936..5e9933194d 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -44,15 +44,6 @@ class CustomLLMError(Exception): # use this for all your exceptions ) # Call the base class constructor with the parameters it needs -def custom_chat_llm_router(): - """ - Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type - - Validates if response is in expected format - """ - pass - - class CustomLLM(BaseLLM): def __init__(self) -> None: super().__init__() @@ -68,3 +59,20 @@ class CustomLLM(BaseLLM): async def astreaming(self, *args, **kwargs): raise CustomLLMError(status_code=500, message="Not implemented yet!") + + +def custom_chat_llm_router( + async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM +): + """ + Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type + + Validates if response is in expected format + """ + if async_fn: + if stream: + return custom_llm.astreaming + return custom_llm.acompletion + if stream: + return custom_llm.streaming + return custom_llm.completion diff --git a/litellm/main.py b/litellm/main.py index 539c3d3e1c..51e7c611c9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -382,6 +382,7 @@ async def acompletion( or custom_llm_provider == "clarifai" or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider in litellm._custom_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( @@ -2704,7 +2705,14 @@ def completion( raise ValueError( f"Unable to map your input to a model. Check your input - {args}" ) - response = custom_handler.completion() + + ## ROUTE LLM CALL ## + handler_fn = custom_chat_llm_router( + async_fn=acompletion, stream=stream, custom_llm=custom_handler + ) + + ## CALL FUNCTION + response = handler_fn() else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index 0506986eb9..fd46c892e3 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -23,7 +23,7 @@ import httpx from dotenv import load_dotenv import litellm -from litellm import CustomLLM, completion, get_llm_provider +from litellm import CustomLLM, acompletion, completion, get_llm_provider class MyCustomLLM(CustomLLM): @@ -35,6 +35,15 @@ class MyCustomLLM(CustomLLM): ) # type: ignore +class MyCustomAsyncLLM(CustomLLM): + async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + def test_get_llm_provider(): from litellm.utils import custom_llm_setup @@ -61,3 +70,17 @@ def test_simple_completion(): ) assert resp.choices[0].message.content == "Hi!" + + +@pytest.mark.asyncio +async def test_simple_acompletion(): + my_custom_llm = MyCustomAsyncLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = await acompletion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + ) + + assert resp.choices[0].message.content == "Hi!"