forked from phoenix/litellm-mirror
feat(custom_llm.py): initial working commit for writing your own custom LLM handler
Fixes https://github.com/BerriAI/litellm/issues/4675 Also Addresses https://github.com/BerriAI/litellm/discussions/4677
This commit is contained in:
parent
711496e260
commit
6bf1b9353b
6 changed files with 183 additions and 0 deletions
|
@ -813,6 +813,7 @@ from .utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .types.utils import ImageObject
|
from .types.utils import ImageObject
|
||||||
|
from .llms.custom_llm import CustomLLM
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig
|
||||||
|
@ -909,3 +910,11 @@ from .cost_calculator import response_cost_calculator, cost_per_token
|
||||||
from .types.adapter import AdapterItem
|
from .types.adapter import AdapterItem
|
||||||
|
|
||||||
adapters: List[AdapterItem] = []
|
adapters: List[AdapterItem] = []
|
||||||
|
|
||||||
|
### CUSTOM LLMs ###
|
||||||
|
from .types.llms.custom_llm import CustomLLMItem
|
||||||
|
|
||||||
|
custom_provider_map: List[CustomLLMItem] = []
|
||||||
|
_custom_providers: List[str] = (
|
||||||
|
[]
|
||||||
|
) # internal helper util, used to track names of custom providers
|
||||||
|
|
70
litellm/llms/custom_llm.py
Normal file
70
litellm/llms/custom_llm.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# What is this?
|
||||||
|
## Handler file for a Custom Chat LLM
|
||||||
|
|
||||||
|
"""
|
||||||
|
- completion
|
||||||
|
- acompletion
|
||||||
|
- streaming
|
||||||
|
- async_streaming
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
|
from litellm.types.utils import ProviderField
|
||||||
|
from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
|
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
def custom_chat_llm_router():
|
||||||
|
"""
|
||||||
|
Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type
|
||||||
|
|
||||||
|
Validates if response is in expected format
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
def streaming(self, *args, **kwargs):
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
async def astreaming(self, *args, **kwargs):
|
||||||
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
|
@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.azure import AzureChatCompletion
|
from .llms.azure import AzureChatCompletion
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||||
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks import DatabricksChatCompletion
|
from .llms.databricks import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
|
@ -2690,6 +2691,20 @@ def completion(
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = model
|
model_response.model = model
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif (
|
||||||
|
custom_llm_provider in litellm._custom_providers
|
||||||
|
): # Assume custom LLM provider
|
||||||
|
# Get the Custom Handler
|
||||||
|
custom_handler: Optional[CustomLLM] = None
|
||||||
|
for item in litellm.custom_provider_map:
|
||||||
|
if item["provider"] == custom_llm_provider:
|
||||||
|
custom_handler = item["custom_handler"]
|
||||||
|
|
||||||
|
if custom_handler is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
)
|
||||||
|
response = custom_handler.completion()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unable to map your input to a model. Check your input - {args}"
|
f"Unable to map your input to a model. Check your input - {args}"
|
||||||
|
|
63
litellm/tests/test_custom_llm.py
Normal file
63
litellm/tests/test_custom_llm.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for the CustomLLM class
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import CustomLLM, completion, get_llm_provider
|
||||||
|
|
||||||
|
|
||||||
|
class MyCustomLLM(CustomLLM):
|
||||||
|
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||||
|
return litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
mock_response="Hi!",
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider():
|
||||||
|
from litellm.utils import custom_llm_setup
|
||||||
|
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
|
||||||
|
custom_llm_setup()
|
||||||
|
|
||||||
|
model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model")
|
||||||
|
|
||||||
|
assert provider == "custom_llm"
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_completion():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
resp = completion(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
messages=[{"role": "user", "content": "Hello world!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.choices[0].message.content == "Hi!"
|
10
litellm/types/llms/custom_llm.py
Normal file
10
litellm/types/llms/custom_llm.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from typing_extensions import Dict, Required, TypedDict, override
|
||||||
|
|
||||||
|
from litellm.llms.custom_llm import CustomLLM
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMItem(TypedDict):
|
||||||
|
provider: str
|
||||||
|
custom_handler: CustomLLM
|
|
@ -330,6 +330,18 @@ class Rules:
|
||||||
|
|
||||||
####### CLIENT ###################
|
####### CLIENT ###################
|
||||||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||||
|
def custom_llm_setup():
|
||||||
|
"""
|
||||||
|
Add custom_llm provider to provider list
|
||||||
|
"""
|
||||||
|
for custom_llm in litellm.custom_provider_map:
|
||||||
|
if custom_llm["provider"] not in litellm.provider_list:
|
||||||
|
litellm.provider_list.append(custom_llm["provider"])
|
||||||
|
|
||||||
|
if custom_llm["provider"] not in litellm._custom_providers:
|
||||||
|
litellm._custom_providers.append(custom_llm["provider"])
|
||||||
|
|
||||||
|
|
||||||
def function_setup(
|
def function_setup(
|
||||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||||
|
@ -341,6 +353,10 @@ def function_setup(
|
||||||
try:
|
try:
|
||||||
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
global callback_list, add_breadcrumb, user_logger_fn, Logging
|
||||||
|
|
||||||
|
## CUSTOM LLM SETUP ##
|
||||||
|
custom_llm_setup()
|
||||||
|
|
||||||
|
## LOGGING SETUP
|
||||||
function_id = kwargs["id"] if "id" in kwargs else None
|
function_id = kwargs["id"] if "id" in kwargs else None
|
||||||
|
|
||||||
if len(litellm.callbacks) > 0:
|
if len(litellm.callbacks) > 0:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue