diff --git a/litellm/__init__.py b/litellm/__init__.py index b53d49e26e..3b11c14ace 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -21,6 +21,7 @@ huggingface_key: Optional[str] = None vertex_project: Optional[str] = None vertex_location: Optional[str] = None togetherai_api_key: Optional[str] = None +baseten_key: Optional[str] = None caching = False caching_with_models = False # if you want the caching key to be model + prompt model_alias_map: Dict[str, str] = {} diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 3b65dc585f..871321feb1 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 5c77739597..02f2cbf12a 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index b695df71a3..6fbfd8e4dc 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/integrations/litedebugger.py b/litellm/integrations/litedebugger.py index c8b52a16cd..8b4c6d7517 100644 --- a/litellm/integrations/litedebugger.py +++ b/litellm/integrations/litedebugger.py @@ -29,7 +29,7 @@ class LiteDebugger: ) def input_log_event( - self, model, messages, end_user, litellm_call_id, print_verbose + self, model, messages, end_user, litellm_call_id, print_verbose, litellm_params, optional_params ): try: print_verbose( @@ -42,6 +42,8 @@ class LiteDebugger: "status": "initiated", "litellm_call_id": litellm_call_id, "user_email": self.user_email, + "litellm_params": litellm_params, + "optional_params": optional_params } response = requests.post( url=self.api_url, diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py new file mode 100644 index 0000000000..b8016a32e4 --- /dev/null +++ b/litellm/llms/baseten.py @@ -0,0 +1,129 @@ +import os, json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse + +class BasetenError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class BasetenLLM: + def __init__( + self, encoding, logging_obj, api_key=None + ): + self.encoding = encoding + self.completion_url_fragment_1 = "https://app.baseten.co/models/" + self.completion_url_fragment_2 = "/predict" + self.api_key = api_key + self.logging_obj = logging_obj + self.validate_environment(api_key=api_key) + + def validate_environment( + self, api_key + ): # set up the environment required to run the model + # set the api key + if self.api_key == None: + raise ValueError( + "Missing Baseten API Key - A call is being made to baseten but no key is set either in the environment variables or via params" + ) + self.api_key = api_key + self.headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": "Api-Key " + self.api_key, + } + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + optional_params=None, + litellm_params=None, + logger_fn=None, + ): # logic for parsing in - calling - parsing out model completion calls + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{message['content']}" + ) + else: + prompt += ( + f"{message['content']}" + ) + else: + prompt += f"{message['content']}" + data = { + "prompt": prompt, + # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg + # **optional_params, + } + + ## LOGGING + self.logging_obj.pre_call( + input=prompt, + api_key=self.api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data) + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + self.logging_obj.post_call( + input=prompt, + api_key=self.api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise BasetenError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + if "model_output" in completion_response: + if isinstance(completion_response["model_output"], str): + model_response["choices"][0]["message"]["content"] = completion_response["model_output"] + elif isinstance(completion_response["model_output"], dict) and "data" in completion_response["model_output"] and isinstance(completion_response["model_output"]["data"], list): + model_response["choices"][0]["message"]["content"] = completion_response["model_output"]["data"][0] + else: + raise ValueError(f"Unable to parse response. Original response: {response.text}") + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + self.encoding.encode(prompt) + ) + completion_tokens = len( + self.encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + + def embedding( + self, + ): # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index 5062ace36a..ffd7bfc363 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -21,6 +21,7 @@ from litellm.utils import ( ) from .llms.anthropic import AnthropicLLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM +from .llms.baseten import BasetenLLM import tiktoken from concurrent.futures import ThreadPoolExecutor @@ -73,6 +74,7 @@ def completion( max_tokens=float("inf"), presence_penalty=0, frequency_penalty=0, + num_beams=1, logit_bias={}, user="", deployment_id=None, @@ -681,36 +683,31 @@ def completion( custom_llm_provider == "baseten" or litellm.api_base == "https://app.baseten.co" ): - import baseten - - base_ten_key = get_secret("BASETEN_API_KEY") - baseten.login(base_ten_key) - - prompt = " ".join([message["content"] for message in messages]) - ## LOGGING - logging.pre_call(input=prompt, api_key=base_ten_key, model=model) - - base_ten__model = baseten.deployed_model_version_id(model) - - completion_response = base_ten__model.predict({"prompt": prompt}) - if type(completion_response) == dict: - completion_response = completion_response["data"] - if type(completion_response) == dict: - completion_response = completion_response["generated_text"] - - ## LOGGING - logging.post_call( - input=prompt, - api_key=base_ten_key, - original_response=completion_response, + custom_llm_provider = "baseten" + baseten_key = ( + api_key + or litellm.baseten_key + or os.environ.get("BASETEN_API_KEY") ) - - ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = completion_response - model_response["created"] = time.time() - model_response["model"] = model + baseten_client = BasetenLLM( + encoding=encoding, api_key=baseten_key, logging_obj=logging + ) + model_response = baseten_client.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="huggingface" + ) + return response response = model_response - elif custom_llm_provider == "petals" or ( litellm.api_base and "chat.petals.dev" in litellm.api_base ): diff --git a/litellm/tests/test_model_alias_map.py b/litellm/tests/test_model_alias_map.py index 368f02b1ab..9edd771b13 100644 --- a/litellm/tests/test_model_alias_map.py +++ b/litellm/tests/test_model_alias_map.py @@ -13,4 +13,5 @@ from litellm import embedding, completion litellm.set_verbose = True # Test: Check if the alias created via LiteDebugger is mapped correctly -print(completion("wizard-lm", messages=[{"role": "user", "content": "Hey, how's it going?"}])) \ No newline at end of file +{"top_p": 0.75, "prompt": "What's the meaning of life?", "num_beams": 4, "temperature": 0.1} +print(completion("llama-7b", messages=[{"role": "user", "content": "Hey, how's it going?"}], top_p=0.1, temperature=0, num_beams=4, max_tokens=60)) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index b06da44e1e..938f6a7c36 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -154,6 +154,7 @@ class Logging: self.optional_params = optional_params self.litellm_params = litellm_params self.logger_fn = litellm_params["logger_fn"] + print_verbose(f"self.optional_params: {self.optional_params}") self.model_call_details = { "model": model, "messages": messages, @@ -214,6 +215,8 @@ class Logging: end_user=litellm._thread_context.user, litellm_call_id=self. litellm_params["litellm_call_id"], + litellm_params=self.model_call_details["litellm_params"], + optional_params=self.model_call_details["optional_params"], print_verbose=print_verbose, ) except Exception as e: @@ -539,7 +542,7 @@ def get_litellm_params( return litellm_params -def get_optional_params( +def get_optional_params( # use the openai defaults # 12 optional params functions=[], function_call="", @@ -552,6 +555,7 @@ def get_optional_params( presence_penalty=0, frequency_penalty=0, logit_bias={}, + num_beams=1, user="", deployment_id=None, model=None, @@ -613,7 +617,13 @@ def get_optional_params( optional_params["temperature"] = temperature optional_params["top_p"] = top_p optional_params["top_k"] = top_k - + elif custom_llm_provider == "baseten": + optional_params["temperature"] = temperature + optional_params["top_p"] = top_p + optional_params["top_k"] = top_k + optional_params["num_beams"] = num_beams + if max_tokens != float("inf"): + optional_params["max_new_tokens"] = max_tokens else: # assume passing in params for openai/azure openai if functions != []: optional_params["functions"] = functions diff --git a/pyproject.toml b/pyproject.toml index 5e1ebb6648..eee1a89678 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.478" +version = "0.1.479" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"