(feat) cohere_chat provider

This commit is contained in:
ishaan-jaff 2024-03-12 10:29:26 -07:00
parent 042a71cdc7
commit 7635c764cf
2 changed files with 245 additions and 0 deletions

View file

@ -54,6 +54,7 @@ from .llms import (
ollama_chat,
cloudflare,
cohere,
cohere_chat,
petals,
oobabooga,
openrouter,
@ -1218,6 +1219,46 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "cohere_chat":
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or "https://api.cohere.ai/v1/chat"
)
model_response = cohere_chat.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="cohere_chat",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key