forked from phoenix-oss/llama-stack-mirror
fix: misc fixes for tests kill horrible warnings
This commit is contained in:
parent
8b4158169f
commit
429f6de7d7
4 changed files with 12 additions and 63 deletions
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
@ -54,15 +53,6 @@ def get_llama_model(client_with_models, model_id):
|
|||
return model.metadata.get("llama_model", None)
|
||||
|
||||
|
||||
def get_llama_tokenizer():
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
formatter = ChatFormat(tokenizer)
|
||||
return tokenizer, formatter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
@ -261,41 +251,6 @@ def test_text_chat_completion_non_streaming(client_with_models, text_model_id, t
|
|||
assert expected.lower() in message_content
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:ttft",
|
||||
],
|
||||
)
|
||||
def test_text_chat_completion_first_token_profiling(client_with_models, text_model_id, test_case):
|
||||
tc = TestCase(test_case)
|
||||
|
||||
messages = tc["messages"]
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in input, ideally around 800
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
typed_messages = [TypeAdapter(Message).validate_python(m) for m in messages]
|
||||
encoded = formatter.encode_dialog_prompt(typed_messages, None)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
response = client_with_models.inference.chat_completion(
|
||||
model_id=text_model_id,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
timeout=120, # Increase timeout to 2 minutes for large conversation history
|
||||
)
|
||||
message_content = response.completion_message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
|
||||
if os.environ.get("DEBUG_TTFT"): # debugging print number of tokens in response, ideally around 150
|
||||
tokenizer, formatter = get_llama_tokenizer()
|
||||
encoded = formatter.encode_content(message_content)
|
||||
raise ValueError(len(encoded.tokens) if encoded and encoded.tokens else 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue