From 93a0316ab76bf50f40f99b2bd4d7c2042a65ba6a Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 15 Aug 2023 14:02:29 -0700 Subject: [PATCH] working petals implementation --- litellm/main.py | 24 +++++++++++------------- litellm/tests/test_completion.py | 17 ++++++++--------- pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index e01a307dd..18453e0c7 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -460,21 +460,19 @@ def completion( url = "https://chat.petals.dev/api/v1/generate" import requests prompt = " ".join([message["content"] for message in messages]) - response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model}) ## LOGGING - #logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) - - #response.text - print("got response", response.json()) - print("got response text", response.text) - # Embeddings & prompts are on your device, transformer blocks are distributed across the Internet - - ## RESPONSE OBJECT - # model_response["choices"][0]["message"]["content"] = completion_response - # model_response["created"] = time.time() - # model_response["model"] = model - # response = model_response + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model}) + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": response}, logger_fn=logger_fn) + completion_response = response.json()["outputs"] + + # RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + response = model_response else: ## LOGGING logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 9d3a394ef..a73599cbe 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -214,15 +214,14 @@ def test_completion_together_ai_stream(): pytest.fail(f"Error occurred: {e}") -# def test_petals(): -# model_name = "stabilityai/StableBeluga2" -# try: -# response = completion(model=model_name, messages=messages, custom_llm_provider="petals") -# # Add any assertions here to check the response -# print(response) -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") -# # test_petals() +def test_petals(): + model_name = "stabilityai/StableBeluga2" + try: + response = completion(model=model_name, messages=messages, custom_llm_provider="petals", force_timeout=120) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") # def test_baseten_falcon_7bcompletion(): # model_name = "qvv0xeq" diff --git a/pyproject.toml b/pyproject.toml index f7914b0fe..b7be3a0f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.398" +version = "0.1.399" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"