(Refactor) Code Quality improvement - Use Common base handler for Cohere /generate API (#7122)

* use validate_environment in common utils

* use transform request / response for cohere

* remove unused file

* use cohere base_llm_http_handler

* working cohere generate api on llm http handler

* streaming cohere generate api

* fix get_model_response_iterator

* fix streaming handler

* fix get_model_response_iterator

* test_cohere_generate_api_completion

* fix linting error

* fix testing cohere raising error

* fix get_model_response_iterator type

* add testing cohere generate api
This commit is contained in:
Ishaan Jaff 2024-12-10 10:44:42 -08:00 committed by GitHub
parent bd39e1ab5d
commit 5e016fe66a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 439 additions and 382 deletions

View file

@ -109,7 +109,6 @@ from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.cohere.completion import completion as cohere_completion # type: ignore
from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
@ -446,6 +445,7 @@ async def acompletion(
or custom_llm_provider == "groq"
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cohere_chat"
or custom_llm_provider == "cohere"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
@ -1895,31 +1895,22 @@ def completion( # type: ignore # noqa: PLR0915
if extra_headers is not None:
headers.update(extra_headers)
model_response = cohere_completion.completion(
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
custom_llm_provider="cohere",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="cohere",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "cohere_chat":
cohere_key = (
api_key