(Refactor) Code Quality improvement - use Common base handler for Cohere (#7117)

* fix use new format for Cohere config

* fix base llm http handler

* Litellm code qa common config (#7116)

* feat(base_llm): initial commit for common base config class

Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* feat(base_llm/): add transform request/response abstract methods to base config class

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>

* use base transform helpers

* use base_llm_http_handler for cohere

* working cohere using base llm handler

* add async cohere chat completion support on base handler

* fix completion code

* working sync cohere stream

* add async support cohere_chat

* fix types get_model_response_iterator

* async / sync tests cohere

* feat  cohere using base llm class

* fix linting errors

* fix _abc error

* add cohere params to transformation

* remove old cohere file

* fix type error

* fix merge conflicts

* fix cohere merge conflicts

* fix linting error

* fix litellm.llms.custom_httpx.http_handler.HTTPHandler.post

* fix passing cohere specific params

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2024-12-09 17:45:29 -08:00 committed by GitHub
parent 501885d653
commit c5e0407703
14 changed files with 933 additions and 720 deletions

View file

@ -111,9 +111,9 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.clarifai.chat import handler
from .llms.cohere.chat import handler as cohere_chat
from .llms.cohere.completion import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
@ -233,6 +233,7 @@ sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
####### COMPLETION ENDPOINTS ################
@ -446,6 +447,7 @@ async def acompletion(
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cohere_chat"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
@ -1941,15 +1943,15 @@ def completion( # type: ignore # noqa: PLR0915
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or get_secret_str("COHERE_API_KEY")
or get_secret_str("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or get_secret_str("COHERE_API_BASE")
or "https://api.cohere.ai/v1/chat"
)
@ -1960,32 +1962,22 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_chat.completion(
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="cohere_chat",
timeout=timeout,
headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
# if "stream" in optional_params and optional_params["stream"] is True:
# # don't try to access stream object,
# response = CustomStreamWrapper(
# model_response,
# model,
# custom_llm_provider="cohere_chat",
# logging_obj=logging,
# _response_headers=headers,
# )
# return response
response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key