mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(custom_llm.py): initial working commit for writing your own custom LLM handler
Fixes https://github.com/BerriAI/litellm/issues/4675 Also Addresses https://github.com/BerriAI/litellm/discussions/4677
This commit is contained in:
parent
711496e260
commit
6bf1b9353b
6 changed files with 183 additions and 0 deletions
63
litellm/tests/test_custom_llm.py
Normal file
63
litellm/tests/test_custom_llm.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# What is this?
|
||||
## Unit tests for the CustomLLM class
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import litellm
|
||||
from litellm import CustomLLM, completion, get_llm_provider
|
||||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_get_llm_provider():
|
||||
from litellm.utils import custom_llm_setup
|
||||
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
||||
custom_llm_setup()
|
||||
|
||||
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
||||
|
||||
assert provider == "custom_llm"
|
||||
|
||||
|
||||
def test_simple_completion():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
resp = completion(
|
||||
model="custom_llm/my-fake-model",
|
||||
messages=[{"role": "user", "content": "Hello world!"}],
|
||||
)
|
||||
|
||||
assert resp.choices[0].message.content == "Hi!"
|
Loading…
Add table
Add a link
Reference in a new issue