(Refactor) Code Quality improvement - use Common base handler for Cohere (#7117)

* fix use new format for Cohere config

* fix base llm http handler

* Litellm code qa common config (#7116)

* feat(base_llm): initial commit for common base config class

Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* feat(base_llm/): add transform request/response abstract methods to base config class

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>

* use base transform helpers

* use base_llm_http_handler for cohere

* working cohere using base llm handler

* add async cohere chat completion support on base handler

* fix completion code

* working sync cohere stream

* add async support cohere_chat

* fix types get_model_response_iterator

* async / sync tests cohere

* feat  cohere using base llm class

* fix linting errors

* fix _abc error

* add cohere params to transformation

* remove old cohere file

* fix type error

* fix merge conflicts

* fix cohere merge conflicts

* fix linting error

* fix litellm.llms.custom_httpx.http_handler.HTTPHandler.post

* fix passing cohere specific params

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-12-09 17:45:29 -08:00 committed by GitHub
parent 5bbf906c83
commit ff7c95694d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 933 additions and 720 deletions

View file

@ -4,7 +4,16 @@ Common base config for all LLM providers
import types
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Iterator,
List,
Optional,
Union,
)
import httpx
@ -12,11 +21,11 @@ from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LoggingClass = Any
LiteLLMLoggingObj = Any
class BaseLLMException(Exception):
@ -78,11 +87,11 @@ class BaseConfig(ABC):
@abstractmethod
def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
) -> dict:
pass
@ -109,21 +118,26 @@ class BaseConfig(ABC):
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: Any,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
pass
@abstractmethod
def get_error_class(
self,
error_message: str,
status_code: int,
headers: dict,
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str]],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
pass