working petals implementation

This commit is contained in:
ishaan-jaff 2023-08-15 14:02:29 -07:00
parent f336dafd04
commit 93a0316ab7
3 changed files with 20 additions and 23 deletions

View file

@ -460,21 +460,19 @@ def completion(
url = "https://chat.petals.dev/api/v1/generate" url = "https://chat.petals.dev/api/v1/generate"
import requests import requests
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model})
## LOGGING ## LOGGING
#logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model})
## LOGGING
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response}, logger_fn=logger_fn)
completion_response = response.json()["outputs"]
#response.text # RESPONSE OBJECT
print("got response", response.json()) model_response["choices"][0]["message"]["content"] = completion_response
print("got response text", response.text) model_response["created"] = time.time()
# Embeddings & prompts are on your device, transformer blocks are distributed across the Internet model_response["model"] = model
response = model_response
## RESPONSE OBJECT
# model_response["choices"][0]["message"]["content"] = completion_response
# model_response["created"] = time.time()
# model_response["model"] = model
# response = model_response
else: else:
## LOGGING ## LOGGING
logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)

View file

@ -214,15 +214,14 @@ def test_completion_together_ai_stream():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# def test_petals(): def test_petals():
# model_name = "stabilityai/StableBeluga2" model_name = "stabilityai/StableBeluga2"
# try: try:
# response = completion(model=model_name, messages=messages, custom_llm_provider="petals") response = completion(model=model_name, messages=messages, custom_llm_provider="petals", force_timeout=120)
# # Add any assertions here to check the response # Add any assertions here to check the response
# print(response) print(response)
# except Exception as e: except Exception as e:
# pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# # test_petals()
# def test_baseten_falcon_7bcompletion(): # def test_baseten_falcon_7bcompletion():
# model_name = "qvv0xeq" # model_name = "qvv0xeq"

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.398" version = "0.1.399"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"