fix(custom_llm.py): pass input params to custom llm

This commit is contained in:
Krrish Dholakia 2024-07-25 19:03:52 -07:00
parent bd7af04a72
commit 41abd51240
3 changed files with 182 additions and 10 deletions

View file

@ -59,16 +59,88 @@ class CustomLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def completion(self, *args, **kwargs) -> ModelResponse:
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
def streaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[HTTPHandler] = None,
) -> Iterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def acompletion(self, *args, **kwargs) -> ModelResponse:
async def acompletion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ModelResponse:
raise CustomLLMError(status_code=500, message="Not implemented yet!")
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
async def astreaming(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> AsyncIterator[GenericStreamingChunk]:
raise CustomLLMError(status_code=500, message="Not implemented yet!")