mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(main.py): add support for maritalk api
This commit is contained in:
parent
d61e4cab19
commit
0ed3917b09
6 changed files with 274 additions and 7 deletions
|
@ -47,7 +47,8 @@ from .llms import (
|
|||
petals,
|
||||
oobabooga,
|
||||
palm,
|
||||
vertex_ai)
|
||||
vertex_ai,
|
||||
maritalk)
|
||||
from .llms.openai import OpenAIChatCompletion
|
||||
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
|
||||
import tiktoken
|
||||
|
@ -703,7 +704,7 @@ def completion(
|
|||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
elif model in litellm.cohere_models:
|
||||
elif custom_llm_provider == "cohere":
|
||||
cohere_key = (
|
||||
api_key
|
||||
or litellm.cohere_key
|
||||
|
@ -738,6 +739,40 @@ def completion(
|
|||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "maritalk":
|
||||
maritalk_key = (
|
||||
api_key
|
||||
or litellm.maritalk_key
|
||||
or get_secret("MARITALK_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("MARITALK_API_BASE")
|
||||
or "https://chat.maritaca.ai/api/chat/inference"
|
||||
)
|
||||
|
||||
model_response = maritalk.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=maritalk_key,
|
||||
logging_obj=logging
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "deepinfra": # for now this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra
|
||||
# this can be called with the openai python package
|
||||
api_key = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue