fixes and linting

This commit is contained in:
Hardik Shah 2025-03-28 18:33:36 -07:00
parent 021dd0d35d
commit 5251d2422d
8 changed files with 149 additions and 345 deletions

View file

@ -8,9 +8,9 @@
import os
import pytest
from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model
from pydantic import BaseModel
from ..test_cases.test_case import TestCase
@ -29,9 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
"remote::gemini",
"remote::groq",
):
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
def get_llama_model(client_with_models, model_id):
@ -112,9 +110,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(
client_with_models, text_model_id, inference_provider_type, test_case
):
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
@ -141,9 +137,7 @@ def test_text_completion_stop_sequence(
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_non_streaming(
client_with_models, text_model_id, inference_provider_type, test_case
):
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -162,9 +156,7 @@ def test_text_completion_log_probs_non_streaming(
},
)
assert response.logprobs, "Logprobs should not be empty"
assert (
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
@ -174,9 +166,7 @@ def test_text_completion_log_probs_non_streaming(
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_streaming(
client_with_models, text_model_id, inference_provider_type, test_case
):
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -198,9 +188,7 @@ def test_text_completion_log_probs_streaming(
for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@ -211,13 +199,9 @@ def test_text_completion_log_probs_streaming(
"inference:completion:structured_output",
],
)
def test_text_completion_structured_output(
client_with_models, text_model_id, test_case, inference_provider_type
):
def test_text_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support structured outputs yet"
)
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
class AnswerFormat(BaseModel):
@ -254,9 +238,7 @@ def test_text_completion_structured_output(
"inference:chat_completion:non_streaming_02",
],
)
def test_text_chat_completion_non_streaming(
client_with_models, text_model_id, test_case
):
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
@ -282,18 +264,15 @@ def test_text_chat_completion_non_streaming(
"inference:chat_completion:ttft",
],
)
def test_text_chat_completion_first_token_profiling(
client_with_models, text_model_id, test_case
):
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
messages = tc["messages"]
if os.environ.get(
"DEBUG_TTFT"
): # debugging print number of tokens in input, ideally around 800
from llama_stack.apis.inference import Message
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
from pydantic import TypeAdapter
from llama_stack.apis.inference import Message
tokenizer, formatter = get_llama_tokenizer()
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
encoded = formatter.encode_dialog_prompt(typed_messages, None)
@ -307,9 +286,7 @@ def test_text_chat_completion_first_token_profiling(
message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0
if os.environ.get(
"DEBUG_TTFT"
): # debugging print number of tokens in response, ideally around 150
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
tokenizer, formatter = get_llama_tokenizer()
encoded = formatter.encode_content(message_content)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
@ -332,9 +309,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
messages=[{"role": "user", "content": question}],
stream=True,
)
streamed_content = [
str(chunk.event.delta.text.lower().strip()) for chunk in response
]
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
assert len(streamed_content) > 0
assert expected.lower() in "".join(streamed_content)
@ -345,9 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, test_case
):
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -361,10 +334,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert (
response.completion_message.tool_calls[0].tool_name
== tc["tools"][0]["tool_name"]
)
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
@ -387,9 +357,7 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, test_case
):
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -415,9 +383,7 @@ def test_text_chat_completion_with_tool_choice_required(
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
)
pytest.xfail(f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet")
tc = TestCase(test_case)
@ -442,9 +408,7 @@ def test_text_chat_completion_with_tool_choice_required(
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(
client_with_models, text_model_id, test_case
):
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -464,13 +428,9 @@ def test_text_chat_completion_with_tool_choice_none(
"inference:chat_completion:structured_output",
],
)
def test_text_chat_completion_structured_output(
client_with_models, text_model_id, test_case, inference_provider_type
):
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support structured outputs yet"
)
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
class NBAStats(BaseModel):
year_for_draft: int