add baseten support

This commit is contained in:
ishaan-jaff 2023-08-15 12:15:39 -07:00
parent d7d9893cf3
commit 02a666ee02
3 changed files with 75 additions and 3 deletions

View file

@ -431,6 +431,32 @@ def completion(
generator = get_ollama_response_stream(endpoint, model, prompt) generator = get_ollama_response_stream(endpoint, model, prompt)
# assume all responses are streamed # assume all responses are streamed
return generator return generator
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
install_and_import("baseten")
import baseten
base_ten_key = get_secret('BASETEN_API_KEY')
baseten.login(base_ten_key)
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
base_ten__model = baseten.deployed_model_version_id(model)
completion_response = base_ten__model.predict({"prompt": prompt})
if type(completion_response) == dict:
completion_response = completion_response["data"]
if type(completion_response) == dict:
completion_response = completion_response["generated_text"]
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif custom_llm_provider == "petals": elif custom_llm_provider == "petals":
install_and_import("transformers") install_and_import("transformers")
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -446,7 +472,7 @@ def completion(
outputs = model.generate( outputs = model.generate(
inputs=inputs, inputs=inputs,
temperature=1.0 max_new_tokens=5
) )
print("got output", outputs) print("got output", outputs)

View file

@ -59,7 +59,7 @@ def test_completion_hf_deployed_api():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_hf_deployed_api()
def test_completion_cohere(): def test_completion_cohere():
try: try:
response = completion(model="command-nightly", messages=messages, max_tokens=500) response = completion(model="command-nightly", messages=messages, max_tokens=500)
@ -213,6 +213,52 @@ def test_completion_together_ai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def test_baseten_falcon_7bcompletion():
model_name = "qvv0xeq"
try:
response = completion(model=model_name, messages=messages, custom_llm_provider="baseten")
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_baseten_falcon_7bcompletion_withbase():
model_name = "qvv0xeq"
litellm.api_base = "https://app.baseten.co"
try:
response = completion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_baseten_wizardLMcompletion_withbase():
# model_name = "q841o8w"
# litellm.api_base = "https://app.baseten.co"
# try:
# response = completion(model=model_name, messages=messages)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_baseten_wizardLMcompletion_withbase()
# def test_baseten_mosaic_ML_completion_withbase():
# model_name = "31dxrj3"
# litellm.api_base = "https://app.baseten.co"
# try:
# response = completion(model=model_name, messages=messages)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
#### Test A121 ################### #### Test A121 ###################
# def test_completion_ai21(): # def test_completion_ai21():
# model_name = "j2-light" # model_name = "j2-light"

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.394" version = "0.1.395"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"