forked from phoenix/litellm-mirror
add baseten support
This commit is contained in:
parent
d7d9893cf3
commit
02a666ee02
3 changed files with 75 additions and 3 deletions
|
@ -431,6 +431,32 @@ def completion(
|
||||||
generator = get_ollama_response_stream(endpoint, model, prompt)
|
generator = get_ollama_response_stream(endpoint, model, prompt)
|
||||||
# assume all responses are streamed
|
# assume all responses are streamed
|
||||||
return generator
|
return generator
|
||||||
|
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
|
||||||
|
install_and_import("baseten")
|
||||||
|
import baseten
|
||||||
|
base_ten_key = get_secret('BASETEN_API_KEY')
|
||||||
|
baseten.login(base_ten_key)
|
||||||
|
|
||||||
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||||
|
|
||||||
|
base_ten__model = baseten.deployed_model_version_id(model)
|
||||||
|
|
||||||
|
completion_response = base_ten__model.predict({"prompt": prompt})
|
||||||
|
if type(completion_response) == dict:
|
||||||
|
completion_response = completion_response["data"]
|
||||||
|
if type(completion_response) == dict:
|
||||||
|
completion_response = completion_response["generated_text"]
|
||||||
|
|
||||||
|
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_response
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = model
|
||||||
|
response = model_response
|
||||||
|
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals":
|
||||||
install_and_import("transformers")
|
install_and_import("transformers")
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -446,7 +472,7 @@ def completion(
|
||||||
|
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
temperature=1.0
|
max_new_tokens=5
|
||||||
)
|
)
|
||||||
|
|
||||||
print("got output", outputs)
|
print("got output", outputs)
|
||||||
|
|
|
@ -59,7 +59,7 @@ def test_completion_hf_deployed_api():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
test_completion_hf_deployed_api()
|
|
||||||
def test_completion_cohere():
|
def test_completion_cohere():
|
||||||
try:
|
try:
|
||||||
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
response = completion(model="command-nightly", messages=messages, max_tokens=500)
|
||||||
|
@ -213,6 +213,52 @@ def test_completion_together_ai_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_baseten_falcon_7bcompletion():
|
||||||
|
model_name = "qvv0xeq"
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages, custom_llm_provider="baseten")
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_baseten_falcon_7bcompletion_withbase():
|
||||||
|
model_name = "qvv0xeq"
|
||||||
|
litellm.api_base = "https://app.baseten.co"
|
||||||
|
try:
|
||||||
|
response = completion(model=model_name, messages=messages)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# def test_baseten_wizardLMcompletion_withbase():
|
||||||
|
# model_name = "q841o8w"
|
||||||
|
# litellm.api_base = "https://app.baseten.co"
|
||||||
|
# try:
|
||||||
|
# response = completion(model=model_name, messages=messages)
|
||||||
|
# # Add any assertions here to check the response
|
||||||
|
# print(response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# test_baseten_wizardLMcompletion_withbase()
|
||||||
|
|
||||||
|
# def test_baseten_mosaic_ML_completion_withbase():
|
||||||
|
# model_name = "31dxrj3"
|
||||||
|
# litellm.api_base = "https://app.baseten.co"
|
||||||
|
# try:
|
||||||
|
# response = completion(model=model_name, messages=messages)
|
||||||
|
# # Add any assertions here to check the response
|
||||||
|
# print(response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#### Test A121 ###################
|
#### Test A121 ###################
|
||||||
# def test_completion_ai21():
|
# def test_completion_ai21():
|
||||||
# model_name = "j2-light"
|
# model_name = "j2-light"
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.394"
|
version = "0.1.395"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue