llama-stack-mirror/llama_stack/providers/tests/inference/test_vision_inference.py
Sixian Yi 67450e4024
bug fixes on inference tests (#774)
# What does this PR do?

Fixes two issues on providers/test/inference

- [ ] Addresses issue (#issue)


## Test Plan

### Before
```

===================================================================================== FAILURES =====================================================================================
__________________________________ TestVisionModelInference.test_vision_chat_completion_streaming[llama_vision-fireworks][llama_vision] ___________________________________
providers/tests/inference/test_vision_inference.py:145: in test_vision_chat_completion_streaming
    content = "".join(
E   TypeError: sequence item 0: expected str instance, TextDelta found
------------------------------------------------------------------------------ Captured log teardown -------------------------------------------------------------------------------
ERROR    asyncio:base_events.py:1858 Task was destroyed but it is pending!
task: <Task pending name='Task-5' coro=<<async_generator_athrow without __name__>()>>
============================================================================= short test summary info ==============================================================================
FAILED providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[llama_vision-fireworks] - TypeError: sequence item 0: expected str instance, TextDelta found
============================================================== 1 failed, 2 passed, 33 deselected, 7 warnings in 3.59s ==============================================================
(base) sxyi@sxyi-mbp llama_stack % 
```

### After 
```
(base) sxyi@sxyi-mbp llama_stack % pytest -k "fireworks"  /Users/sxyi/llama-stack/llama_stack/providers/tests/inference/test_vision_inference.py
/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=============================================================================== test session starts ================================================================================
platform darwin -- Python 3.13.0, pytest-8.3.3, pluggy-1.5.0
rootdir: /Users/sxyi/llama-stack
configfile: pyproject.toml
plugins: asyncio-0.24.0, html-4.1.1, metadata-3.1.1, dependency-0.6.0, anyio-4.6.2.post1
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 36 items / 33 deselected / 3 selected                                                                                                                                    

providers/tests/inference/test_vision_inference.py ...                                                                                                                       [100%]

=================================================================== 3 passed, 33 deselected, 7 warnings in 3.75s ===================================================================
(base) sxyi@sxyi-mbp llama_stack % 
```

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2025-01-15 15:39:05 -08:00

150 lines
5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
import pytest
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
SamplingParams,
UserMessage,
)
from .utils import group_chunks
THIS_DIR = Path(__file__).parent
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
PASTA_IMAGE = f.read()
class TestVisionModelInference:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"image, expected_strings",
[
(
ImageContentItem(data=PASTA_IMAGE),
["spaghetti"],
),
(
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
["puppy"],
),
],
)
async def test_vision_chat_completion_non_streaming(
self, inference_model, inference_stack, image, expected_strings
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
"remote::vllm",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[
image,
TextContentItem(text="Describe this image in two sentences."),
]
),
],
stream=False,
sampling_params=SamplingParams(max_tokens=100),
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
for expected_string in expected_strings:
assert expected_string in response.completion_message.content
@pytest.mark.asyncio
async def test_vision_chat_completion_streaming(
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
"remote::vllm",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageContentItem(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
),
]
expected_strings_to_check = [
["puppy"],
]
for image, expected_strings in zip(images, expected_strings_to_check):
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
]
),
],
stream=True,
sampling_params=SamplingParams(max_tokens=100),
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk)
for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
content = "".join(
chunk.event.delta.text
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
for expected_string in expected_strings:
assert expected_string in content