fix(custom_llm.py): pass input params to custom llm

This commit is contained in:
Krrish Dholakia 2024-07-25 19:03:52 -07:00
parent bd7af04a72
commit 41abd51240
3 changed files with 182 additions and 10 deletions

View file

@ -17,7 +17,16 @@ sys.path.insert(
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Coroutine,
Iterator,
Optional,
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
@ -94,21 +103,75 @@ class CustomModelResponseIterator:
class MyCustomLLM(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> litellm.ModelResponse:
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
) # type: ignore
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM):
)
return custom_iterator
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
async def astreaming( # type: ignore
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable[..., Any],
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, openai.Timeout]] = None,
client: Optional[litellm.AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,