From 2b9f185363d1d31276278049df2d83e5f7de6f55 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 30 Jan 2025 15:33:11 -0800 Subject: [PATCH] revert back test --- tests/client-sdk/inference/test_inference.py | 106 +++++++++---------- 1 file changed, 51 insertions(+), 55 deletions(-) diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 6260d1cdf..6dff1be24 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import base64 -import pathlib +import os import pytest from pydantic import BaseModel @@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::fireworks": "json", } +PROVIDER_LOGPROBS_TOP_K = set( + { + "remote::together", + "remote::fireworks", + # "remote:vllm" + } +) + @pytest.fixture(scope="session") def provider_tool_format(inference_provider_type): @@ -48,31 +56,14 @@ def get_weather_tool_definition(): } -# @pytest.fixture -# def base64_image_url(): -# image_path = os.path.join(os.path.dirname(__file__), "dog.png") -# with open(image_path, "rb") as image_file: -# # Convert the image to base64 -# base64_string = base64.b64encode(image_file.read()).decode("utf-8") -# base64_url = f"data:image/png;base64,{base64_string}" -# return base64_url - - @pytest.fixture -def image_path(): - return pathlib.Path(__file__).parent / "dog.png" - - -@pytest.fixture -def base64_image_data(image_path): - # Convert the image to base64 - return base64.b64encode(image_path.read_bytes()).decode("utf-8") - - -@pytest.fixture -def base64_image_url(base64_image_data, image_path): - # suffix includes the ., so we remove it - return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" +def base64_image_url(): + image_path = os.path.join(os.path.dirname(__file__), "dog.png") + with open(image_path, "rb") as image_file: + # Convert the image to base64 + base64_string = base64.b64encode(image_file.read()).decode("utf-8") + base64_url = f"data:image/png;base64,{base64_string}" + return base64_url def test_text_completion_non_streaming(llama_stack_client, text_model_id): @@ -100,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert "blue" in "".join(streamed_content).lower().strip() -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_non_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, @@ -110,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) assert response.logprobs, "Logprobs should not be empty" - assert 1 <= len(response.logprobs) <= 5 - assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs) + assert ( + 1 <= len(response.logprobs) <= 5 + ) # each token has 1 logprob and here max_tokens=5 + assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) -@pytest.mark.skip("Most inference providers don't support log probs yet") -def test_completion_log_probs_streaming(llama_stack_client, text_model_id): +def test_completion_log_probs_streaming( + llama_stack_client, text_model_id, inference_provider_type +): + if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: + pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") + response = llama_stack_client.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, @@ -128,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): "max_tokens": 5, }, logprobs={ - "top_k": 3, + "top_k": 1, }, ) streamed_content = [chunk for chunk in response] @@ -136,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id): if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all( - len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs + len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs ) else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" @@ -370,30 +371,25 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) -@pytest.mark.parametrize("type_", ["url", "data"]) -def test_image_chat_completion_base64( - llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_ +def test_image_chat_completion_base64_url( + llama_stack_client, vision_model_id, base64_image_url ): - image_spec = { - "url": { - "type": "image", - "image": { - "url": { - "uri": base64_image_url, - }, - }, - }, - "data": { - "type": "image", - "image": { - "data": base64_image_data, - }, - }, - }[type_] - message = { "role": "user", - "content": [image_spec], + "content": [ + { + "type": "image", + "image": { + "url": { + "uri": base64_image_url, + }, + }, + }, + { + "type": "text", + "text": "Describe what is in this image.", + }, + ], } response = llama_stack_client.inference.chat_completion( model_id=vision_model_id,