feat(main.py): add support for maritalk api

This commit is contained in:
Krrish Dholakia 2023-10-30 17:36:32 -07:00
parent d61e4cab19
commit 0ed3917b09
6 changed files with 274 additions and 7 deletions

View file

@ -23,6 +23,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
openrouter_key: Optional[str] = None openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None huggingface_key: Optional[str] = None
@ -218,6 +219,10 @@ ollama_models = [
"llama2" "llama2"
] ]
maritalk_models = [
"maritalk"
]
model_list = ( model_list = (
open_ai_chat_completion_models open_ai_chat_completion_models
+ open_ai_text_completion_models + open_ai_text_completion_models
@ -237,6 +242,7 @@ model_list = (
+ bedrock_models + bedrock_models
+ deepinfra_models + deepinfra_models
+ perplexity_models + perplexity_models
+ maritalk_models
) )
provider_list: List = [ provider_list: List = [
@ -263,6 +269,7 @@ provider_list: List = [
"deepinfra", "deepinfra",
"perplexity", "perplexity",
"anyscale", "anyscale",
"maritalk",
"custom", # custom apis "custom", # custom apis
] ]
@ -282,6 +289,7 @@ models_by_provider: dict = {
"ollama": ollama_models, "ollama": ollama_models,
"deepinfra": deepinfra_models, "deepinfra": deepinfra_models,
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents
@ -347,6 +355,7 @@ from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig from .llms.ollama import OllamaConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, AzureOpenAIConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, AzureOpenAIConfig
from .main import * # type: ignore from .main import * # type: ignore

161
litellm/llms/maritalk.py Normal file
View file

@ -0,0 +1,161 @@
import os, types
import json
from enum import Enum
import requests
import time, traceback
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message
import litellm
class MaritalkError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class MaritTalkConfig():
"""
The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters:
- `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1.
- `model` (string): The model used for conversation. Default is 'maritalk'.
- `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True.
- `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7.
- `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95.
- `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1.
- `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped.
"""
max_tokens: Optional[int] = None
model: Optional[str] = None
do_sample: Optional[bool] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
repetition_penalty: Optional[float] = None
stopping_tokens: Optional[List[str]] = None
def __init__(self,
max_tokens: Optional[int]=None,
model: Optional[str] = None,
do_sample: Optional[bool] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
stopping_tokens: Optional[List[str]] = None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
def validate_environment(api_key):
headers = {
"accept": "application/json",
"content-type": "application/json",
}
if api_key:
headers["Authorization"] = f"Key {api_key}"
return headers
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
headers = validate_environment(api_key)
completion_url = api_base
model = model
## Load Config
config=litellm.MaritTalkConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"messages": messages,
**optional_params,
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
if "error" in completion_response:
raise MaritalkError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
if len(completion_response["answer"]) > 0:
model_response["choices"][0]["message"]["content"] = completion_response["answer"]
except Exception as e:
raise MaritalkError(message=response.text, status_code=response.status_code)
## CALCULATING USAGE
prompt = "".join(m["content"] for m in messages)
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = time.time()
model_response["model"] = model
model_response.usage.completion_tokens = completion_tokens
model_response.usage.prompt_tokens = prompt_tokens
model_response.usage.total_tokens = prompt_tokens + completion_tokens
return model_response
def embedding(
model: str,
input: list,
api_key: Optional[str] = None,
logging_obj=None,
model_response=None,
encoding=None,
):
pass

View file

@ -47,7 +47,8 @@ from .llms import (
petals, petals,
oobabooga, oobabooga,
palm, palm,
vertex_ai) vertex_ai,
maritalk)
from .llms.openai import OpenAIChatCompletion from .llms.openai import OpenAIChatCompletion
from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt
import tiktoken import tiktoken
@ -703,7 +704,7 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging) response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging)
return response return response
response = model_response response = model_response
elif model in litellm.cohere_models: elif custom_llm_provider == "cohere":
cohere_key = ( cohere_key = (
api_key api_key
or litellm.cohere_key or litellm.cohere_key
@ -738,6 +739,40 @@ def completion(
response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging) response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging)
return response return response
response = model_response response = model_response
elif custom_llm_provider == "maritalk":
maritalk_key = (
api_key
or litellm.maritalk_key
or get_secret("MARITALK_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("MARITALK_API_BASE")
or "https://chat.maritaca.ai/api/chat/inference"
)
model_response = maritalk.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=maritalk_key,
logging_obj=logging
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging)
return response
response = model_response
elif custom_llm_provider == "deepinfra": # for now this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra elif custom_llm_provider == "deepinfra": # for now this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra
# this can be called with the openai python package # this can be called with the openai python package
api_key = ( api_key = (

View file

@ -56,7 +56,7 @@ def test_completion_claude():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_claude() # test_completion_claude()
# def test_completion_oobabooga(): # def test_completion_oobabooga():
# try: # try:
@ -1273,6 +1273,14 @@ def test_completion_palm():
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
def test_maritalk():
messages = [{"role": "user", "content": "Hey"}]
try:
response = completion("maritalk", messages=messages)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_maritalk()
def test_completion_together_ai_stream(): def test_completion_together_ai_stream():
user_message = "Write 1pg about YC & litellm" user_message = "Write 1pg about YC & litellm"

View file

@ -724,6 +724,23 @@ def test_completion_replicate_stream_bad_key():
# test_completion_sagemaker_stream() # test_completion_sagemaker_stream()
def test_maritalk_streaming():
messages = [{"role": "user", "content": "Hey"}]
try:
response = completion("maritalk", messages=messages, stream=True)
complete_response = ""
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
complete_response += chunk
if finished:
break
if complete_response.strip() == "":
raise Exception("Empty response received")
except:
pytest.fail(f"error occurred: {traceback.format_exc()}")
test_maritalk_streaming()
# test on openai completion call # test on openai completion call
def test_openai_text_completion_call(): def test_openai_text_completion_call():
try: try:

View file

@ -1285,8 +1285,25 @@ def get_optional_params( # use the openai defaults
optional_params["presence_penalty"] = presence_penalty optional_params["presence_penalty"] = presence_penalty
if stop: if stop:
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
elif custom_llm_provider == "perplexity": elif custom_llm_provider == "maritalk":
optional_params[""] ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"]
_check_valid_arg(supported_params=supported_params)
# handle cohere params
if stream:
optional_params["stream"] = stream
if temperature:
optional_params["temperature"] = temperature
if max_tokens:
optional_params["max_tokens"] = max_tokens
if logit_bias != {}:
optional_params["logit_bias"] = logit_bias
if top_p:
optional_params["p"] = top_p
if presence_penalty:
optional_params["repetition_penalty"] = presence_penalty
if stop:
optional_params["stopping_tokens"] = stop
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"]
@ -1585,7 +1602,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
return model, custom_llm_provider, dynamic_api_key, api_base return model, custom_llm_provider, dynamic_api_key, api_base
# check if llm provider part of model name # check if llm provider part of model name
if model.split("/",1)[0] in litellm.provider_list: if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list:
custom_llm_provider = model.split("/", 1)[0] custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1] model = model.split("/", 1)[1]
if custom_llm_provider == "perplexity": if custom_llm_provider == "perplexity":
@ -1631,6 +1648,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
## openrouter ## openrouter
elif model in litellm.openrouter_models: elif model in litellm.openrouter_models:
custom_llm_provider = "openrouter" custom_llm_provider = "openrouter"
## openrouter
elif model in litellm.maritalk_models:
custom_llm_provider = "maritalk"
## vertex - text + chat models ## vertex - text + chat models
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
custom_llm_provider = "vertex_ai" custom_llm_provider = "vertex_ai"
@ -3328,7 +3348,7 @@ def exception_type(
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
if "no attribute 'async_get_ollama_response_stream" in error_str: if "no attribute 'async_get_ollama_response_stream" in error_str:
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'") raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
elif custom_llm_provider == "custom_openai": elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk":
if hasattr(original_exception, "status_code"): if hasattr(original_exception, "status_code"):
exception_mapping_worked = True exception_mapping_worked = True
if original_exception.status_code == 401: if original_exception.status_code == 401:
@ -3590,6 +3610,17 @@ class CustomStreamWrapper:
except: except:
raise ValueError(f"Unable to parse response. Original response: {chunk}") raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_maritalk_chunk(self, chunk): # fake streaming
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
text = data_json["answer"]
is_finished = True
finish_reason = "stop"
return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason}
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_nlp_cloud_chunk(self, chunk): def handle_nlp_cloud_chunk(self, chunk):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
data_json = json.loads(chunk) data_json = json.loads(chunk)
@ -3776,6 +3807,12 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"] model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "maritalk":
chunk = next(self.completion_stream)
response_obj = self.handle_maritalk_chunk(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider and self.custom_llm_provider == "vllm": elif self.custom_llm_provider and self.custom_llm_provider == "vllm":
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
completion_obj["content"] = chunk[0].outputs[0].text completion_obj["content"] = chunk[0].outputs[0].text