mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 05:59:59 +00:00
fixes and linting
This commit is contained in:
parent
021dd0d35d
commit
5251d2422d
8 changed files with 149 additions and 345 deletions
|
|
@ -8,9 +8,9 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
|
@ -29,9 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
"remote::gemini",
|
||||
"remote::groq",
|
||||
):
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
|
||||
)
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
|
||||
|
||||
def get_llama_model(client_with_models, model_id):
|
||||
|
|
@ -112,9 +110,7 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
|
|||
"inference:completion:stop_sequence",
|
||||
],
|
||||
)
|
||||
def test_text_completion_stop_sequence(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
|
||||
if inference_provider_type != "remote::vllm":
|
||||
|
|
@ -141,9 +137,7 @@ def test_text_completion_stop_sequence(
|
|||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_non_streaming(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
|
@ -162,9 +156,7 @@ def test_text_completion_log_probs_non_streaming(
|
|||
},
|
||||
)
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert (
|
||||
1 <= len(response.logprobs) <= 5
|
||||
) # each token has 1 logprob and here max_tokens=5
|
||||
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
|
|
@ -174,9 +166,7 @@ def test_text_completion_log_probs_non_streaming(
|
|||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_streaming(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
|
@ -198,9 +188,7 @@ def test_text_completion_log_probs_streaming(
|
|||
for chunk in streamed_content:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
||||
)
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
|
@ -211,13 +199,9 @@ def test_text_completion_log_probs_streaming(
|
|||
"inference:completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_completion_structured_output(
|
||||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
||||
)
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
|
@ -254,9 +238,7 @@ def test_text_completion_structured_output(
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
|
@ -282,18 +264,15 @@ def test_text_chat_completion_non_streaming(
|
|||
"inference:chat_completion:ttft",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_first_token_profiling(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
messages = tc["messages"]
|
||||
if os.environ.get(
|
||||
"DEBUG_TTFT"
|
||||
): # debugging print number of tokens in input, ideally around 800
|
||||
from llama_stack.apis.inference import Message
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
||||
|
|
@ -307,9 +286,7 @@ def test_text_chat_completion_first_token_profiling(
|
|||
message_content = response.completion_message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
|
||||
if os.environ.get(
|
||||
"DEBUG_TTFT"
|
||||
): # debugging print number of tokens in response, ideally around 150
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
encoded = formatter.encode_content(message_content)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
|
@ -332,9 +309,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
)
|
||||
streamed_content = [
|
||||
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||
]
|
||||
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
|
||||
assert len(streamed_content) > 0
|
||||
assert expected.lower() in "".join(streamed_content)
|
||||
|
||||
|
|
@ -345,9 +320,7 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -361,10 +334,7 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
|||
assert response.completion_message.role == "assistant"
|
||||
|
||||
assert len(response.completion_message.tool_calls) == 1
|
||||
assert (
|
||||
response.completion_message.tool_calls[0].tool_name
|
||||
== tc["tools"][0]["tool_name"]
|
||||
)
|
||||
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
|
||||
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
||||
|
||||
|
||||
|
|
@ -387,9 +357,7 @@ def extract_tool_invocation_content(response):
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -415,9 +383,7 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
|
||||
)
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet")
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
|
|
@ -442,9 +408,7 @@ def test_text_chat_completion_with_tool_choice_required(
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -464,13 +428,9 @@ def test_text_chat_completion_with_tool_choice_none(
|
|||
"inference:chat_completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_structured_output(
|
||||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case, inference_provider_type):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
||||
)
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support structured outputs yet")
|
||||
|
||||
class NBAStats(BaseModel):
|
||||
year_for_draft: int
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue