move baseten to a REST endpoint call

This commit is contained in:
Krrish Dholakia 2023-08-24 14:43:49 -07:00
parent 725611aa58
commit 6e30b234ac
10 changed files with 173 additions and 33 deletions

View file

@ -21,6 +21,7 @@ from litellm.utils import (
)
from .llms.anthropic import AnthropicLLM
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -73,6 +74,7 @@ def completion(
max_tokens=float("inf"),
presence_penalty=0,
frequency_penalty=0,
num_beams=1,
logit_bias={},
user="",
deployment_id=None,
@ -681,36 +683,31 @@ def completion(
custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co"
):
import baseten
base_ten_key = get_secret("BASETEN_API_KEY")
baseten.login(base_ten_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=base_ten_key, model=model)
base_ten__model = baseten.deployed_model_version_id(model)
completion_response = base_ten__model.predict({"prompt": prompt})
if type(completion_response) == dict:
completion_response = completion_response["data"]
if type(completion_response) == dict:
completion_response = completion_response["generated_text"]
## LOGGING
logging.post_call(
input=prompt,
api_key=base_ten_key,
original_response=completion_response,
custom_llm_provider = "baseten"
baseten_key = (
api_key
or litellm.baseten_key
or os.environ.get("BASETEN_API_KEY")
)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
baseten_client = BasetenLLM(
encoding=encoding, api_key=baseten_key, logging_obj=logging
)
model_response = baseten_client.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface"
)
return response
response = model_response
elif custom_llm_provider == "petals" or (
litellm.api_base and "chat.petals.dev" in litellm.api_base
):