diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index f1b2b28b4..47c5a485c 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -59,16 +59,88 @@ class CustomLLM(BaseLLM): def __init__(self) -> None: super().__init__() - def completion(self, *args, **kwargs) -> ModelResponse: + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + def streaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[HTTPHandler] = None, + ) -> Iterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def acompletion(self, *args, **kwargs) -> ModelResponse: + async def acompletion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> ModelResponse: raise CustomLLMError(status_code=500, message="Not implemented yet!") - async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + async def astreaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[AsyncHTTPHandler] = None, + ) -> AsyncIterator[GenericStreamingChunk]: raise CustomLLMError(status_code=500, message="Not implemented yet!") diff --git a/litellm/main.py b/litellm/main.py index c3be01373..672029f69 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2711,8 +2711,27 @@ def completion( async_fn=acompletion, stream=stream, custom_llm=custom_handler ) + headers = headers or litellm.headers + ## CALL FUNCTION - response = handler_fn() + response = handler_fn( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + ) if stream is True: return CustomStreamWrapper( completion_stream=response, diff --git a/litellm/tests/test_custom_llm.py b/litellm/tests/test_custom_llm.py index af88b1f3a..a0f8b569e 100644 --- a/litellm/tests/test_custom_llm.py +++ b/litellm/tests/test_custom_llm.py @@ -17,7 +17,16 @@ sys.path.insert( import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Callable, + Coroutine, + Iterator, + Optional, + Union, +) from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -94,21 +103,75 @@ class CustomModelResponseIterator: class MyCustomLLM(CustomLLM): - def completion(self, *args, **kwargs) -> litellm.ModelResponse: + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.HTTPHandler] = None, + ) -> ModelResponse: return litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}], mock_response="Hi!", ) # type: ignore - async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + async def acompletion( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.AsyncHTTPHandler] = None, + ) -> litellm.ModelResponse: return litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}], mock_response="Hi!", ) # type: ignore - def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + def streaming( + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.HTTPHandler] = None, + ) -> Iterator[GenericStreamingChunk]: generic_streaming_chunk: GenericStreamingChunk = { "finish_reason": "stop", "index": 0, @@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM): ) return custom_iterator - async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore + async def astreaming( # type: ignore + self, + model: str, + messages: list, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable[..., Any], + encoding, + api_key, + logging_obj, + optional_params: dict, + acompletion=None, + litellm_params=None, + logger_fn=None, + headers={}, + timeout: Optional[Union[float, openai.Timeout]] = None, + client: Optional[litellm.AsyncHTTPHandler] = None, + ) -> AsyncIterator[GenericStreamingChunk]: # type: ignore generic_streaming_chunk: GenericStreamingChunk = { "finish_reason": "stop", "index": 0,