From 6bf1b9353bbc675390cac2a5821eaa76a4788c28 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 15:33:05 -0700 Subject: [PATCH 1/7] feat(custom_llm.py): initial working commit for writing your own custom LLM handler Fixes https://github.com/BerriAI/litellm/issues/4675 Also Addresses https://github.com/BerriAI/litellm/discussions/4677 --- litellm/__init__.py | 9 ++++ litellm/llms/custom_llm.py | 70 ++++++++++++++++++++++++++++++++ litellm/main.py | 15 +++++++ litellm/tests/test_custom_llm.py | 63 ++++++++++++++++++++++++++++ litellm/types/llms/custom_llm.py | 10 +++++ litellm/utils.py | 16 ++++++++ 6 files changed, 183 insertions(+) create mode 100644 litellm/llms/custom_llm.py create mode 100644 litellm/tests/test_custom_llm.py create mode 100644 litellm/types/llms/custom_llm.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 956834afc..0527ef199 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -813,6 +813,7 @@ from .utils import ( ) from .types.utils import ImageObject +from .llms.custom_llm import CustomLLM from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig from .llms.databricks import DatabricksConfig, DatabricksEmbeddingConfig @@ -909,3 +910,11 @@ from .cost_calculator import response_cost_calculator, cost_per_token from .types.adapter import AdapterItem adapters: List[AdapterItem] = [] + +### CUSTOM LLMs ### +from .types.llms.custom_llm import CustomLLMItem + +custom_provider_map: List[CustomLLMItem] = [] +_custom_providers: List[str] = ( + [] +) # internal helper util, used to track names of custom providers diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py new file mode 100644 index 000000000..fac1eb293 --- /dev/null +++ b/litellm/llms/custom_llm.py @@ -0,0 +1,70 @@ +# What is this? +## Handler file for a Custom Chat LLM + +""" +- completion +- acompletion +- streaming +- async_streaming +""" + +import copy +import json +import os +import time +import types +from enum import Enum +from functools import partial +from typing import Callable, List, Literal, Optional, Tuple, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.databricks import GenericStreamingChunk +from litellm.types.utils import ProviderField +from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import custom_prompt, prompt_factory + + +class CustomLLMError(Exception): # use this for all your exceptions + def __init__( + self, + status_code, + message, + ): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +def custom_chat_llm_router(): + """ + Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type + + Validates if response is in expected format + """ + pass + + +class CustomLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def completion(self, *args, **kwargs) -> ModelResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + def streaming(self, *args, **kwargs): + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def acompletion(self, *args, **kwargs) -> ModelResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def astreaming(self, *args, **kwargs): + raise CustomLLMError(status_code=500, message="Not implemented yet!") diff --git a/litellm/main.py b/litellm/main.py index f724a68bd..539c3d3e1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -107,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion from .llms.azure import AzureChatCompletion from .llms.azure_text import AzureTextCompletion from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM +from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.databricks import DatabricksChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion @@ -2690,6 +2691,20 @@ def completion( model_response.created = int(time.time()) model_response.model = model response = model_response + elif ( + custom_llm_provider in litellm._custom_providers + ): # Assume custom LLM provider + # Get the Custom Handler + custom_handler: Optional[CustomLLM] = None + for item in litellm.custom_provider_map: + if item["provider"] == custom_llm_provider: + custom_handler = item["custom_handler"] + + if custom_handler is None: + raise ValueError( + f"Unable to map your input to a model. Check your input - {args}" + ) + response = custom_handler.completion() else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py new file mode 100644 index 000000000..0506986eb --- /dev/null +++ b/litellm/tests/test_custom_llm.py @@ -0,0 +1,63 @@ +# What is this? +## Unit tests for the CustomLLM class + + +import asyncio +import os +import sys +import time +import traceback + +import openai +import pytest + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +from dotenv import load_dotenv + +import litellm +from litellm import CustomLLM, completion, get_llm_provider + + +class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + +def test_get_llm_provider(): + from litellm.utils import custom_llm_setup + + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + + custom_llm_setup() + + model, provider, _, _ = get_llm_provider(model="custom_llm/my-fake-model") + + assert provider == "custom_llm" + + +def test_simple_completion(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = completion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + ) + + assert resp.choices[0].message.content == "Hi!" diff --git a/litellm/types/llms/custom_llm.py b/litellm/types/llms/custom_llm.py new file mode 100644 index 000000000..d5499a419 --- /dev/null +++ b/litellm/types/llms/custom_llm.py @@ -0,0 +1,10 @@ +from typing import List + +from typing_extensions import Dict, Required, TypedDict, override + +from litellm.llms.custom_llm import CustomLLM + + +class CustomLLMItem(TypedDict): + provider: str + custom_handler: CustomLLM diff --git a/litellm/utils.py b/litellm/utils.py index e104de958..0f1b0315d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -330,6 +330,18 @@ class Rules: ####### CLIENT ################### # make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking +def custom_llm_setup(): + """ + Add custom_llm provider to provider list + """ + for custom_llm in litellm.custom_provider_map: + if custom_llm["provider"] not in litellm.provider_list: + litellm.provider_list.append(custom_llm["provider"]) + + if custom_llm["provider"] not in litellm._custom_providers: + litellm._custom_providers.append(custom_llm["provider"]) + + def function_setup( original_function: str, rules_obj, start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. @@ -341,6 +353,10 @@ def function_setup( try: global callback_list, add_breadcrumb, user_logger_fn, Logging + ## CUSTOM LLM SETUP ## + custom_llm_setup() + + ## LOGGING SETUP function_id = kwargs["id"] if "id" in kwargs else None if len(litellm.callbacks) > 0: From 9f97436308de5c1ddc1acf14567b0caf0c23ab2d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 15:51:39 -0700 Subject: [PATCH 2/7] fix(custom_llm.py): support async completion calls --- litellm/llms/custom_llm.py | 26 +++++++++++++++++--------- litellm/main.py | 10 +++++++++- litellm/tests/test_custom_llm.py | 25 ++++++++++++++++++++++++- 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index fac1eb293..5e9933194 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -44,15 +44,6 @@ class CustomLLMError(Exception): # use this for all your exceptions ) # Call the base class constructor with the parameters it needs -def custom_chat_llm_router(): - """ - Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type - - Validates if response is in expected format - """ - pass - - class CustomLLM(BaseLLM): def __init__(self) -> None: super().__init__() @@ -68,3 +59,20 @@ class CustomLLM(BaseLLM): async def astreaming(self, *args, **kwargs): raise CustomLLMError(status_code=500, message="Not implemented yet!") + + +def custom_chat_llm_router( + async_fn: bool, stream: Optional[bool], custom_llm: CustomLLM +): + """ + Routes call to CustomLLM completion/acompletion/streaming/astreaming functions, based on call type + + Validates if response is in expected format + """ + if async_fn: + if stream: + return custom_llm.astreaming + return custom_llm.acompletion + if stream: + return custom_llm.streaming + return custom_llm.completion diff --git a/litellm/main.py b/litellm/main.py index 539c3d3e1..51e7c611c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -382,6 +382,7 @@ async def acompletion( or custom_llm_provider == "clarifai" or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider in litellm._custom_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) if isinstance(init_response, dict) or isinstance( @@ -2704,7 +2705,14 @@ def completion( raise ValueError( f"Unable to map your input to a model. Check your input - {args}" ) - response = custom_handler.completion() + + ## ROUTE LLM CALL ## + handler_fn = custom_chat_llm_router( + async_fn=acompletion, stream=stream, custom_llm=custom_handler + ) + + ## CALL FUNCTION + response = handler_fn() else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index 0506986eb..fd46c892e 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -23,7 +23,7 @@ import httpx from dotenv import load_dotenv import litellm -from litellm import CustomLLM, completion, get_llm_provider +from litellm import CustomLLM, acompletion, completion, get_llm_provider class MyCustomLLM(CustomLLM): @@ -35,6 +35,15 @@ class MyCustomLLM(CustomLLM): ) # type: ignore +class MyCustomAsyncLLM(CustomLLM): + async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + def test_get_llm_provider(): from litellm.utils import custom_llm_setup @@ -61,3 +70,17 @@ def test_simple_completion(): ) assert resp.choices[0].message.content == "Hi!" + + +@pytest.mark.asyncio +async def test_simple_acompletion(): + my_custom_llm = MyCustomAsyncLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = await acompletion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + ) + + assert resp.choices[0].message.content == "Hi!" From b4e3a77ad0b823fb5ab44f6ee92a48e2b929993d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 16:47:32 -0700 Subject: [PATCH 3/7] feat(utils.py): support sync streaming for custom llm provider --- litellm/__init__.py | 1 + litellm/llms/custom_llm.py | 19 ++++-- litellm/main.py | 8 +++ litellm/tests/test_custom_llm.py | 111 +++++++++++++++++++++++++++++-- litellm/utils.py | 10 ++- 5 files changed, 139 insertions(+), 10 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 0527ef199..b6aacad1a 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -913,6 +913,7 @@ adapters: List[AdapterItem] = [] ### CUSTOM LLMs ### from .types.llms.custom_llm import CustomLLMItem +from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] _custom_providers: List[str] = ( diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index 5e9933194..f00d02ab7 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -15,7 +15,17 @@ import time import types from enum import Enum from functools import partial -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Iterator, + List, + Literal, + Optional, + Tuple, + Union, +) import httpx # type: ignore import requests # type: ignore @@ -23,8 +33,7 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.databricks import GenericStreamingChunk -from litellm.types.utils import ProviderField +from litellm.types.utils import GenericStreamingChunk, ProviderField from litellm.utils import CustomStreamWrapper, EmbeddingResponse, ModelResponse, Usage from .base import BaseLLM @@ -51,13 +60,13 @@ class CustomLLM(BaseLLM): def completion(self, *args, **kwargs) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - def streaming(self, *args, **kwargs): + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") async def acompletion(self, *args, **kwargs) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def astreaming(self, *args, **kwargs): + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") diff --git a/litellm/main.py b/litellm/main.py index 51e7c611c..c3be01373 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2713,6 +2713,14 @@ def completion( ## CALL FUNCTION response = handler_fn() + if stream is True: + return CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging, + ) + else: raise ValueError( f"Unable to map your input to a model. Check your input - {args}" diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index fd46c892e..4cc355e4b 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,13 +17,80 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterator, Iterator, Union from unittest.mock import AsyncMock, MagicMock, patch import httpx from dotenv import load_dotenv import litellm -from litellm import CustomLLM, acompletion, completion, get_llm_provider +from litellm import ( + ChatCompletionDeltaChunk, + ChatCompletionUsageBlock, + CustomLLM, + GenericStreamingChunk, + ModelResponse, + acompletion, + completion, + get_llm_provider, +) +from litellm.utils import ModelResponseIterator + + +class CustomModelResponseIterator: + def __init__(self, streaming_response: Union[Iterator, AsyncIterator]): + self.streaming_response = streaming_response + + def chunk_parser(self, chunk: Any) -> GenericStreamingChunk: + return GenericStreamingChunk( + text="hello world", + tool_use=None, + is_finished=True, + finish_reason="stop", + usage=ChatCompletionUsageBlock( + prompt_tokens=10, completion_tokens=20, total_tokens=30 + ), + index=0, + ) + + # Sync iterator + def __iter__(self): + return self + + def __next__(self) -> GenericStreamingChunk: + try: + chunk: Any = self.streaming_response.__next__() # type: ignore + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self.chunk_parser(chunk=chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore + return self + + async def __anext__(self) -> GenericStreamingChunk: + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self.chunk_parser(chunk=chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") class MyCustomLLM(CustomLLM): @@ -34,8 +101,6 @@ class MyCustomLLM(CustomLLM): mock_response="Hi!", ) # type: ignore - -class MyCustomAsyncLLM(CustomLLM): async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: return litellm.completion( model="gpt-3.5-turbo", @@ -43,8 +108,27 @@ class MyCustomAsyncLLM(CustomLLM): mock_response="Hi!", ) # type: ignore + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": "Hello world", + "tool_use": None, + "usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30}, + } + + completion_stream = ModelResponseIterator( + model_response=generic_streaming_chunk # type: ignore + ) + custom_iterator = CustomModelResponseIterator( + streaming_response=completion_stream + ) + return custom_iterator + def test_get_llm_provider(): + """""" from litellm.utils import custom_llm_setup my_custom_llm = MyCustomLLM() @@ -74,7 +158,7 @@ def test_simple_completion(): @pytest.mark.asyncio async def test_simple_acompletion(): - my_custom_llm = MyCustomAsyncLLM() + my_custom_llm = MyCustomLLM() litellm.custom_provider_map = [ {"provider": "custom_llm", "custom_handler": my_custom_llm} ] @@ -84,3 +168,22 @@ async def test_simple_acompletion(): ) assert resp.choices[0].message.content == "Hi!" + + +def test_simple_completion_streaming(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = completion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + stream=True, + ) + + for chunk in resp: + print(chunk) + if chunk.choices[0].finish_reason is None: + assert isinstance(chunk.choices[0].delta.content, str) + else: + assert chunk.choices[0].finish_reason == "stop" diff --git a/litellm/utils.py b/litellm/utils.py index 0f1b0315d..c14ab36dd 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9262,7 +9262,10 @@ class CustomStreamWrapper: try: # return this for all models completion_obj = {"content": ""} - if self.custom_llm_provider and self.custom_llm_provider == "anthropic": + if self.custom_llm_provider and ( + self.custom_llm_provider == "anthropic" + or self.custom_llm_provider in litellm._custom_providers + ): from litellm.types.utils import GenericStreamingChunk as GChunk if self.received_finish_reason is not None: @@ -10981,3 +10984,8 @@ class ModelResponseIterator: raise StopAsyncIteration self.is_done = True return self.model_response + + +class CustomModelResponseIterator(Iterable): + def __init__(self) -> None: + super().__init__() From 060249c7e0477fee7740a856b4bb7d58ba3c8079 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 17:11:57 -0700 Subject: [PATCH 4/7] feat(utils.py): support async streaming for custom llm provider --- litellm/llms/custom_llm.py | 2 ++ litellm/tests/test_custom_llm.py | 36 ++++++++++++++++++++++++++++++-- litellm/utils.py | 2 ++ 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index f00d02ab7..f1b2b28b4 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -17,8 +17,10 @@ from enum import Enum from functools import partial from typing import ( Any, + AsyncGenerator, AsyncIterator, Callable, + Coroutine, Iterator, List, Literal, diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index 4cc355e4b..af88b1f3a 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,7 +17,7 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncIterator, Iterator, Union +from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -75,7 +75,7 @@ class CustomModelResponseIterator: # Async iterator def __aiter__(self): self.async_response_iterator = self.streaming_response.__aiter__() # type: ignore - return self + return self.streaming_response async def __anext__(self) -> GenericStreamingChunk: try: @@ -126,6 +126,18 @@ class MyCustomLLM(CustomLLM): ) return custom_iterator + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": "Hello world", + "tool_use": None, + "usage": {"completion_tokens": 10, "prompt_tokens": 20, "total_tokens": 30}, + } + + yield generic_streaming_chunk # type: ignore + def test_get_llm_provider(): """""" @@ -187,3 +199,23 @@ def test_simple_completion_streaming(): assert isinstance(chunk.choices[0].delta.content, str) else: assert chunk.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_simple_completion_async_streaming(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + resp = await litellm.acompletion( + model="custom_llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + stream=True, + ) + + async for chunk in resp: + print(chunk) + if chunk.choices[0].finish_reason is None: + assert isinstance(chunk.choices[0].delta.content, str) + else: + assert chunk.choices[0].finish_reason == "stop" diff --git a/litellm/utils.py b/litellm/utils.py index c14ab36dd..9158afb74 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10132,6 +10132,7 @@ class CustomStreamWrapper: try: if self.completion_stream is None: await self.fetch_stream() + if ( self.custom_llm_provider == "openai" or self.custom_llm_provider == "azure" @@ -10156,6 +10157,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "triton" or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints + or self.custom_llm_provider in litellm._custom_providers ): async for chunk in self.completion_stream: print_verbose(f"value of async chunk: {chunk}") From a2d07cfe64e24f2a42612213f46e49114a94ff8e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 17:41:19 -0700 Subject: [PATCH 5/7] docs(custom_llm_server.md): add calling custom llm server to docs --- .../docs/providers/custom_llm_server.md | 73 ++++++++++ .../docs/providers/custom_openai_proxy.md | 129 ------------------ docs/my-website/sidebars.js | 3 +- 3 files changed, 75 insertions(+), 130 deletions(-) create mode 100644 docs/my-website/docs/providers/custom_llm_server.md delete mode 100644 docs/my-website/docs/providers/custom_openai_proxy.md diff --git a/docs/my-website/docs/providers/custom_llm_server.md b/docs/my-website/docs/providers/custom_llm_server.md new file mode 100644 index 000000000..f8d5fb551 --- /dev/null +++ b/docs/my-website/docs/providers/custom_llm_server.md @@ -0,0 +1,73 @@ +# Custom API Server (Custom Format) + +LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format + + +:::info + +For calling an openai-compatible endpoint, [go here](./openai_compatible.md) +::: + +## Quick Start + +```python +import litellm +from litellm import CustomLLM, completion, get_llm_provider + + +class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + +litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER + {"provider": "my-custom-llm", "custom_handler": my_custom_llm} + ] + +resp = completion( + model="my-custom-llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + ) + +assert resp.choices[0].message.content == "Hi!" +``` + + +## Custom Handler Spec + +```python +from litellm.types.utils import GenericStreamingChunk, ModelResponse +from typing import Iterator, AsyncIterator +from litellm.llms.base import BaseLLM + +class CustomLLMError(Exception): # use this for all your exceptions + def __init__( + self, + status_code, + message, + ): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +class CustomLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def completion(self, *args, **kwargs) -> ModelResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def acompletion(self, *args, **kwargs) -> ModelResponse: + raise CustomLLMError(status_code=500, message="Not implemented yet!") + + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + raise CustomLLMError(status_code=500, message="Not implemented yet!") +``` \ No newline at end of file diff --git a/docs/my-website/docs/providers/custom_openai_proxy.md b/docs/my-website/docs/providers/custom_openai_proxy.md deleted file mode 100644 index b6f2eccac..000000000 --- a/docs/my-website/docs/providers/custom_openai_proxy.md +++ /dev/null @@ -1,129 +0,0 @@ -# Custom API Server (OpenAI Format) - -LiteLLM allows you to call your custom endpoint in the OpenAI ChatCompletion format - -## API KEYS -No api keys required - -## Set up your Custom API Server -Your server should have the following Endpoints: - -Here's an example OpenAI proxy server with routes: https://replit.com/@BerriAI/openai-proxy#main.py - -### Required Endpoints -- POST `/chat/completions` - chat completions endpoint - -### Optional Endpoints -- POST `/completions` - completions endpoint -- Get `/models` - available models on server -- POST `/embeddings` - creates an embedding vector representing the input text. - - -## Example Usage - -### Call `/chat/completions` -In order to use your custom OpenAI Chat Completion proxy with LiteLLM, ensure you set - -* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co" -* `custom_llm_provider` to `openai` this ensures litellm uses the `openai.ChatCompletion` to your api_base - -```python -import os -from litellm import completion - -## set ENV variables -os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy - -messages = [{ "content": "Hello, how are you?","role": "user"}] - -response = completion( - model="command-nightly", - messages=[{ "content": "Hello, how are you?","role": "user"}], - api_base="https://openai-proxy.berriai.repl.co", - custom_llm_provider="openai" # litellm will use the openai.ChatCompletion to make the request - -) -print(response) -``` - -#### Response -```json -{ - "object": - "chat.completion", - "choices": [{ - "finish_reason": "stop", - "index": 0, - "message": { - "content": - "The sky, a canvas of blue,\nA work of art, pure and true,\nA", - "role": "assistant" - } - }], - "id": - "chatcmpl-7fbd6077-de10-4cb4-a8a4-3ef11a98b7c8", - "created": - 1699290237.408061, - "model": - "togethercomputer/llama-2-70b-chat", - "usage": { - "completion_tokens": 18, - "prompt_tokens": 14, - "total_tokens": 32 - } - } -``` - - -### Call `/completions` -In order to use your custom OpenAI Completion proxy with LiteLLM, ensure you set - -* `api_base` to your proxy url, example "https://openai-proxy.berriai.repl.co" -* `custom_llm_provider` to `text-completion-openai` this ensures litellm uses the `openai.Completion` to your api_base - -```python -import os -from litellm import completion - -## set ENV variables -os.environ["OPENAI_API_KEY"] = "anything" #key is not used for proxy - -messages = [{ "content": "Hello, how are you?","role": "user"}] - -response = completion( - model="command-nightly", - messages=[{ "content": "Hello, how are you?","role": "user"}], - api_base="https://openai-proxy.berriai.repl.co", - custom_llm_provider="text-completion-openai" # litellm will use the openai.Completion to make the request - -) -print(response) -``` - -#### Response -```json -{ - "warning": - "This model version is deprecated. Migrate before January 4, 2024 to avoid disruption of service. Learn more https://platform.openai.com/docs/deprecations", - "id": - "cmpl-8HxHqF5dymQdALmLplS0dWKZVFe3r", - "object": - "text_completion", - "created": - 1699290166, - "model": - "text-davinci-003", - "choices": [{ - "text": - "\n\nThe weather in San Francisco varies depending on what time of year and time", - "index": 0, - "logprobs": None, - "finish_reason": "length" - }], - "usage": { - "prompt_tokens": 7, - "completion_tokens": 16, - "total_tokens": 23 - } - } -``` \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index d228e09d2..c1ce83068 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -175,7 +175,8 @@ const sidebars = { "providers/aleph_alpha", "providers/baseten", "providers/openrouter", - "providers/custom_openai_proxy", + // "providers/custom_openai_proxy", + "providers/custom_llm_server", "providers/petals", ], From bd7af04a725e74290aeb0d87889538041aa0cc3a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 17:56:34 -0700 Subject: [PATCH 6/7] feat(proxy_server.py): support custom llm handler on proxy --- .../docs/providers/custom_llm_server.md | 97 ++++++++++++++++++- litellm/proxy/_new_secret_config.yaml | 9 +- litellm/proxy/custom_handler.py | 21 ++++ litellm/proxy/proxy_server.py | 15 +++ 4 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 litellm/proxy/custom_handler.py diff --git a/docs/my-website/docs/providers/custom_llm_server.md b/docs/my-website/docs/providers/custom_llm_server.md index f8d5fb551..70fc4cea5 100644 --- a/docs/my-website/docs/providers/custom_llm_server.md +++ b/docs/my-website/docs/providers/custom_llm_server.md @@ -35,6 +35,101 @@ resp = completion( assert resp.choices[0].message.content == "Hi!" ``` +## OpenAI Proxy Usage + +1. Setup your `custom_handler.py` file + +```python +import litellm +from litellm import CustomLLM, completion, get_llm_provider + + +class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + +my_custom_llm = MyCustomLLM() +``` + +2. Add to `config.yaml` + +In the config below, we pass + +python_filename: `custom_handler.py` +custom_handler_instance_name: `my_custom_llm`. This is defined in Step 1 + +custom_handler: `custom_handler.my_custom_llm` + +```yaml +model_list: + - model_name: "test-model" + litellm_params: + model: "openai/text-embedding-ada-002" + - model_name: "my-custom-model" + litellm_params: + model: "my-custom-llm/my-model" + +litellm_settings: + custom_provider_map: + - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm} +``` + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "my-custom-model", + "messages": [{"role": "user", "content": "Say \"this is a test\" in JSON!"}], +}' +``` + +Expected Response + +``` +{ + "id": "chatcmpl-06f1b9cd-08bc-43f7-9814-a69173921216", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "Hi!", + "role": "assistant", + "tool_calls": null, + "function_call": null + } + } + ], + "created": 1721955063, + "model": "gpt-3.5-turbo", + "object": "chat.completion", + "system_fingerprint": null, + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + } +} +``` ## Custom Handler Spec @@ -70,4 +165,4 @@ class CustomLLM(BaseLLM): async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") -``` \ No newline at end of file +``` diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a81d133e5..0854f0901 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,4 +1,11 @@ model_list: - model_name: "test-model" litellm_params: - model: "openai/text-embedding-ada-002" \ No newline at end of file + model: "openai/text-embedding-ada-002" + - model_name: "my-custom-model" + litellm_params: + model: "my-custom-llm/my-model" + +litellm_settings: + custom_provider_map: + - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm} \ No newline at end of file diff --git a/litellm/proxy/custom_handler.py b/litellm/proxy/custom_handler.py new file mode 100644 index 000000000..56943c34d --- /dev/null +++ b/litellm/proxy/custom_handler.py @@ -0,0 +1,21 @@ +import litellm +from litellm import CustomLLM, completion, get_llm_provider + + +class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + +my_custom_llm = MyCustomLLM() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f22f25f73..bad1abae2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1507,6 +1507,21 @@ class ProxyConfig: verbose_proxy_logger.debug( f"litellm.post_call_rules: {litellm.post_call_rules}" ) + elif key == "custom_provider_map": + from litellm.utils import custom_llm_setup + + litellm.custom_provider_map = [ + { + "provider": item["provider"], + "custom_handler": get_instance_fn( + value=item["custom_handler"], + config_file_path=config_file_path, + ), + } + for item in value + ] + + custom_llm_setup() elif key == "success_callback": litellm.success_callback = [] From 41abd5124023c931aa7856271d6e5761804358e6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 25 Jul 2024 19:03:52 -0700 Subject: [PATCH 7/7] fix(custom_llm.py): pass input params to custom llm --- litellm/llms/custom_llm.py | 80 ++++++++++++++++++++++++++-- litellm/main.py | 21 +++++++- litellm/tests/test_custom_llm.py | 91 ++++++++++++++++++++++++++++++-- 3 files changed, 182 insertions(+), 10 deletions(-) diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index f1b2b28b4..47c5a485c 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -59,16 +59,88 @@ class CustomLLM(BaseLLM): def __init__(self) -> None: super().__init__() - def completion(self, *args, **kwargs) -> ModelResponse: + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + def streaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> Iterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def acompletion(self, *args, **kwargs) -> ModelResponse: + async def acompletion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + async def astreaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") diff --git a/litellm/main.py b/litellm/main.py index c3be01373..672029f69 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2711,8 +2711,27 @@ def completion( async_fn=acompletion, stream=stream, custom_llm=custom_handler ) + headers = headers or litellm.headers + ## CALL FUNCTION - response = handler_fn() + response = handler_fn( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + ) if stream is True: return CustomStreamWrapper( completion_stream=response, diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index af88b1f3a..a0f8b569e 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,7 +17,16 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Callable, + Coroutine, + Iterator, + Optional, + Union, +) from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -94,21 +103,75 @@ class CustomModelResponseIterator: class MyCustomLLM(CustomLLM): - def completion(self, *args, **kwargs) -> litellm.ModelResponse: + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.HTTPHandler] = None, + ) -> ModelResponse: return litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}], mock_response="Hi!", ) # type: ignore - async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + async def acompletion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.AsyncHTTPHandler] = None, + ) -> litellm.ModelResponse: return litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}], mock_response="Hi!", ) # type: ignore - def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + def streaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.HTTPHandler] = None, + ) -> Iterator[GenericStreamingChunk]: generic_streaming_chunk: GenericStreamingChunk = { "finish_reason": "stop", "index": 0, @@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM): ) return custom_iterator - async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore + async def astreaming( # type: ignore + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.AsyncHTTPHandler] = None, + ) -> AsyncIterator[GenericStreamingChunk]: # type: ignore generic_streaming_chunk: GenericStreamingChunk = { "finish_reason": "stop", "index": 0,