diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d46f39802f..0841db134d 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -14,6 +14,7 @@ from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError from litellm.llms.prompt_templates.factory import anthropic_messages_pt from unittest.mock import patch, MagicMock +from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler # litellm.num_retries=3 litellm.cache = None @@ -152,29 +153,63 @@ async def test_completion_databricks(sync_mode): response_format_tests(response=response) +def predibase_mock_post(url, data=None, json=None, headers=None): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = { + "generated_text": " Is it to find happiness, to achieve success,", + "details": { + "finish_reason": "length", + "prompt_tokens": 8, + "generated_tokens": 10, + "seed": None, + "prefill": [], + "tokens": [ + {"id": 2209, "text": " Is", "logprob": -1.7568359, "special": False}, + {"id": 433, "text": " it", "logprob": -0.2220459, "special": False}, + {"id": 311, "text": " to", "logprob": -0.6928711, "special": False}, + {"id": 1505, "text": " find", "logprob": -0.6425781, "special": False}, + { + "id": 23871, + "text": " happiness", + "logprob": -0.07519531, + "special": False, + }, + {"id": 11, "text": ",", "logprob": -0.07110596, "special": False}, + {"id": 311, "text": " to", "logprob": -0.79296875, "special": False}, + { + "id": 11322, + "text": " achieve", + "logprob": -0.7602539, + "special": False, + }, + { + "id": 2450, + "text": " success", + "logprob": -0.03656006, + "special": False, + }, + {"id": 11, "text": ",", "logprob": -0.0011510849, "special": False}, + ], + }, + } + return mock_response + + # @pytest.mark.skip(reason="local only test") -@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_completion_predibase(sync_mode): +async def test_completion_predibase(): try: litellm.set_verbose = True - if sync_mode: + with patch("requests.post", side_effect=predibase_mock_post): response = completion( model="predibase/llama-3-8b-instruct", tenant_id="c4768f95", api_key=os.getenv("PREDIBASE_API_KEY"), messages=[{"role": "user", "content": "What is the meaning of life?"}], - ) - - print(response) - else: - response = await litellm.acompletion( - model="predibase/llama-3-8b-instruct", - tenant_id="c4768f95", - api_base="https://serving.app.predibase.com", - api_key=os.getenv("PREDIBASE_API_KEY"), - messages=[{"role": "user", "content": "What is the meaning of life?"}], + max_tokens=10, ) print(response) diff --git a/litellm/utils.py b/litellm/utils.py index 91f2b48a14..c4f51bf8c6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5457,7 +5457,7 @@ def get_optional_params( optional_params["top_p"] = top_p if stop is not None: optional_params["stop_sequences"] = stop - elif custom_llm_provider == "huggingface": + elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase": ## check if unsupported param passed in supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -5912,7 +5912,6 @@ def get_optional_params( optional_params["logprobs"] = logprobs if top_logprobs is not None: optional_params["top_logprobs"] = top_logprobs - elif custom_llm_provider == "openrouter": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider