mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 05:53:53 +00:00
make TGI work well
This commit is contained in:
parent
e58c7f6c37
commit
021dd0d35d
9 changed files with 617 additions and 326 deletions
|
|
@ -8,9 +8,9 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
|
@ -23,8 +23,15 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id):
|
|||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
if provider.provider_type in ("remote::openai", "remote::anthropic", "remote::gemini", "remote::groq"):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support completion")
|
||||
if provider.provider_type in (
|
||||
"remote::openai",
|
||||
"remote::anthropic",
|
||||
"remote::gemini",
|
||||
"remote::groq",
|
||||
):
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} doesn't support completion"
|
||||
)
|
||||
|
||||
|
||||
def get_llama_model(client_with_models, model_id):
|
||||
|
|
@ -105,7 +112,9 @@ def test_text_completion_streaming(client_with_models, text_model_id, test_case)
|
|||
"inference:completion:stop_sequence",
|
||||
],
|
||||
)
|
||||
def test_text_completion_stop_sequence(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_completion_stop_sequence(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
# This is only supported/tested for remote vLLM: https://github.com/meta-llama/llama-stack/issues/1771
|
||||
if inference_provider_type != "remote::vllm":
|
||||
|
|
@ -132,7 +141,9 @@ def test_text_completion_stop_sequence(client_with_models, text_model_id, infere
|
|||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_completion_log_probs_non_streaming(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
|
@ -151,7 +162,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
|
|||
},
|
||||
)
|
||||
assert response.logprobs, "Logprobs should not be empty"
|
||||
assert 1 <= len(response.logprobs) <= 5 # each token has 1 logprob and here max_tokens=5
|
||||
assert (
|
||||
1 <= len(response.logprobs) <= 5
|
||||
) # each token has 1 logprob and here max_tokens=5
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs)
|
||||
|
||||
|
||||
|
|
@ -161,7 +174,9 @@ def test_text_completion_log_probs_non_streaming(client_with_models, text_model_
|
|||
"inference:completion:log_probs",
|
||||
],
|
||||
)
|
||||
def test_text_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type, test_case):
|
||||
def test_text_completion_log_probs_streaming(
|
||||
client_with_models, text_model_id, inference_provider_type, test_case
|
||||
):
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support log probs yet")
|
||||
|
|
@ -183,7 +198,9 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
|
|||
for chunk in streamed_content:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs)
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 1 for logprob in chunk.logprobs
|
||||
)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
|
@ -194,7 +211,13 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
|
|||
"inference:completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
def test_text_completion_structured_output(
|
||||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
||||
)
|
||||
skip_if_model_doesnt_support_completion(client_with_models, text_model_id)
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
|
@ -231,7 +254,9 @@ def test_text_completion_structured_output(client_with_models, text_model_id, te
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_non_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
|
@ -257,14 +282,17 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
|
|||
"inference:chat_completion:ttft",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_first_token_profiling(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
messages = tc["messages"]
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if os.environ.get(
|
||||
"DEBUG_TTFT"
|
||||
): # debugging print number of tokens in input, ideally around 800
|
||||
from llama_stack.apis.inference import Message
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||
|
|
@ -279,7 +307,9 @@ def test_text_chat_completion_first_token_profiling(client_with_models, text_mod
|
|||
message_content = response.completion_message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||
if os.environ.get(
|
||||
"DEBUG_TTFT"
|
||||
): # debugging print number of tokens in response, ideally around 150
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
encoded = formatter.encode_content(message_content)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
|
@ -302,7 +332,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
messages=[{"role": "user", "content": question}],
|
||||
stream=True,
|
||||
)
|
||||
streamed_content = [str(chunk.event.delta.text.lower().strip()) for chunk in response]
|
||||
streamed_content = [
|
||||
str(chunk.event.delta.text.lower().strip()) for chunk in response
|
||||
]
|
||||
assert len(streamed_content) > 0
|
||||
assert expected.lower() in "".join(streamed_content)
|
||||
|
||||
|
|
@ -313,7 +345,9 @@ def test_text_chat_completion_streaming(client_with_models, text_model_id, test_
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_with_tool_calling_and_non_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -327,7 +361,10 @@ def test_text_chat_completion_with_tool_calling_and_non_streaming(client_with_mo
|
|||
assert response.completion_message.role == "assistant"
|
||||
|
||||
assert len(response.completion_message.tool_calls) == 1
|
||||
assert response.completion_message.tool_calls[0].tool_name == tc["tools"][0]["tool_name"]
|
||||
assert (
|
||||
response.completion_message.tool_calls[0].tool_name
|
||||
== tc["tools"][0]["tool_name"]
|
||||
)
|
||||
assert response.completion_message.tool_calls[0].arguments == tc["expected"]
|
||||
|
||||
|
||||
|
|
@ -350,7 +387,9 @@ def extract_tool_invocation_content(response):
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_with_tool_calling_and_streaming(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -372,7 +411,14 @@ def test_text_chat_completion_with_tool_calling_and_streaming(client_with_models
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_required(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_with_tool_choice_required(
|
||||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support tool_choice 'required' parameter yet"
|
||||
)
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -396,7 +442,9 @@ def test_text_chat_completion_with_tool_choice_required(client_with_models, text
|
|||
"inference:chat_completion:tool_calling",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_with_tool_choice_none(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_with_tool_choice_none(
|
||||
client_with_models, text_model_id, test_case
|
||||
):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
|
|
@ -416,7 +464,14 @@ def test_text_chat_completion_with_tool_choice_none(client_with_models, text_mod
|
|||
"inference:chat_completion:structured_output",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case):
|
||||
def test_text_chat_completion_structured_output(
|
||||
client_with_models, text_model_id, test_case, inference_provider_type
|
||||
):
|
||||
if inference_provider_type == "remote::tgi":
|
||||
pytest.xfail(
|
||||
f"{inference_provider_type} doesn't support structured outputs yet"
|
||||
)
|
||||
|
||||
class NBAStats(BaseModel):
|
||||
year_for_draft: int
|
||||
num_seasons_in_nba: int
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue