update baseten handler to handle TGI calls

This commit is contained in:
Krrish Dholakia 2023-08-30 19:14:48 -07:00
parent 9cc0a094c0
commit 14d4c7ead2
8 changed files with 79 additions and 38 deletions

View file

@ -60,9 +60,12 @@ class BasetenLLM:
else: else:
prompt += f"{message['content']}" prompt += f"{message['content']}"
data = { data = {
"prompt": prompt, # "prompt": prompt,
"inputs": prompt, # in case it's a TGI deployed model
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
**optional_params, # **optional_params,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
} }
## LOGGING ## LOGGING
@ -76,8 +79,9 @@ class BasetenLLM:
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
headers=self.headers, headers=self.headers,
data=json.dumps(data), data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
) )
if "stream" in optional_params and optional_params["stream"] == True: if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -117,9 +121,23 @@ class BasetenLLM:
model_response["choices"][0]["message"][ model_response["choices"][0]["message"][
"content" "content"
] = completion_response["completion"] ] = completion_response["completion"]
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
else: else:
raise ValueError( raise BasetenError(
f"Unable to parse response. Original response: {response.text}" message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.

View file

@ -1,4 +1,4 @@
import os, openai, sys, json import os, openai, sys, json, inspect
from typing import Any from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
@ -682,7 +682,7 @@ def completion(
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn, logger_fn=logger_fn,
) )
if "stream" in optional_params and optional_params["stream"] == True: if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True):
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model, custom_llm_provider="baseten", logging_obj=logging model_response, model, custom_llm_provider="baseten", logging_obj=logging

View file

@ -20,7 +20,7 @@ litellm.use_client = True
# litellm.set_verbose = True # litellm.set_verbose = True
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"]) # litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
user_message = "Hello, whats the weather in San Francisco??" user_message = "write me a function to print hello world in python"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
@ -383,6 +383,14 @@ def test_completion_with_fallbacks():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# def test_baseten():
# try:
# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# def test_baseten_falcon_7bcompletion(): # def test_baseten_falcon_7bcompletion():
# model_name = "qvv0xeq" # model_name = "qvv0xeq"

View file

@ -25,23 +25,24 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
# test on baseten completion call # test on baseten completion call
try: # try:
response = completion( # response = completion(
model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn # model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn
) # )
print(f"response: {response}") # print(f"response: {response}")
complete_response = "" # complete_response = ""
start_time = time.time() # start_time = time.time()
for chunk in response: # for chunk in response:
chunk_time = time.time() # chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}") # print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"]) # print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"] # complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "": # if complete_response == "":
raise Exception("Empty response received") # raise Exception("Empty response received")
except: # print(f"complete response: {complete_response}")
print(f"error occurred: {traceback.format_exc()}") # except:
pass # print(f"error occurred: {traceback.format_exc()}")
# pass
# test on openai completion call # test on openai completion call
try: try:

View file

@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults
elif custom_llm_provider == "baseten": elif custom_llm_provider == "baseten":
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
optional_params["stream"] = stream optional_params["stream"] = stream
optional_params["top_p"] = top_p if top_p != 1:
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k optional_params["top_k"] = top_k
optional_params["num_beams"] = num_beams optional_params["num_beams"] = num_beams
if max_tokens != float("inf"): if max_tokens != float("inf"):
@ -1739,18 +1740,31 @@ class CustomStreamWrapper:
return chunk["choices"][0]["delta"]["content"] return chunk["choices"][0]["delta"]["content"]
def handle_baseten_chunk(self, chunk): def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8") try:
data_json = json.loads(chunk) chunk = chunk.decode("utf-8")
if "model_output" in data_json: if len(chunk) > 0:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): if chunk.startswith("data:"):
return data_json["model_output"]["data"][0] data_json = json.loads(chunk[5:])
elif isinstance(data_json["model_output"], str): if "token" in data_json and "text" in data_json["token"]:
return data_json["model_output"] return data_json["token"]["text"]
elif "completion" in data_json and isinstance(data_json["completion"], str): else:
return data_json["completion"] return ""
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
else: else:
raise ValueError(f"Unable to parse response. Original response: {chunk}") return ""
else: except:
traceback.print_exc()
return "" return ""
def __next__(self): def __next__(self):

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.508" version = "0.1.509"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"