diff --git a/litellm/main.py b/litellm/main.py index bb7da6352..9046a8bc5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -431,6 +431,32 @@ def completion( generator = get_ollama_response_stream(endpoint, model, prompt) # assume all responses are streamed return generator + elif custom_llm_provider == "baseten" or litellm.api_base=="https://app.baseten.co": + install_and_import("baseten") + import baseten + base_ten_key = get_secret('BASETEN_API_KEY') + baseten.login(base_ten_key) + + prompt = " ".join([message["content"] for message in messages]) + ## LOGGING + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn) + + base_ten__model = baseten.deployed_model_version_id(model) + + completion_response = base_ten__model.predict({"prompt": prompt}) + if type(completion_response) == dict: + completion_response = completion_response["data"] + if type(completion_response) == dict: + completion_response = completion_response["generated_text"] + + logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn) + + ## RESPONSE OBJECT + model_response["choices"][0]["message"]["content"] = completion_response + model_response["created"] = time.time() + model_response["model"] = model + response = model_response + elif custom_llm_provider == "petals": install_and_import("transformers") from transformers import AutoTokenizer @@ -446,7 +472,7 @@ def completion( outputs = model.generate( inputs=inputs, - temperature=1.0 + max_new_tokens=5 ) print("got output", outputs) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3719234b9..a1e5b4bf1 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -59,7 +59,7 @@ def test_completion_hf_deployed_api(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_hf_deployed_api() + def test_completion_cohere(): try: response = completion(model="command-nightly", messages=messages, max_tokens=500) @@ -213,6 +213,52 @@ def test_completion_together_ai_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_baseten_falcon_7bcompletion(): + model_name = "qvv0xeq" + try: + response = completion(model=model_name, messages=messages, custom_llm_provider="baseten") + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +def test_baseten_falcon_7bcompletion_withbase(): + model_name = "qvv0xeq" + litellm.api_base = "https://app.baseten.co" + try: + response = completion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + + + +# def test_baseten_wizardLMcompletion_withbase(): +# model_name = "q841o8w" +# litellm.api_base = "https://app.baseten.co" +# try: +# response = completion(model=model_name, messages=messages) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# test_baseten_wizardLMcompletion_withbase() + +# def test_baseten_mosaic_ML_completion_withbase(): +# model_name = "31dxrj3" +# litellm.api_base = "https://app.baseten.co" +# try: +# response = completion(model=model_name, messages=messages) +# # Add any assertions here to check the response +# print(response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + + + #### Test A121 ################### # def test_completion_ai21(): # model_name = "j2-light" diff --git a/pyproject.toml b/pyproject.toml index 19ae6db52..5aea23ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.394" +version = "0.1.395" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"