update baseten handler to handle TGI calls

This commit is contained in:
Krrish Dholakia 2023-08-30 19:14:48 -07:00
parent 9cc0a094c0
commit 14d4c7ead2
8 changed files with 79 additions and 38 deletions

View file

@ -60,9 +60,12 @@ class BasetenLLM:
else:
prompt += f"{message['content']}"
data = {
"prompt": prompt,
# "prompt": prompt,
"inputs": prompt, # in case it's a TGI deployed model
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
**optional_params,
# **optional_params,
"parameters": optional_params,
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
}
## LOGGING
@ -76,8 +79,9 @@ class BasetenLLM:
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
headers=self.headers,
data=json.dumps(data),
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
)
if "stream" in optional_params and optional_params["stream"] == True:
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
return response.iter_lines()
else:
## LOGGING
@ -117,9 +121,23 @@ class BasetenLLM:
model_response["choices"][0]["message"][
"content"
] = completion_response["completion"]
elif isinstance(completion_response, list) and len(completion_response) > 0:
if "generated_text" not in completion_response:
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
)
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
else:
raise ValueError(
f"Unable to parse response. Original response: {response.text}"
raise BasetenError(
message=f"Unable to parse response. Original response: {response.text}",
status_code=response.status_code
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.

View file

@ -1,4 +1,4 @@
import os, openai, sys, json
import os, openai, sys, json, inspect
from typing import Any
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
@ -682,7 +682,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
)
if "stream" in optional_params and optional_params["stream"] == True:
if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True):
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="baseten", logging_obj=logging

View file

@ -20,7 +20,7 @@ litellm.use_client = True
# litellm.set_verbose = True
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
user_message = "Hello, whats the weather in San Francisco??"
user_message = "write me a function to print hello world in python"
messages = [{"content": user_message, "role": "user"}]
@ -383,6 +383,14 @@ def test_completion_with_fallbacks():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_baseten():
# try:
# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# def test_baseten_falcon_7bcompletion():
# model_name = "qvv0xeq"

View file

@ -25,23 +25,24 @@ user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
# test on baseten completion call
try:
response = completion(
model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# try:
# response = completion(
# model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn
# )
# print(f"response: {response}")
# complete_response = ""
# start_time = time.time()
# for chunk in response:
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.5f}")
# print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"]
# if complete_response == "":
# raise Exception("Empty response received")
# print(f"complete response: {complete_response}")
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# test on openai completion call
try:

View file

@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults
elif custom_llm_provider == "baseten":
optional_params["temperature"] = temperature
optional_params["stream"] = stream
optional_params["top_p"] = top_p
if top_p != 1:
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k
optional_params["num_beams"] = num_beams
if max_tokens != float("inf"):
@ -1739,18 +1740,31 @@ class CustomStreamWrapper:
return chunk["choices"][0]["delta"]["content"]
def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
try:
chunk = chunk.decode("utf-8")
if len(chunk) > 0:
if chunk.startswith("data:"):
data_json = json.loads(chunk[5:])
if "token" in data_json and "text" in data_json["token"]:
return data_json["token"]["text"]
else:
return ""
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
except:
traceback.print_exc()
return ""
def __next__(self):

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.508"
version = "0.1.509"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"