forked from phoenix-oss/llama-stack-mirror
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -12,9 +12,7 @@ from .fixtures import INFERENCE_FIXTURES
|
|||
|
||||
def pytest_configure(config):
|
||||
for model in ["llama_8b", "llama_3b", "llama_vision"]:
|
||||
config.addinivalue_line(
|
||||
"markers", f"{model}: mark test to run only with the given model"
|
||||
)
|
||||
config.addinivalue_line("markers", f"{model}: mark test to run only with the given model")
|
||||
|
||||
for fixture_name in INFERENCE_FIXTURES:
|
||||
config.addinivalue_line(
|
||||
|
@ -24,12 +22,8 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
MODEL_PARAMS = [
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"
|
||||
),
|
||||
pytest.param(
|
||||
"meta-llama/Llama-3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"
|
||||
),
|
||||
pytest.param("meta-llama/Llama-3.1-8B-Instruct", marks=pytest.mark.llama_8b, id="llama_8b"),
|
||||
pytest.param("meta-llama/Llama-3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||
]
|
||||
|
||||
VISION_MODEL_PARAMS = [
|
||||
|
@ -49,9 +43,7 @@ def pytest_generate_tests(metafunc):
|
|||
params = []
|
||||
inference_models = getattr(test_config, "inference_models", [])
|
||||
for model in inference_models:
|
||||
if ("Vision" in cls_name and "Vision" in model) or (
|
||||
"Vision" not in cls_name and "Vision" not in model
|
||||
):
|
||||
if ("Vision" in cls_name and "Vision" in model) or ("Vision" not in cls_name and "Vision" not in model):
|
||||
params.append(pytest.param(model, id=model))
|
||||
|
||||
if not params:
|
||||
|
@ -74,10 +66,7 @@ def pytest_generate_tests(metafunc):
|
|||
fixtures = [stack.values[0]["inference"] for stack in filtered_stacks]
|
||||
if test_config:
|
||||
if custom_fixtures := [
|
||||
(
|
||||
scenario.fixture_combo_id
|
||||
or scenario.provider_fixtures.get("inference")
|
||||
)
|
||||
(scenario.fixture_combo_id or scenario.provider_fixtures.get("inference"))
|
||||
for scenario in test_config.scenarios
|
||||
]:
|
||||
fixtures = custom_fixtures
|
||||
|
|
|
@ -47,9 +47,7 @@ def inference_remote() -> ProviderFixture:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
|
||||
# If embedding dimension is set, use the 8B model for testing
|
||||
if os.getenv("EMBEDDING_DIMENSION"):
|
||||
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||
|
@ -88,9 +86,7 @@ def inference_cerebras() -> ProviderFixture:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_ollama(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
|
||||
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||
|
||||
|
@ -99,9 +95,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
Provider(
|
||||
provider_id="ollama",
|
||||
provider_type="remote::ollama",
|
||||
config=OllamaImplConfig(
|
||||
host="localhost", port=os.getenv("OLLAMA_PORT", 11434)
|
||||
).model_dump(),
|
||||
config=OllamaImplConfig(host="localhost", port=os.getenv("OLLAMA_PORT", 11434)).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -109,9 +103,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def inference_vllm(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||
)
|
||||
inference_model = [inference_model] if isinstance(inference_model, str) else inference_model
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
|
|
|
@ -162,9 +162,7 @@ class TestConvertChatCompletionRequest:
|
|||
|
||||
def test_includes_stratgy(self):
|
||||
request = self._dummy_chat_completion_request()
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(
|
||||
temperature=0.5, top_p=0.95
|
||||
)
|
||||
request.sampling_params.strategy = TopPSamplingStrategy(temperature=0.5, top_p=0.95)
|
||||
|
||||
converted = convert_chat_completion_request(request)
|
||||
|
||||
|
@ -375,9 +373,7 @@ class TestConvertNonStreamChatCompletionResponse:
|
|||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant", content="Hello World"
|
||||
),
|
||||
message=ChatCompletionMessage(role="assistant", content="Hello World"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
|
|
|
@ -29,11 +29,7 @@ class TestEmbeddings:
|
|||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
assert all(isinstance(value, float) for embedding in response.embeddings for value in embedding)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||
|
@ -53,11 +49,7 @@ class TestEmbeddings:
|
|||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == len(texts)
|
||||
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||
assert all(
|
||||
isinstance(value, float)
|
||||
for embedding in response.embeddings
|
||||
for value in embedding
|
||||
)
|
||||
assert all(isinstance(value, float) for embedding in response.embeddings for value in embedding)
|
||||
|
||||
embedding_dim = len(response.embeddings[0])
|
||||
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
||||
|
|
|
@ -44,11 +44,7 @@ from .utils import group_chunks
|
|||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return (
|
||||
StopReason.end_of_message
|
||||
if ("Llama3.1" in model or "Llama-3.1" in model)
|
||||
else StopReason.end_of_turn
|
||||
)
|
||||
return StopReason.end_of_message if ("Llama3.1" in model or "Llama-3.1" in model) else StopReason.end_of_turn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -179,13 +175,9 @@ class TestInference:
|
|||
1 <= len(chunks) <= 6
|
||||
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||
for chunk in chunks:
|
||||
if (
|
||||
chunk.delta.type == "text" and chunk.delta.text
|
||||
): # if there's a token, we expect logprobs
|
||||
if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
assert all(
|
||||
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||
)
|
||||
assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs)
|
||||
else: # no token, no logprobs
|
||||
assert not chunk.logprobs, "Logprobs should be empty"
|
||||
|
||||
|
@ -236,9 +228,7 @@ class TestInference:
|
|||
assert len(response.completion_message.content) > 0
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_structured_output(
|
||||
self, inference_model, inference_stack, common_params
|
||||
):
|
||||
async def test_structured_output(self, inference_model, inference_stack, common_params):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
class AnswerFormat(BaseModel):
|
||||
|
@ -295,9 +285,7 @@ class TestInference:
|
|||
AnswerFormat.model_validate_json(response.completion_message.content)
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_chat_completion_streaming(
|
||||
self, inference_model, inference_stack, common_params, sample_messages
|
||||
):
|
||||
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
|
||||
inference_impl, _ = inference_stack
|
||||
response = [
|
||||
r
|
||||
|
@ -310,9 +298,7 @@ class TestInference:
|
|||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
|
@ -387,9 +373,7 @@ class TestInference:
|
|||
)
|
||||
]
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response
|
||||
)
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
|
@ -404,13 +388,10 @@ class TestInference:
|
|||
|
||||
if "Llama3.1" in inference_model:
|
||||
assert all(
|
||||
chunk.event.delta.type == "tool_call"
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
chunk.event.delta.type == "tool_call" for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
first = grouped[ChatCompletionResponseEventType.progress][0]
|
||||
if not isinstance(
|
||||
first.event.delta.tool_call, ToolCall
|
||||
): # first chunk may contain entire call
|
||||
if not isinstance(first.event.delta.tool_call, ToolCall): # first chunk may contain entire call
|
||||
assert first.event.delta.parse_status == ToolCallParseStatus.started
|
||||
|
||||
last = grouped[ChatCompletionResponseEventType.progress][-1]
|
||||
|
|
|
@ -73,9 +73,7 @@ class TestVisionModelInference:
|
|||
assert expected_string in response.completion_message.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_chat_completion_streaming(
|
||||
self, inference_model, inference_stack
|
||||
):
|
||||
async def test_vision_chat_completion_streaming(self, inference_model, inference_stack):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
images = [
|
||||
|
@ -100,9 +98,7 @@ class TestVisionModelInference:
|
|||
UserMessage(
|
||||
content=[
|
||||
image,
|
||||
TextContentItem(
|
||||
text="Describe this image in two sentences."
|
||||
),
|
||||
TextContentItem(text="Describe this image in two sentences."),
|
||||
]
|
||||
),
|
||||
],
|
||||
|
@ -112,18 +108,12 @@ class TestVisionModelInference:
|
|||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk)
|
||||
for chunk in response
|
||||
)
|
||||
assert all(isinstance(chunk, ChatCompletionResponseStreamChunk) for chunk in response)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
content = "".join(
|
||||
chunk.event.delta.text
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
content = "".join(chunk.event.delta.text for chunk in grouped[ChatCompletionResponseEventType.progress])
|
||||
for expected_string in expected_strings:
|
||||
assert expected_string in content
|
||||
|
|
|
@ -10,7 +10,5 @@ import itertools
|
|||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
for event_type, group in itertools.groupby(response, key=lambda chunk: chunk.event.event_type)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue