diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index b11ae179dd..5ed3941619 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -1,11 +1,11 @@ -import os, json +import os +import json from enum import Enum import requests import time from typing import Callable from litellm.utils import ModelResponse - class BasetenError(Exception): def __init__(self, status_code, message): self.status_code = status_code @@ -14,148 +14,134 @@ class BasetenError(Exception): self.message ) # Call the base class constructor with the parameters it needs +def validate_environment(api_key): + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Api-Key {api_key}" + return headers -class BasetenLLM: - def __init__(self, encoding, logging_obj, api_key=None): - self.encoding = encoding - self.completion_url_fragment_1 = "https://app.baseten.co/models/" - self.completion_url_fragment_2 = "/predict" - self.api_key = api_key - self.logging_obj = logging_obj - self.validate_environment(api_key=api_key) - - def validate_environment( - self, api_key - ): # set up the environment required to run the model - # set the api key - if self.api_key == None: - raise ValueError( - "Missing Baseten API Key - A call is being made to baseten but no key is set either in the environment variables or via params" - ) - self.api_key = api_key - self.headers = { - "accept": "application/json", - "content-type": "application/json", - "Authorization": "Api-Key " + self.api_key, - } - - def completion( - self, - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - optional_params=None, - litellm_params=None, - logger_fn=None, - ): # logic for parsing in - calling - parsing out model completion calls - model = model - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += f"{message['content']}" - else: - prompt += f"{message['content']}" +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url_fragment_1 = "https://app.baseten.co/models/" + completion_url_fragment_2 = "/predict" + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += f"{message['content']}" else: prompt += f"{message['content']}" - data = { - # "prompt": prompt, - "inputs": prompt, # in case it's a TGI deployed model - # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg - # **optional_params, - "parameters": optional_params, - "stream": True if "stream" in optional_params and optional_params["stream"] == True else False - } + else: + prompt += f"{message['content']}" + data = { + "inputs": prompt, + "prompt": prompt, + "parameters": optional_params, + "stream": True if "stream" in optional_params and optional_params["stream"] == True else False + } - ## LOGGING - self.logging_obj.pre_call( + ## LOGGING + logging_obj.pre_call( input=prompt, - api_key=self.api_key, + api_key=api_key, additional_args={"complete_input_dict": data}, ) - ## COMPLETION CALL - response = requests.post( - self.completion_url_fragment_1 + model + self.completion_url_fragment_2, - headers=self.headers, - data=json.dumps(data), - stream=True if "stream" in optional_params and optional_params["stream"] == True else False - ) - if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): - return response.iter_lines() - else: - ## LOGGING - self.logging_obj.post_call( + ## COMPLETION CALL + response = requests.post( + completion_url_fragment_1 + model + completion_url_fragment_2, + headers=headers, + data=json.dumps(data), + stream=True if "stream" in optional_params and optional_params["stream"] == True else False + ) + if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True): + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( input=prompt, - api_key=self.api_key, + api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - completion_response = response.json() - if "error" in completion_response: - raise BasetenError( - message=completion_response["error"], - status_code=response.status_code, - ) - else: - if "model_output" in completion_response: - if ( - isinstance(completion_response["model_output"], dict) - and "data" in completion_response["model_output"] - and isinstance( - completion_response["model_output"]["data"], list - ) - ): - model_response["choices"][0]["message"][ - "content" - ] = completion_response["model_output"]["data"][0] - elif isinstance(completion_response["model_output"], str): - model_response["choices"][0]["message"][ - "content" - ] = completion_response["model_output"] - elif "completion" in completion_response and isinstance( - completion_response["completion"], str + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise BasetenError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + if "model_output" in completion_response: + if ( + isinstance(completion_response["model_output"], dict) + and "data" in completion_response["model_output"] + and isinstance( + completion_response["model_output"]["data"], list + ) ): model_response["choices"][0]["message"][ "content" - ] = completion_response["completion"] - elif isinstance(completion_response, list) and len(completion_response) > 0: - if "generated_text" not in completion_response: - raise BasetenError( - message=f"Unable to parse response. Original response: {response.text}", - status_code=response.status_code - ) - model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] - ## GETTING LOGPROBS - if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: - sum_logprob = 0 - for token in completion_response[0]["details"]["tokens"]: - sum_logprob += token["logprob"] - model_response["choices"][0]["message"]["logprobs"] = sum_logprob - else: + ] = completion_response["model_output"]["data"][0] + elif isinstance(completion_response["model_output"], str): + model_response["choices"][0]["message"][ + "content" + ] = completion_response["model_output"] + elif "completion" in completion_response and isinstance( + completion_response["completion"], str + ): + model_response["choices"][0]["message"][ + "content" + ] = completion_response["completion"] + elif isinstance(completion_response, list) and len(completion_response) > 0: + if "generated_text" not in completion_response: raise BasetenError( message=f"Unable to parse response. Original response: {response.text}", status_code=response.status_code ) + model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"] + ## GETTING LOGPROBS + if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]: + sum_logprob = 0 + for token in completion_response[0]["details"]["tokens"]: + sum_logprob += token["logprob"] + model_response["choices"][0]["message"]["logprobs"] = sum_logprob + else: + raise BasetenError( + message=f"Unable to parse response. Original response: {response.text}", + status_code=response.status_code + ) - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len(self.encoding.encode(prompt)) - completion_tokens = len( - self.encoding.encode(model_response["choices"][0]["message"]["content"]) - ) + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - return model_response + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response - def embedding( - self, - ): # logic for parsing in - calling - parsing out model embedding calls - pass +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index 859065328d..6042a5eaa1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -26,7 +26,7 @@ from .llms import sagemaker from .llms import bedrock from .llms import huggingface_restapi from .llms import aleph_alpha -from .llms.baseten import BasetenLLM +from .llms import baseten import tiktoken from concurrent.futures import ThreadPoolExecutor @@ -751,10 +751,8 @@ def completion( baseten_key = ( api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY") ) - baseten_client = BasetenLLM( - encoding=encoding, api_key=baseten_key, logging_obj=logging - ) - model_response = baseten_client.completion( + + model_response = baseten.completion( model=model, messages=messages, model_response=model_response, @@ -762,6 +760,9 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, + encoding=encoding, + api_key=baseten_key, + logging_obj=logging ) if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): # don't try to access stream object,