llama-stack/llama_stack/providers/tests/inference/test_vision_inference.py
snova-edwardm 22dc684da6
Sambanova inference provider (#555)
# What does this PR do?

This PR adds SambaNova as one of the Provider

- Add SambaNova as a provider

## Test Plan
Test the functional command
```
pytest -s -v --providers inference=sambanova llama_stack/providers/tests/inference/test_embeddings.py llama_stack/providers/tests/inference/test_prompt_adapter.py llama_stack/providers/tests/inference/test_text_inference.py llama_stack/providers/tests/inference/test_vision_inference.py --env SAMBANOVA_API_KEY=<sambanova-api-key>
```

Test the distribution template:
```
# Docker
LLAMA_STACK_PORT=5001
docker run -it -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
  llamastack/distribution-sambanova \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY

# Conda
llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \
  --port $LLAMA_STACK_PORT \
  --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```

## Source
[SambaNova API Documentation](https://cloud.sambanova.ai/apis)

## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [Y] Ran pre-commit to handle lint / formatting issues.
- [Y] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [Y] Updated relevant documentation.
- [Y ] Wrote necessary unit or integration tests.

---------

Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
2025-01-23 12:20:28 -08:00

156 lines
5.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
import pytest
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem, URL
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
SamplingParams,
UserMessage,
)
from .utils import group_chunks
THIS_DIR = Path(__file__).parent
with open(THIS_DIR / "pasta.jpeg", "rb") as f:
PASTA_IMAGE = f.read()
class TestVisionModelInference:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"image, expected_strings",
[
(
ImageContentItem(image=dict(data=PASTA_IMAGE)),
["spaghetti"],
),
(
ImageContentItem(
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
["puppy"],
),
],
)
async def test_vision_chat_completion_non_streaming(
self, inference_model, inference_stack, image, expected_strings
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[
image,
TextContentItem(text="Describe this image in two sentences."),
]
),
],
stream=False,
sampling_params=SamplingParams(max_tokens=100),
)
assert isinstance(response, ChatCompletionResponse)
assert response.completion_message.role == "assistant"
assert isinstance(response.completion_message.content, str)
for expected_string in expected_strings:
assert expected_string in response.completion_message.content
@pytest.mark.asyncio
async def test_vision_chat_completion_streaming(
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
"remote::vllm",
"remote::sambanova",
):
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageContentItem(
image=dict(
url=URL(
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
)
)
),
]
expected_strings_to_check = [
["puppy"],
]
for image, expected_strings in zip(images, expected_strings_to_check):
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(
content=[
image,
TextContentItem(
text="Describe this image in two sentences."
),
]
),
],
stream=True,
sampling_params=SamplingParams(max_tokens=100),
)
]
assert len(response) > 0
assert all(
isinstance(chunk, ChatCompletionResponseStreamChunk)
for chunk in response
)
grouped = group_chunks(response)
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
content = "".join(
chunk.event.delta.text
for chunk in grouped[ChatCompletionResponseEventType.progress]
)
for expected_string in expected_strings:
assert expected_string in content