diff --git a/litellm/main.py b/litellm/main.py index 9046a8bc5..e01a307dd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -432,7 +432,6 @@ def completion( # assume all responses are streamed return generator elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co": - install_and_import("baseten") import baseten base_ten_key = get_secret('BASETEN_API_KEY') baseten.login(base_ten_key) @@ -457,36 +456,25 @@ def completion( model_response["model"] = model response = model_response - elif custom_llm_provider == "petals": - install_and_import("transformers") - from transformers import AutoTokenizer - from petals import AutoDistributedModelForCausalLM + elif custom_llm_provider == "petals" or "chat.petals.dev" in litellm.api_base: + url = "https://chat.petals.dev/api/v1/generate" + import requests + prompt = " ".join([message["content"] for message in messages]) + response = requests.post(url, data={"inputs": prompt, "max_new_tokens": 100, "model": model}) - tokenizer = AutoTokenizer.from_pretrained(model) - model = AutoDistributedModelForCausalLM.from_pretrained(model) + ## LOGGING + #logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) - print("got model", model) + #response.text + print("got response", response.json()) + print("got response text", response.text) # Embeddings & prompts are on your device, transformer blocks are distributed across the Internet - inputs = tokenizer(prompt, return_tensors="pt")["input_ids"] - - outputs = model.generate( - inputs=inputs, - max_new_tokens=5 - ) - - print("got output", outputs) - completion_response = tokenizer.decode(outputs[0]) - - print("got output text", completion_response) - ## LOGGING - logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) - ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = completion_response - model_response["created"] = time.time() - model_response["model"] = model - response = model_response + # model_response["choices"][0]["message"]["content"] = completion_response + # model_response["created"] = time.time() + # model_response["model"] = model + # response = model_response else: ## LOGGING logging(model=model, input=messages, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f07a1dcee..9d3a394ef 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -213,6 +213,17 @@ def test_completion_together_ai_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") + +# def test_petals(): +# model_name = "stabilityai/StableBeluga2" +# try: +# response = completion(model=model_name, messages=messages, custom_llm_provider="petals") +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") +# # test_petals() + # def test_baseten_falcon_7bcompletion(): # model_name = "qvv0xeq" # try: