forked from phoenix/litellm-mirror
fix import baseten + petals test
This commit is contained in:
parent
3e09743173
commit
f336dafd04
2 changed files with 25 additions and 26 deletions
|
@ -432,7 +432,6 @@ def completion(
|
||||||
# assume all responses are streamed
|
# assume all responses are streamed
|
||||||
return generator
|
return generator
|
||||||
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
|
elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co":
|
||||||
install_and_import("baseten")
|
|
||||||
import baseten
|
import baseten
|
||||||
base_ten_key = get_secret('BASETEN_API_KEY')
|
base_ten_key = get_secret('BASETEN_API_KEY')
|
||||||
baseten.login(base_ten_key)
|
baseten.login(base_ten_key)
|
||||||
|
@ -457,36 +456,25 @@ def completion(
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
elif custom_llm_provider == "petals":
|
elif custom_llm_provider == "petals" or "chat.petals.dev" in litellm.api_base:
|
||||||
install_and_import("transformers")
|
url = "https://chat.petals.dev/api/v1/generate"
|
||||||
from transformers import AutoTokenizer
|
import requests
|
||||||
from petals import AutoDistributedModelForCausalLM
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
|
response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model})
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
## LOGGING
|
||||||
model = AutoDistributedModelForCausalLM.from_pretrained(model)
|
#logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||||
|
|
||||||
print("got model", model)
|
#response.text
|
||||||
|
print("got response", response.json())
|
||||||
|
print("got response text", response.text)
|
||||||
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
|
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet
|
||||||
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
|
|
||||||
|
|
||||||
outputs = model.generate(
|
|
||||||
inputs=inputs,
|
|
||||||
max_new_tokens=5
|
|
||||||
)
|
|
||||||
|
|
||||||
print("got output", outputs)
|
|
||||||
completion_response = tokenizer.decode(outputs[0])
|
|
||||||
|
|
||||||
print("got output text", completion_response)
|
|
||||||
## LOGGING
|
|
||||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response
|
# model_response["choices"][0]["message"]["content"] = completion_response
|
||||||
model_response["created"] = time.time()
|
# model_response["created"] = time.time()
|
||||||
model_response["model"] = model
|
# model_response["model"] = model
|
||||||
response = model_response
|
# response = model_response
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||||
|
|
|
@ -213,6 +213,17 @@ def test_completion_together_ai_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# def test_petals():
|
||||||
|
# model_name = "stabilityai/StableBeluga2"
|
||||||
|
# try:
|
||||||
|
# response = completion(model=model_name, messages=messages, custom_llm_provider="petals")
|
||||||
|
# # Add any assertions here to check the response
|
||||||
|
# print(response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
# # test_petals()
|
||||||
|
|
||||||
# def test_baseten_falcon_7bcompletion():
|
# def test_baseten_falcon_7bcompletion():
|
||||||
# model_name = "qvv0xeq"
|
# model_name = "qvv0xeq"
|
||||||
# try:
|
# try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue