diff --git a/litellm/__init__.py b/litellm/__init__.py index e1754f4681..5eba89b12b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -23,6 +23,7 @@ azure_key: Optional[str] = None anthropic_key: Optional[str] = None replicate_key: Optional[str] = None cohere_key: Optional[str] = None +maritalk_key: Optional[str] = None ai21_key: Optional[str] = None openrouter_key: Optional[str] = None huggingface_key: Optional[str] = None @@ -218,6 +219,10 @@ ollama_models = [ "llama2" ] +maritalk_models = [ + "maritalk" +] + model_list = ( open_ai_chat_completion_models + open_ai_text_completion_models @@ -237,6 +242,7 @@ model_list = ( + bedrock_models + deepinfra_models + perplexity_models + + maritalk_models ) provider_list: List = [ @@ -263,6 +269,7 @@ provider_list: List = [ "deepinfra", "perplexity", "anyscale", + "maritalk", "custom", # custom apis ] @@ -282,6 +289,7 @@ models_by_provider: dict = { "ollama": ollama_models, "deepinfra": deepinfra_models, "perplexity": perplexity_models, + "maritalk": maritalk_models } # mapping for those models which have larger equivalents @@ -347,6 +355,7 @@ from .llms.petals import PetalsConfig from .llms.vertex_ai import VertexAIConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig +from .llms.maritalk import MaritTalkConfig from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig, AzureOpenAIConfig from .main import * # type: ignore diff --git a/litellm/llms/maritalk.py b/litellm/llms/maritalk.py new file mode 100644 index 0000000000..10f39aa09a --- /dev/null +++ b/litellm/llms/maritalk.py @@ -0,0 +1,161 @@ +import os, types +import json +from enum import Enum +import requests +import time, traceback +from typing import Callable, Optional, List +from litellm.utils import ModelResponse, Choices, Message +import litellm + +class MaritalkError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +class MaritTalkConfig(): + """ + The class `MaritTalkConfig` provides configuration for the MaritTalk's API interface. Here are the parameters: + + - `max_tokens` (integer): Maximum number of tokens the model will generate as part of the response. Default is 1. + + - `model` (string): The model used for conversation. Default is 'maritalk'. + + - `do_sample` (boolean): If set to True, the API will generate a response using sampling. Default is True. + + - `temperature` (number): A non-negative float controlling the randomness in generation. Lower temperatures result in less random generations. Default is 0.7. + + - `top_p` (number): Selection threshold for token inclusion based on cumulative probability. Default is 0.95. + + - `repetition_penalty` (number): Penalty for repetition in the generated conversation. Default is 1. + + - `stopping_tokens` (list of string): List of tokens where the conversation can be stopped/stopped. + """ + max_tokens: Optional[int] = None + model: Optional[str] = None + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + stopping_tokens: Optional[List[str]] = None + + def __init__(self, + max_tokens: Optional[int]=None, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + stopping_tokens: Optional[List[str]] = None) -> None: + + locals_ = locals() + for key, value in locals_.items(): + if key != 'self' and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return {k: v for k, v in cls.__dict__.items() + if not k.startswith('__') + and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) + and v is not None} + +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Key {api_key}" + return headers + +def completion( + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url = api_base + model = model + + ## Load Config + config=litellm.MaritTalkConfig.get_config() + for k, v in config.items(): + if k not in optional_params: # completion(top_k=3) > maritalk_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + data = { + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + completion_url, headers=headers, data=json.dumps(data), stream=optional_params["stream"] if "stream" in optional_params else False + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise MaritalkError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + if len(completion_response["answer"]) > 0: + model_response["choices"][0]["message"]["content"] = completion_response["answer"] + except Exception as e: + raise MaritalkError(message=response.text, status_code=response.status_code) + + ## CALCULATING USAGE + prompt = "".join(m["content"] for m in messages) + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response.usage.completion_tokens = completion_tokens + model_response.usage.prompt_tokens = prompt_tokens + model_response.usage.total_tokens = prompt_tokens + completion_tokens + return model_response + +def embedding( + model: str, + input: list, + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + encoding=None, +): + pass \ No newline at end of file diff --git a/litellm/main.py b/litellm/main.py index ce026d08f3..d1c35ef1ee 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -47,7 +47,8 @@ from .llms import ( petals, oobabooga, palm, - vertex_ai) + vertex_ai, + maritalk) from .llms.openai import OpenAIChatCompletion from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt import tiktoken @@ -703,7 +704,7 @@ def completion( response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging) return response response = model_response - elif model in litellm.cohere_models: + elif custom_llm_provider == "cohere": cohere_key = ( api_key or litellm.cohere_key @@ -738,6 +739,40 @@ def completion( response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging) return response response = model_response + elif custom_llm_provider == "maritalk": + maritalk_key = ( + api_key + or litellm.maritalk_key + or get_secret("MARITALK_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("MARITALK_API_BASE") + or "https://chat.maritaca.ai/api/chat/inference" + ) + + model_response = maritalk.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=maritalk_key, + logging_obj=logging + ) + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging) + return response + response = model_response elif custom_llm_provider == "deepinfra": # for now this NEEDS to be above Hugging Face otherwise all calls to meta-llama/Llama-2-70b-chat-hf go to hf, we need this to go to deep infra if user sets provider to deep infra # this can be called with the openai python package api_key = ( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a6d44efc3e..d6ea0513a8 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -56,7 +56,7 @@ def test_completion_claude(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_claude() +# test_completion_claude() # def test_completion_oobabooga(): # try: @@ -1273,6 +1273,14 @@ def test_completion_palm(): # pytest.fail(f"Error occurred: {e}") +def test_maritalk(): + messages = [{"role": "user", "content": "Hey"}] + try: + response = completion("maritalk", messages=messages) + print(f"response: {response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") +# test_maritalk() def test_completion_together_ai_stream(): user_message = "Write 1pg about YC & litellm" diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 371f090045..0360b3e981 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -724,6 +724,23 @@ def test_completion_replicate_stream_bad_key(): # test_completion_sagemaker_stream() + +def test_maritalk_streaming(): + messages = [{"role": "user", "content": "Hey"}] + try: + response = completion("maritalk", messages=messages, stream=True) + complete_response = "" + start_time = time.time() + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + complete_response += chunk + if finished: + break + if complete_response.strip() == "": + raise Exception("Empty response received") + except: + pytest.fail(f"error occurred: {traceback.format_exc()}") +test_maritalk_streaming() # test on openai completion call def test_openai_text_completion_call(): try: diff --git a/litellm/utils.py b/litellm/utils.py index 50b61aabe6..08e6f4cb02 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1285,8 +1285,25 @@ def get_optional_params( # use the openai defaults optional_params["presence_penalty"] = presence_penalty if stop: optional_params["stop_sequences"] = stop - elif custom_llm_provider == "perplexity": - optional_params[""] + elif custom_llm_provider == "maritalk": + ## check if unsupported param passed in + supported_params = ["stream", "temperature", "max_tokens", "top_p", "presence_penalty", "stop"] + _check_valid_arg(supported_params=supported_params) + # handle cohere params + if stream: + optional_params["stream"] = stream + if temperature: + optional_params["temperature"] = temperature + if max_tokens: + optional_params["max_tokens"] = max_tokens + if logit_bias != {}: + optional_params["logit_bias"] = logit_bias + if top_p: + optional_params["p"] = top_p + if presence_penalty: + optional_params["repetition_penalty"] = presence_penalty + if stop: + optional_params["stopping_tokens"] = stop elif custom_llm_provider == "replicate": ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] @@ -1585,7 +1602,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ return model, custom_llm_provider, dynamic_api_key, api_base # check if llm provider part of model name - if model.split("/",1)[0] in litellm.provider_list: + if model.split("/",1)[0] in litellm.provider_list and model.split("/",1)[0] not in litellm.model_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if custom_llm_provider == "perplexity": @@ -1631,6 +1648,9 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ ## openrouter elif model in litellm.openrouter_models: custom_llm_provider = "openrouter" + ## openrouter + elif model in litellm.maritalk_models: + custom_llm_provider = "maritalk" ## vertex - text + chat models elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: custom_llm_provider = "vertex_ai" @@ -3328,7 +3348,7 @@ def exception_type( elif custom_llm_provider == "ollama": if "no attribute 'async_get_ollama_response_stream" in error_str: raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'") - elif custom_llm_provider == "custom_openai": + elif custom_llm_provider == "custom_openai" or custom_llm_provider == "maritalk": if hasattr(original_exception, "status_code"): exception_mapping_worked = True if original_exception.status_code == 401: @@ -3590,6 +3610,17 @@ class CustomStreamWrapper: except: raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_maritalk_chunk(self, chunk): # fake streaming + chunk = chunk.decode("utf-8") + data_json = json.loads(chunk) + try: + text = data_json["answer"] + is_finished = True + finish_reason = "stop" + return {"text": text, "is_finished": is_finished, "finish_reason": finish_reason} + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + def handle_nlp_cloud_chunk(self, chunk): chunk = chunk.decode("utf-8") data_json = json.loads(chunk) @@ -3776,6 +3807,12 @@ class CustomStreamWrapper: completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: model_response.choices[0].finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider and self.custom_llm_provider == "maritalk": + chunk = next(self.completion_stream) + response_obj = self.handle_maritalk_chunk(chunk) + completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: + model_response.choices[0].finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider and self.custom_llm_provider == "vllm": chunk = next(self.completion_stream) completion_obj["content"] = chunk[0].outputs[0].text