feat(main.py): add support for maritalk api

This commit is contained in:
Krrish Dholakia 2023-10-30 17:36:32 -07:00
parent d61e4cab19
commit 0ed3917b09
6 changed files with 274 additions and 7 deletions

View file

@ -1285,8 +1285,25 @@ def get_optional_params( # use the openai defaults
optional_params["presence_penalty"] = presence_penalty
if stop:
optional_params["stop_sequences"] = stop
elif custom_llm_provider == "perplexity":
optional_params[""]
elif custom_llm_provider == "maritalk":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature:
optional_params["temperature"] = temperature
if max_tokens:
optional_params["max_tokens"] = max_tokens
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if top_p:
optional_params["p"] = top_p
if presence_penalty:
optional_params["repetition_penalty"] = presence_penalty
if stop:
optional_params["stopping_tokens"] = stop
elif custom_llm_provider == "replicate":
## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
@ -1585,7 +1602,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
return model, custom_llm_provider, dynamic_api_key, api_base
# check if llm provider part of model name
if model.split("/",1)[0] in litellm.provider_list:
if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1]
if custom_llm_provider == "perplexity":
@ -1631,6 +1648,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
## openrouter
elif model in litellm.openrouter_models:
custom_llm_provider = "openrouter"
## openrouter
elif model in litellm.maritalk_models:
custom_llm_provider = "maritalk"
## vertex - text + chat models
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
custom_llm_provider = "vertex_ai"
@ -3328,7 +3348,7 @@ def exception_type(
elif custom_llm_provider == "ollama":
if "no attribute 'async_get_ollama_response_stream" in error_str:
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
elif custom_llm_provider == "custom_openai":
elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
if hasattr(original_exception, "status_code"):
exception_mapping_worked = True
if original_exception.status_code == 401:
@ -3590,6 +3610,17 @@ class CustomStreamWrapper:
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_maritalk_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
text = data_json["answer"]
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_nlp_cloud_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
@ -3776,6 +3807,12 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text