mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Refactor) Code Quality improvement - use Common base handler for Cohere (#7117)
* fix use new format for Cohere config * fix base llm http handler * Litellm code qa common config (#7116) * feat(base_llm): initial commit for common base config class Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * feat(base_llm/): add transform request/response abstract methods to base config class --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com> * use base transform helpers * use base_llm_http_handler for cohere * working cohere using base llm handler * add async cohere chat completion support on base handler * fix completion code * working sync cohere stream * add async support cohere_chat * fix types get_model_response_iterator * async / sync tests cohere * feat cohere using base llm class * fix linting errors * fix _abc error * add cohere params to transformation * remove old cohere file * fix type error * fix merge conflicts * fix cohere merge conflicts * fix linting error * fix litellm.llms.custom_httpx.http_handler.HTTPHandler.post * fix passing cohere specific params --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
501885d653
commit
c5e0407703
14 changed files with 933 additions and 720 deletions
|
@ -111,9 +111,9 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
|||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
from .llms.clarifai.chat import handler
|
||||
from .llms.cohere.chat import handler as cohere_chat
|
||||
from .llms.cohere.completion import completion as cohere_completion # type: ignore
|
||||
from .llms.cohere.embed import handler as cohere_embed
|
||||
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
|
@ -233,6 +233,7 @@ sagemaker_llm = SagemakerLLM()
|
|||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
databricks_embedding = DatabricksEmbeddingHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -446,6 +447,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "groq"
|
||||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cohere_chat"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
|
@ -1941,15 +1943,15 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
or get_secret("COHERE_API_KEY")
|
||||
or get_secret("CO_API_KEY")
|
||||
or get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("COHERE_API_BASE")
|
||||
or get_secret_str("COHERE_API_BASE")
|
||||
or "https://api.cohere.ai/v1/chat"
|
||||
)
|
||||
|
||||
|
@ -1960,32 +1962,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
model_response = cohere_chat.completion(
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="cohere_chat",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
# if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# # don't try to access stream object,
|
||||
# response = CustomStreamWrapper(
|
||||
# model_response,
|
||||
# model,
|
||||
# custom_llm_provider="cohere_chat",
|
||||
# logging_obj=logging,
|
||||
# _response_headers=headers,
|
||||
# )
|
||||
# return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "maritalk":
|
||||
maritalk_key = (
|
||||
api_key
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue