Merge pull request #2474 from BerriAI/litellm_support_command_r

[New-Model] Cohere/command-r
This commit is contained in:
Ishaan Jaff 2024-03-12 11:11:56 -07:00 committed by GitHub
commit 5172fb1de9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 386 additions and 27 deletions

View file

@ -55,6 +55,7 @@ from .llms import (
ollama_chat,
cloudflare,
cohere,
cohere_chat,
petals,
oobabooga,
openrouter,
@ -1287,6 +1288,46 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "cohere_chat":
cohere_key = (
api_key
or litellm.cohere_key
or get_secret("COHERE_API_KEY")
or get_secret("CO_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("COHERE_API_BASE")
or "https://api.cohere.ai/v1/chat"
)
model_response = cohere_chat.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="cohere_chat",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key