baseten client mapping

This commit is contained in:
ishaan-jaff 2023-09-04 15:41:36 -07:00
parent 3147bf1d99
commit db4f4c0191
2 changed files with 118 additions and 131 deletions

View file

@ -1,11 +1,11 @@
import os, json import os
import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
class BasetenError(Exception): class BasetenError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message):
self.status_code = status_code self.status_code = status_code
@ -14,41 +14,30 @@ class BasetenError(Exception):
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def validate_environment(api_key):
class BasetenLLM: headers = {
def __init__(self, encoding, logging_obj, api_key=None):
self.encoding = encoding
self.completion_url_fragment_1 = "https://app.baseten.co/models/"
self.completion_url_fragment_2 = "/predict"
self.api_key = api_key
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key)
def validate_environment(
self, api_key
): # set up the environment required to run the model
# set the api key
if self.api_key == None:
raise ValueError(
"Missing Baseten API Key - A call is being made to baseten but no key is set either in the environment variables or via params"
)
self.api_key = api_key
self.headers = {
"accept": "application/json", "accept": "application/json",
"content-type": "application/json", "content-type": "application/json",
"Authorization": "Api-Key " + self.api_key,
} }
if api_key:
headers["Authorization"] = f"Api-Key {api_key}"
return headers
def completion( def completion(
self,
model: str, model: str,
messages: list, messages: list,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls ):
headers = validate_environment(api_key)
completion_url_fragment_1 = "https://app.baseten.co/models/"
completion_url_fragment_2 = "/predict"
model = model model = model
prompt = "" prompt = ""
for message in messages: for message in messages:
@ -60,24 +49,22 @@ class BasetenLLM:
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
data = { data = {
# "prompt": prompt, "inputs": prompt,
"inputs": prompt, # in case it's a TGI deployed model "prompt": prompt,
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
# **optional_params,
"parameters": optional_params, "parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False "stream": True if "stream" in optional_params and optional_params["stream"] == True else False
} }
## LOGGING ## LOGGING
self.logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
## COMPLETION CALL ## COMPLETION CALL
response = requests.post( response = requests.post(
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, completion_url_fragment_1 + model + completion_url_fragment_2,
headers=self.headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False stream=True if "stream" in optional_params and optional_params["stream"] == True else False
) )
@ -85,9 +72,9 @@ class BasetenLLM:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
self.logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
api_key=self.api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
@ -141,9 +128,9 @@ class BasetenLLM:
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(self.encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len( completion_tokens = len(
self.encoding.encode(model_response["choices"][0]["message"]["content"]) encoding.encode(model_response["choices"][0]["message"]["content"])
) )
model_response["created"] = time.time() model_response["created"] = time.time()
@ -155,7 +142,6 @@ class BasetenLLM:
} }
return model_response return model_response
def embedding( def embedding():
self, # logic for parsing in - calling - parsing out model embedding calls
): # logic for parsing in - calling - parsing out model embedding calls
pass pass

View file

@ -26,7 +26,7 @@ from .llms import sagemaker
from .llms import bedrock from .llms import bedrock
from .llms import huggingface_restapi from .llms import huggingface_restapi
from .llms import aleph_alpha from .llms import aleph_alpha
from .llms.baseten import BasetenLLM from .llms import baseten
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -751,10 +751,8 @@ def completion(
baseten_key = ( baseten_key = (
api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY") api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY")
) )
baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging model_response = baseten.completion(
)
model_response = baseten_client.completion(
model=model, model=model,
messages=messages, messages=messages,
model_response=model_response, model_response=model_response,
@ -762,6 +760,9 @@ def completion(
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding,
api_key=baseten_key,
logging_obj=logging
) )
if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True):
# don't try to access stream object, # don't try to access stream object,