fix(custom_llm.py): support async completion calls

This commit is contained in:
Krrish Dholakia 2024-07-25 15:51:39 -07:00
parent 54e1ca29b7
commit fe503386ab
3 changed files with 50 additions and 11 deletions

View file

@ -44,15 +44,6 @@ class CustomLLMError(Exception): # use this for all your exceptions
) # Call the base class constructor with the parameters it needs
def custom_chat_llm_router():
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
pass
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
@ -68,3 +59,20 @@ class CustomLLM(BaseLLM):
async def astreaming(self, *args, **kwargs):
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def custom_chat_llm_router(
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
):
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
if async_fn:
if stream:
return custom_llm.astreaming
return custom_llm.acompletion
if stream:
return custom_llm.streaming
return custom_llm.completion

View file

@ -382,6 +382,7 @@ async def acompletion(
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
@ -2704,7 +2705,14 @@ def completion(
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
response = custom_handler.completion()
## ROUTE LLM CALL ##
handler_fn = custom_chat_llm_router(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)
## CALL FUNCTION
response = handler_fn()
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"

View file

@ -23,7 +23,7 @@ import httpx
from dotenv import load_dotenv
import litellm
from litellm import CustomLLM, completion, get_llm_provider
from litellm import CustomLLM, acompletion, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
@ -35,6 +35,15 @@ class MyCustomLLM(CustomLLM):
) # type: ignore
class MyCustomAsyncLLM(CustomLLM):
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def test_get_llm_provider():
from litellm.utils import custom_llm_setup
@ -61,3 +70,17 @@ def test_simple_completion():
)
assert resp.choices[0].message.content == "Hi!"
@pytest.mark.asyncio
async def test_simple_acompletion():
my_custom_llm = MyCustomAsyncLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = await acompletion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"