mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(main.py): add support for maritalk api
This commit is contained in:
parent
d61e4cab19
commit
0ed3917b09
6 changed files with 274 additions and 7 deletions
|
@ -1285,8 +1285,25 @@ def get_optional_params( # use the openai defaults
|
|||
optional_params["presence_penalty"] = presence_penalty
|
||||
if stop:
|
||||
optional_params["stop_sequences"] = stop
|
||||
elif custom_llm_provider == "perplexity":
|
||||
optional_params[""]
|
||||
elif custom_llm_provider == "maritalk":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if logit_bias != {}:
|
||||
optional_params["logit_bias"] = logit_bias
|
||||
if top_p:
|
||||
optional_params["p"] = top_p
|
||||
if presence_penalty:
|
||||
optional_params["repetition_penalty"] = presence_penalty
|
||||
if stop:
|
||||
optional_params["stopping_tokens"] = stop
|
||||
elif custom_llm_provider == "replicate":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
|
||||
|
@ -1585,7 +1602,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
|||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
# check if llm provider part of model name
|
||||
if model.split("/",1)[0] in litellm.provider_list:
|
||||
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "perplexity":
|
||||
|
@ -1631,6 +1648,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
|
|||
## openrouter
|
||||
elif model in litellm.openrouter_models:
|
||||
custom_llm_provider = "openrouter"
|
||||
## openrouter
|
||||
elif model in litellm.maritalk_models:
|
||||
custom_llm_provider = "maritalk"
|
||||
## vertex - text + chat models
|
||||
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
|
||||
custom_llm_provider = "vertex_ai"
|
||||
|
@ -3328,7 +3348,7 @@ def exception_type(
|
|||
elif custom_llm_provider == "ollama":
|
||||
if "no attribute 'async_get_ollama_response_stream" in error_str:
|
||||
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
|
||||
elif custom_llm_provider == "custom_openai":
|
||||
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
|
||||
if hasattr(original_exception, "status_code"):
|
||||
exception_mapping_worked = True
|
||||
if original_exception.status_code == 401:
|
||||
|
@ -3590,6 +3610,17 @@ class CustomStreamWrapper:
|
|||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_maritalk_chunk(self, chunk): # fake streaming
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
try:
|
||||
text = data_json["answer"]
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
|
||||
except:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
|
||||
def handle_nlp_cloud_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
|
@ -3776,6 +3807,12 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
|
||||
chunk = next(self.completion_stream)
|
||||
response_obj = self.handle_maritalk_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
|
||||
chunk = next(self.completion_stream)
|
||||
completion_obj["content"] = chunk[0].outputs[0].text
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue