make TGI work well

This commit is contained in:
Hardik Shah 2025-03-28 15:38:27 -07:00
parent e58c7f6c37
commit 021dd0d35d
9 changed files with 617 additions and 326 deletions

View file

@ -8,9 +8,9 @@
import os
import pytest
from pydantic import BaseModel
from llama_stack.models.llama.sku_list import resolve_model
from pydantic import BaseModel
from ..test_cases.test_case import TestCase
@ -23,8 +23,15 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
provider_id = models[model_id].provider_id
providers = {p.provider_id: p for p in client_with_models.providers.list()}
provider = providers[provider_id]
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
if provider.provider_type in (
"remote::openai",
"remote::anthropic",
"remote::gemini",
"remote::groq",
):
pytest.skip(
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
)
def get_llama_model(client_with_models, model_id):
@ -105,7 +112,9 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
"inference:completion:stop_sequence",
],
)
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
def test_text_completion_stop_sequence(
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
if inference_provider_type != "remote::vllm":
@ -132,7 +141,9 @@ def test_text_completion_stop_sequence(client_with_models, text_model_id, infere
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
def test_text_completion_log_probs_non_streaming(
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -151,7 +162,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
},
)
assert response.logprobs, "Logprobs should not be empty"
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
assert (
1 <= len(response.logprobs) <= 5
) # each token has 1 logprob and here max_tokens=5
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
@ -161,7 +174,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
"inference:completion:log_probs",
],
)
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
def test_text_completion_log_probs_streaming(
client_with_models, text_model_id, inference_provider_type, test_case
):
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
@ -183,7 +198,9 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
assert all(
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@ -194,7 +211,13 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
"inference:completion:structured_output",
],
)
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
def test_text_completion_structured_output(
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support structured outputs yet"
)
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
class AnswerFormat(BaseModel):
@ -231,7 +254,9 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
"inference:chat_completion:non_streaming_02",
],
)
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
def test_text_chat_completion_non_streaming(
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
@ -257,14 +282,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
"inference:chat_completion:ttft",
],
)
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
def test_text_chat_completion_first_token_profiling(
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case)
messages = tc["messages"]
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
from pydantic import TypeAdapter
if os.environ.get(
"DEBUG_TTFT"
): # debugging print number of tokens in input, ideally around 800
from llama_stack.apis.inference import Message
from pydantic import TypeAdapter
tokenizer, formatter = get_llama_tokenizer()
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
@ -279,7 +307,9 @@ def test_text_chat_completion_first_token_profiling(client_with_models, text_mod
message_content = response.completion_message.content.lower().strip()
assert len(message_content) > 0
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
if os.environ.get(
"DEBUG_TTFT"
): # debugging print number of tokens in response, ideally around 150
tokenizer, formatter = get_llama_tokenizer()
encoded = formatter.encode_content(message_content)
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
@ -302,7 +332,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
messages=[{"role": "user", "content": question}],
stream=True,
)
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
streamed_content = [
str(chunk.event.delta.text.lower().strip()) for chunk in response
]
assert len(streamed_content) > 0
assert expected.lower() in "".join(streamed_content)
@ -313,7 +345,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
def test_text_chat_completion_with_tool_calling_and_non_streaming(
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -327,7 +361,10 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo
assert response.completion_message.role == "assistant"
assert len(response.completion_message.tool_calls) == 1
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
assert (
response.completion_message.tool_calls[0].tool_name
== tc["tools"][0]["tool_name"]
)
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
@ -350,7 +387,9 @@ def extract_tool_invocation_content(response):
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
def test_text_chat_completion_with_tool_calling_and_streaming(
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -372,7 +411,14 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
def test_text_chat_completion_with_tool_choice_required(
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
)
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -396,7 +442,9 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text
"inference:chat_completion:tool_calling",
],
)
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
def test_text_chat_completion_with_tool_choice_none(
client_with_models, text_model_id, test_case
):
tc = TestCase(test_case)
response = client_with_models.inference.chat_completion(
@ -416,7 +464,14 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod
"inference:chat_completion:structured_output",
],
)
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
def test_text_chat_completion_structured_output(
client_with_models, text_model_id, test_case, inference_provider_type
):
if inference_provider_type == "remote::tgi":
pytest.xfail(
f"{inference_provider_type} doesn't support structured outputs yet"
)
class NBAStats(BaseModel):
year_for_draft: int
num_seasons_in_nba: int