feat(custom_llm.py): initial working commit for writing your own custom LLM handler

Fixes https://github.com/BerriAI/litellm/issues/4675

 Also Addresses https://github.com/BerriAI/litellm/discussions/4677
This commit is contained in:
Krrish Dholakia 2024-07-25 15:33:05 -07:00
parent 1d33759bb1
commit 54e1ca29b7
6 changed files with 183 additions and 0 deletions

View file

@ -813,6 +813,7 @@ from .utils import (
)
from .types.utils import ImageObject
from .llms.custom_llm import CustomLLM
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
@ -909,3 +910,11 @@ from .cost_calculator import response_cost_calculator, cost_per_token
from .types.adapter import AdapterItem
adapters: List[AdapterItem] = []
### CUSTOM LLMs ###
from .types.llms.custom_llm import CustomLLMItem
custom_provider_map: List[CustomLLMItem] = []
_custom_providers: List[str] = (
[]
) # internal helper util, used to track names of custom providers

View file

@ -0,0 +1,70 @@
# What is this?
## Handler file for a Custom Chat LLM
"""
- completion
- acompletion
- streaming
- async_streaming
"""
import copy
import json
import os
import time
import types
from enum import Enum
from functools import partial
from typing import Callable, List, Literal, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.types.utils import ProviderField
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import custom_prompt, prompt_factory
class CustomLLMError(Exception): # use this for all your exceptions
def __init__(
self,
status_code,
message,
):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def custom_chat_llm_router():
"""
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
Validates if response is in expected format
"""
pass
class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs):
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs):
raise CustomLLMError(status_code=500, message="Not implemented yet!")

View file

@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -2690,6 +2691,20 @@ def completion(
model_response.created = int(time.time())
model_response.model = model
response = model_response
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
response = custom_handler.completion()
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"

View file

@ -0,0 +1,63 @@
# What is this?
## Unit tests for the CustomLLM class
import asyncio
import os
import sys
import time
import traceback
import openai
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
import litellm
from litellm import CustomLLM, completion, get_llm_provider
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def test_get_llm_provider():
from litellm.utils import custom_llm_setup
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
custom_llm_setup()
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
assert provider == "custom_llm"
def test_simple_completion():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
resp = completion(
model="custom_llm/my-fake-model",
messages=[{"role": "user", "content": "Hello world!"}],
)
assert resp.choices[0].message.content == "Hi!"

View file

@ -0,0 +1,10 @@
from typing import List
from typing_extensions import Dict, Required, TypedDict, override
from litellm.llms.custom_llm import CustomLLM
class CustomLLMItem(TypedDict):
provider: str
custom_handler: CustomLLM

View file

@ -330,6 +330,18 @@ class Rules:
####### CLIENT ###################
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
def custom_llm_setup():
"""
Add custom_llm provider to provider list
"""
for custom_llm in litellm.custom_provider_map:
if custom_llm["provider"] not in litellm.provider_list:
litellm.provider_list.append(custom_llm["provider"])
if custom_llm["provider"] not in litellm._custom_providers:
litellm._custom_providers.append(custom_llm["provider"])
def function_setup(
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -341,6 +353,10 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
## CUSTOM LLM SETUP ##
custom_llm_setup()
## LOGGING SETUP
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0: