fix(custom_llm.py): pass input params to custom llm

This commit is contained in:
Krrish Dholakia 2024-07-25 19:03:52 -07:00
parent bd7af04a72
commit 41abd51240
3 changed files with 182 additions and 10 deletions

View file

@ -59,16 +59,88 @@ class CustomLLM(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def completion(self, *args, **kwargs) -> ModelResponse: def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse: async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!") raise CustomLLMError(status_code=500, message="Not implemented yet!")

View file

@ -2711,8 +2711,27 @@ def completion(
async_fn=acompletion, stream=stream, custom_llm=custom_handler async_fn=acompletion, stream=stream, custom_llm=custom_handler
) )
headers = headers or litellm.headers
## CALL FUNCTION ## CALL FUNCTION
response = handler_fn() response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True: if stream is True:
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=response, completion_stream=response,

View file

@ -17,7 +17,16 @@ sys.path.insert(
import os import os
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
@ -94,21 +103,75 @@ class CustomModelResponseIterator:
class MyCustomLLM(CustomLLM): class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse: def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion( return litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}], messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!", mock_response="Hi!",
) # type: ignore ) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion( return litellm.completion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}], messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!", mock_response="Hi!",
) # type: ignore ) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = { generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop", "finish_reason": "stop",
"index": 0, "index": 0,
@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM):
) )
return custom_iterator return custom_iterator
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = { generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop", "finish_reason": "stop",
"index": 0, "index": 0,