mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(custom_llm.py): support async completion calls
This commit is contained in:
parent
54e1ca29b7
commit
fe503386ab
3 changed files with 50 additions and 11 deletions
|
@ -44,15 +44,6 @@ class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
def custom_chat_llm_router():
|
|
||||||
"""
|
|
||||||
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
|
|
||||||
|
|
||||||
Validates if response is in expected format
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CustomLLM(BaseLLM):
|
class CustomLLM(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -68,3 +59,20 @@ class CustomLLM(BaseLLM):
|
||||||
|
|
||||||
async def astreaming(self, *args, **kwargs):
|
async def astreaming(self, *args, **kwargs):
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
|
def custom_chat_llm_router(
|
||||||
|
async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
|
||||||
|
|
||||||
|
Validates if response is in expected format
|
||||||
|
"""
|
||||||
|
if async_fn:
|
||||||
|
if stream:
|
||||||
|
return custom_llm.astreaming
|
||||||
|
return custom_llm.acompletion
|
||||||
|
if stream:
|
||||||
|
return custom_llm.streaming
|
||||||
|
return custom_llm.completion
|
||||||
|
|
|
@ -382,6 +382,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "clarifai"
|
or custom_llm_provider == "clarifai"
|
||||||
or custom_llm_provider == "watsonx"
|
or custom_llm_provider == "watsonx"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
|
or custom_llm_provider in litellm._custom_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
if isinstance(init_response, dict) or isinstance(
|
if isinstance(init_response, dict) or isinstance(
|
||||||
|
@ -2704,7 +2705,14 @@ def completion(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
)
|
)
|
||||||
response = custom_handler.completion()
|
|
||||||
|
## ROUTE LLM CALL ##
|
||||||
|
handler_fn = custom_chat_llm_router(
|
||||||
|
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
||||||
|
)
|
||||||
|
|
||||||
|
## CALL FUNCTION
|
||||||
|
response = handler_fn()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
|
|
@ -23,7 +23,7 @@ import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import CustomLLM, completion, get_llm_provider
|
from litellm import CustomLLM, acompletion, completion, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
class MyCustomLLM(CustomLLM):
|
class MyCustomLLM(CustomLLM):
|
||||||
|
@ -35,6 +35,15 @@ class MyCustomLLM(CustomLLM):
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomAsyncLLM(CustomLLM):
|
||||||
|
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_provider():
|
def test_get_llm_provider():
|
||||||
from litellm.utils import custom_llm_setup
|
from litellm.utils import custom_llm_setup
|
||||||
|
|
||||||
|
@ -61,3 +70,17 @@ def test_simple_completion():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.choices[0].message.content == "Hi!"
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_acompletion():
|
||||||
|
my_custom_llm = MyCustomAsyncLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = await acompletion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hi!"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue