add baseten support

This commit is contained in:
ishaan-jaff 2023-08-15 12:15:39 -07:00
parent d7d9893cf3
commit 02a666ee02
3 changed files with 75 additions and 3 deletions

View file

@ -431,6 +431,32 @@ def completion(
generator = get_ollama_response_stream(endpoint, model, prompt)
# assume all responses are streamed
return generator
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
install_and_import("baseten")
import baseten
base_ten_key = get_secret('BASETEN_API_KEY')
baseten.login(base_ten_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
base_ten__model = baseten.deployed_model_version_id(model)
completion_response = base_ten__model.predict({"prompt": prompt})
if type(completion_response) == dict:
completion_response = completion_response["data"]
if type(completion_response) == dict:
completion_response = completion_response["generated_text"]
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif custom_llm_provider == "petals":
install_and_import("transformers")
from transformers import AutoTokenizer
@ -446,7 +472,7 @@ def completion(
outputs = model.generate(
inputs=inputs,
temperature=1.0
max_new_tokens=5
)
print("got output", outputs)