mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(custom_llm.py): pass input params to custom llm
This commit is contained in:
parent
bd7af04a72
commit
41abd51240
3 changed files with 182 additions and 10 deletions
|
@ -59,16 +59,88 @@ class CustomLLM(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
def streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[HTTPHandler] = None,
|
||||||
|
) -> Iterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
async def astreaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]:
|
||||||
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
raise CustomLLMError(status_code=500, message="Not implemented yet!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2711,8 +2711,27 @@ def completion(
|
||||||
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
async_fn=acompletion, stream=stream, custom_llm=custom_handler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
## CALL FUNCTION
|
## CALL FUNCTION
|
||||||
response = handler_fn()
|
response = handler_fn(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
return CustomStreamWrapper(
|
return CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
|
|
|
@ -17,7 +17,16 @@ sys.path.insert(
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Coroutine,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -94,21 +103,75 @@ class CustomModelResponseIterator:
|
||||||
|
|
||||||
|
|
||||||
class MyCustomLLM(CustomLLM):
|
class MyCustomLLM(CustomLLM):
|
||||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.HTTPHandler] = None,
|
||||||
|
) -> ModelResponse:
|
||||||
return litellm.completion(
|
return litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hello world"}],
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
mock_response="Hi!",
|
mock_response="Hi!",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
async def acompletion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||||
|
) -> litellm.ModelResponse:
|
||||||
return litellm.completion(
|
return litellm.completion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": "Hello world"}],
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
mock_response="Hi!",
|
mock_response="Hi!",
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
def streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.HTTPHandler] = None,
|
||||||
|
) -> Iterator[GenericStreamingChunk]:
|
||||||
generic_streaming_chunk: GenericStreamingChunk = {
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM):
|
||||||
)
|
)
|
||||||
return custom_iterator
|
return custom_iterator
|
||||||
|
|
||||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
async def astreaming( # type: ignore
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
print_verbose: Callable[..., Any],
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||||
|
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||||
|
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||||
generic_streaming_chunk: GenericStreamingChunk = {
|
generic_streaming_chunk: GenericStreamingChunk = {
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue