mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(custom_llm.py): pass input params to custom llm
This commit is contained in:
parent
bd7af04a72
commit
41abd51240
3 changed files with 182 additions and 10 deletions
|
@ -17,7 +17,16 @@ sys.path.insert(
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Coroutine, Iterator, Union
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
|
@ -94,21 +103,75 @@ class CustomModelResponseIterator:
|
|||
|
||||
|
||||
class MyCustomLLM(CustomLLM):
|
||||
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> litellm.ModelResponse:
|
||||
return litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
mock_response="Hi!",
|
||||
) # type: ignore
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
def streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.HTTPHandler] = None,
|
||||
) -> Iterator[GenericStreamingChunk]:
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
|
@ -126,7 +189,25 @@ class MyCustomLLM(CustomLLM):
|
|||
)
|
||||
return custom_iterator
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
async def astreaming( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable[..., Any],
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
timeout: Optional[Union[float, openai.Timeout]] = None,
|
||||
client: Optional[litellm.AsyncHTTPHandler] = None,
|
||||
) -> AsyncIterator[GenericStreamingChunk]: # type: ignore
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue