forked from phoenix/litellm-mirror
fix(custom_llm.py): pass input params to custom llm
This commit is contained in:
parent
bd7af04a72
commit
41abd51240
3 changed files with 182 additions and 10 deletions
|
@ -59,16 +59,88 @@ class CustomLLM(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
async def astreaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]:
|
||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||
|
||||
|
||||
|
|
|
@ -2711,8 +2711,27 @@ def completion(
|
|||
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
||||
)
|
||||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
## CALL FUNCTION
|
||||
response = handler_fn()
|
||||
response = handler_fn(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
encoding=encoding,
|
||||
)
|
||||
if stream is True:
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
|
|
|
@ -17,7 +17,16 @@ sys.path.insert(
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
@ -94,21 +103,75 @@ class CustomModelResponseIterator:
|
|||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
|
@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM):
|
|||
)
|
||||
return custom_iterator
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
async def astreaming( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue