fix(test_completion.py): fix predibase test to be mock + fix optional param mapping for predibase

This commit is contained in:
Krrish Dholakia 2024-06-04 20:06:23 -07:00
parent 85cfbf0f86
commit 3dcf287826
2 changed files with 49 additions and 15 deletions

View file

@ -14,6 +14,7 @@ from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -152,29 +153,63 @@ async def test_completion_databricks(sync_mode):
response_format_tests(response=response) response_format_tests(response=response)
def predibase_mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"generated_text": " Is it to find happiness, to achieve success,",
"details": {
"finish_reason": "length",
"prompt_tokens": 8,
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{"id": 2209, "text": " Is", "logprob": -1.7568359, "special": False},
{"id": 433, "text": " it", "logprob": -0.2220459, "special": False},
{"id": 311, "text": " to", "logprob": -0.6928711, "special": False},
{"id": 1505, "text": " find", "logprob": -0.6425781, "special": False},
{
"id": 23871,
"text": " happiness",
"logprob": -0.07519531,
"special": False,
},
{"id": 11, "text": ",", "logprob": -0.07110596, "special": False},
{"id": 311, "text": " to", "logprob": -0.79296875, "special": False},
{
"id": 11322,
"text": " achieve",
"logprob": -0.7602539,
"special": False,
},
{
"id": 2450,
"text": " success",
"logprob": -0.03656006,
"special": False,
},
{"id": 11, "text": ",", "logprob": -0.0011510849, "special": False},
],
},
}
return mock_response
# @pytest.mark.skip(reason="local only test") # @pytest.mark.skip(reason="local only test")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_predibase(sync_mode): async def test_completion_predibase():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
if sync_mode: with patch("requests.post", side_effect=predibase_mock_post):
response = completion( response = completion(
model="predibase/llama-3-8b-instruct", model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95", tenant_id="c4768f95",
api_key=os.getenv("PREDIBASE_API_KEY"), api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}], messages=[{"role": "user", "content": "What is the meaning of life?"}],
) max_tokens=10,
print(response)
else:
response = await litellm.acompletion(
model="predibase/llama-3-8b-instruct",
tenant_id="c4768f95",
api_base="https://serving.app.predibase.com",
api_key=os.getenv("PREDIBASE_API_KEY"),
messages=[{"role": "user", "content": "What is the meaning of life?"}],
) )
print(response) print(response)

View file

@ -5457,7 +5457,7 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if stop is not None: if stop is not None:
optional_params["stop_sequences"] = stop optional_params["stop_sequences"] = stop
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface" or custom_llm_provider == "predibase":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -5912,7 +5912,6 @@ def get_optional_params(
optional_params["logprobs"] = logprobs optional_params["logprobs"] = logprobs
if top_logprobs is not None: if top_logprobs is not None:
optional_params["top_logprobs"] = top_logprobs optional_params["top_logprobs"] = top_logprobs
elif custom_llm_provider == "openrouter": elif custom_llm_provider == "openrouter":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider