forked from phoenix/litellm-mirror
update baseten handler to handle TGI calls
This commit is contained in:
parent
9cc0a094c0
commit
14d4c7ead2
8 changed files with 79 additions and 38 deletions
Binary file not shown.
Binary file not shown.
|
@ -60,9 +60,12 @@ class BasetenLLM:
|
|||
else:
|
||||
prompt += f"{message['content']}"
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
# "prompt": prompt,
|
||||
"inputs": prompt, # in case it's a TGI deployed model
|
||||
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
|
||||
**optional_params,
|
||||
# **optional_params,
|
||||
"parameters": optional_params,
|
||||
"stream": True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
|
@ -76,8 +79,9 @@ class BasetenLLM:
|
|||
self.completion_url_fragment_1 + model + self.completion_url_fragment_2,
|
||||
headers=self.headers,
|
||||
data=json.dumps(data),
|
||||
stream=True if "stream" in optional_params and optional_params["stream"] == True else False
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if 'text/event-stream' in response.headers['Content-Type'] or ("stream" in optional_params and optional_params["stream"] == True):
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -117,9 +121,23 @@ class BasetenLLM:
|
|||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = completion_response["completion"]
|
||||
elif isinstance(completion_response, list) and len(completion_response) > 0:
|
||||
if "generated_text" not in completion_response:
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
)
|
||||
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
|
||||
## GETTING LOGPROBS
|
||||
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
sum_logprob += token["logprob"]
|
||||
model_response["choices"][0]["message"]["logprobs"] = sum_logprob
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to parse response. Original response: {response.text}"
|
||||
raise BasetenError(
|
||||
message=f"Unable to parse response. Original response: {response.text}",
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import os, openai, sys, json
|
||||
import os, openai, sys, json, inspect
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
|
@ -682,7 +682,7 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True):
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="baseten", logging_obj=logging
|
||||
|
|
|
@ -20,7 +20,7 @@ litellm.use_client = True
|
|||
# litellm.set_verbose = True
|
||||
# litellm.secret_manager_client = InfisicalClient(token=os.environ["INFISICAL_TOKEN"])
|
||||
|
||||
user_message = "Hello, whats the weather in San Francisco??"
|
||||
user_message = "write me a function to print hello world in python"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
|
||||
|
@ -383,6 +383,14 @@ def test_completion_with_fallbacks():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# def test_baseten():
|
||||
# try:
|
||||
|
||||
# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn)
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# def test_baseten_falcon_7bcompletion():
|
||||
# model_name = "qvv0xeq"
|
||||
|
|
|
@ -25,23 +25,24 @@ user_message = "Hello, how are you?"
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
|
||||
# test on baseten completion call
|
||||
try:
|
||||
response = completion(
|
||||
model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn
|
||||
)
|
||||
print(f"response: {response}")
|
||||
complete_response = ""
|
||||
start_time = time.time()
|
||||
for chunk in response:
|
||||
chunk_time = time.time()
|
||||
print(f"time since initial request: {chunk_time - start_time:.5f}")
|
||||
print(chunk["choices"][0]["delta"])
|
||||
complete_response += chunk["choices"][0]["delta"]["content"]
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
except:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pass
|
||||
# try:
|
||||
# response = completion(
|
||||
# model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn
|
||||
# )
|
||||
# print(f"response: {response}")
|
||||
# complete_response = ""
|
||||
# start_time = time.time()
|
||||
# for chunk in response:
|
||||
# chunk_time = time.time()
|
||||
# print(f"time since initial request: {chunk_time - start_time:.5f}")
|
||||
# print(chunk["choices"][0]["delta"])
|
||||
# complete_response += chunk["choices"][0]["delta"]["content"]
|
||||
# if complete_response == "":
|
||||
# raise Exception("Empty response received")
|
||||
# print(f"complete response: {complete_response}")
|
||||
# except:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
# pass
|
||||
|
||||
# test on openai completion call
|
||||
try:
|
||||
|
|
|
@ -735,7 +735,8 @@ def get_optional_params( # use the openai defaults
|
|||
elif custom_llm_provider == "baseten":
|
||||
optional_params["temperature"] = temperature
|
||||
optional_params["stream"] = stream
|
||||
optional_params["top_p"] = top_p
|
||||
if top_p != 1:
|
||||
optional_params["top_p"] = top_p
|
||||
optional_params["top_k"] = top_k
|
||||
optional_params["num_beams"] = num_beams
|
||||
if max_tokens != float("inf"):
|
||||
|
@ -1739,18 +1740,31 @@ class CustomStreamWrapper:
|
|||
return chunk["choices"][0]["delta"]["content"]
|
||||
|
||||
def handle_baseten_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
data_json = json.loads(chunk)
|
||||
if "model_output" in data_json:
|
||||
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
|
||||
return data_json["model_output"]["data"][0]
|
||||
elif isinstance(data_json["model_output"], str):
|
||||
return data_json["model_output"]
|
||||
elif "completion" in data_json and isinstance(data_json["completion"], str):
|
||||
return data_json["completion"]
|
||||
try:
|
||||
chunk = chunk.decode("utf-8")
|
||||
if len(chunk) > 0:
|
||||
if chunk.startswith("data:"):
|
||||
data_json = json.loads(chunk[5:])
|
||||
if "token" in data_json and "text" in data_json["token"]:
|
||||
return data_json["token"]["text"]
|
||||
else:
|
||||
return ""
|
||||
data_json = json.loads(chunk)
|
||||
if "model_output" in data_json:
|
||||
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
|
||||
return data_json["model_output"]["data"][0]
|
||||
elif isinstance(data_json["model_output"], str):
|
||||
return data_json["model_output"]
|
||||
elif "completion" in data_json and isinstance(data_json["completion"], str):
|
||||
return data_json["completion"]
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
else:
|
||||
return ""
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return ""
|
||||
|
||||
def __next__(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "litellm"
|
||||
version = "0.1.508"
|
||||
version = "0.1.509"
|
||||
description = "Library to easily interface with LLM API providers"
|
||||
authors = ["BerriAI"]
|
||||
license = "MIT License"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue