mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
revert back test
This commit is contained in:
parent
93c48588c8
commit
2b9f185363
1 changed files with 51 additions and 55 deletions
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import pathlib
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -16,6 +16,14 @@ PROVIDER_TOOL_PROMPT_FORMAT = {
|
||||||
"remote::fireworks": "json",
|
"remote::fireworks": "json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PROVIDER_LOGPROBS_TOP_K = set(
|
||||||
|
{
|
||||||
|
"remote::together",
|
||||||
|
"remote::fireworks",
|
||||||
|
# "remote:vllm"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def provider_tool_format(inference_provider_type):
|
def provider_tool_format(inference_provider_type):
|
||||||
|
@ -48,31 +56,14 @@ def get_weather_tool_definition():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# @pytest.fixture
|
|
||||||
# def base64_image_url():
|
|
||||||
# image_path = os.path.join(os.path.dirname(__file__), "dog.png")
|
|
||||||
# with open(image_path, "rb") as image_file:
|
|
||||||
# # Convert the image to base64
|
|
||||||
# base64_string = base64.b64encode(image_file.read()).decode("utf-8")
|
|
||||||
# base64_url = f"data:image/png;base64,{base64_string}"
|
|
||||||
# return base64_url
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def image_path():
|
def base64_image_url():
|
||||||
return pathlib.Path(__file__).parent / "dog.png"
|
image_path = os.path.join(os.path.dirname(__file__), "dog.png")
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
# Convert the image to base64
|
||||||
@pytest.fixture
|
base64_string = base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
def base64_image_data(image_path):
|
base64_url = f"data:image/png;base64,{base64_string}"
|
||||||
# Convert the image to base64
|
return base64_url
|
||||||
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def base64_image_url(base64_image_data, image_path):
|
|
||||||
# suffix includes the ., so we remove it
|
|
||||||
return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
def test_text_completion_non_streaming(llama_stack_client, text_model_id):
|
||||||
|
@ -100,8 +91,12 @@ def test_text_completion_streaming(llama_stack_client, text_model_id):
|
||||||
assert "blue" in "".join(streamed_content).lower().strip()
|
assert "blue" in "".join(streamed_content).lower().strip()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
def test_completion_log_probs_non_streaming(
|
||||||
def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
llama_stack_client, text_model_id, inference_provider_type
|
||||||
|
):
|
||||||
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
response = llama_stack_client.inference.completion(
|
response = llama_stack_client.inference.completion(
|
||||||
content="Complete the sentence: Micheael Jordan is born in ",
|
content="Complete the sentence: Micheael Jordan is born in ",
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -110,16 +105,22 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id):
|
||||||
"max_tokens": 5,
|
"max_tokens": 5,
|
||||||
},
|
},
|
||||||
logprobs={
|
logprobs={
|
||||||
"top_k": 3,
|
"top_k": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.logprobs, "Logprobs should not be empty"
|
assert response.logprobs, "Logprobs should not be empty"
|
||||||
assert 1 <= len(response.logprobs) <= 5
|
assert (
|
||||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
1 <= len(response.logprobs) <= 5
|
||||||
|
) # each token has 1 logprob and here max_tokens=5
|
||||||
|
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("Most inference providers don't support log probs yet")
|
def test_completion_log_probs_streaming(
|
||||||
def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
llama_stack_client, text_model_id, inference_provider_type
|
||||||
|
):
|
||||||
|
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||||
|
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||||
|
|
||||||
response = llama_stack_client.inference.completion(
|
response = llama_stack_client.inference.completion(
|
||||||
content="Complete the sentence: Micheael Jordan is born in ",
|
content="Complete the sentence: Micheael Jordan is born in ",
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -128,7 +129,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||||
"max_tokens": 5,
|
"max_tokens": 5,
|
||||||
},
|
},
|
||||||
logprobs={
|
logprobs={
|
||||||
"top_k": 3,
|
"top_k": 1,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
streamed_content = [chunk for chunk in response]
|
streamed_content = [chunk for chunk in response]
|
||||||
|
@ -136,7 +137,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id):
|
||||||
if chunk.delta: # if there's a token, we expect logprobs
|
if chunk.delta: # if there's a token, we expect logprobs
|
||||||
assert chunk.logprobs, "Logprobs should not be empty"
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
assert all(
|
assert all(
|
||||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
||||||
)
|
)
|
||||||
else: # no token, no logprobs
|
else: # no token, no logprobs
|
||||||
assert not chunk.logprobs, "Logprobs should be empty"
|
assert not chunk.logprobs, "Logprobs should be empty"
|
||||||
|
@ -370,30 +371,25 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id):
|
||||||
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("type_", ["url", "data"])
|
def test_image_chat_completion_base64_url(
|
||||||
def test_image_chat_completion_base64(
|
llama_stack_client, vision_model_id, base64_image_url
|
||||||
llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_
|
|
||||||
):
|
):
|
||||||
image_spec = {
|
|
||||||
"url": {
|
|
||||||
"type": "image",
|
|
||||||
"image": {
|
|
||||||
"url": {
|
|
||||||
"uri": base64_image_url,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"data": {
|
|
||||||
"type": "image",
|
|
||||||
"image": {
|
|
||||||
"data": base64_image_data,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}[type_]
|
|
||||||
|
|
||||||
message = {
|
message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [image_spec],
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": {
|
||||||
|
"url": {
|
||||||
|
"uri": base64_image_url,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe what is in this image.",
|
||||||
|
},
|
||||||
|
],
|
||||||
}
|
}
|
||||||
response = llama_stack_client.inference.chat_completion(
|
response = llama_stack_client.inference.chat_completion(
|
||||||
model_id=vision_model_id,
|
model_id=vision_model_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue