From 6d21da6e485093dab2b2d42029e190d7d6cb3a4c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 17 Jan 2025 16:17:15 -0800 Subject: [PATCH] fix vllm base64 --- .../providers/remote/inference/vllm/vllm.py | 4 +- tests/client-sdk/inference/test_inference.py | 43 ++++++++++++++++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 317d05207..81c746cce 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -176,10 +176,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): media_present = request_has_media(request) if isinstance(request, ChatCompletionRequest): if media_present: - # vllm does not seem to work well with image urls, so we download the images input_dict["messages"] = [ - await convert_message_to_openai_dict(m, download=True) - for m in request.messages + await convert_message_to_openai_dict(m) for m in request.messages ] else: input_dict["prompt"] = await chat_completion_request_to_prompt( diff --git a/tests/client-sdk/inference/test_inference.py b/tests/client-sdk/inference/test_inference.py index 19314e4ab..8553f94f0 100644 --- a/tests/client-sdk/inference/test_inference.py +++ b/tests/client-sdk/inference/test_inference.py @@ -4,7 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 + import pytest +import requests from pydantic import BaseModel PROVIDER_TOOL_PROMPT_FORMAT = { @@ -69,6 +72,16 @@ def get_weather_tool_definition(): } +@pytest.fixture +def base64_image_url(): + downloadable_url = "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + response = requests.get(downloadable_url) + response.raise_for_status() + base64_string = base64.b64encode(response.content).decode("utf-8") + base64_url = f"data:image;base64,{base64_string}" + return base64_url + + def test_completion_non_streaming(llama_stack_client, text_model_id): response = llama_stack_client.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", @@ -326,7 +339,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id ) message_content = response.completion_message.content.lower().strip() assert len(message_content) > 0 - assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) + assert any([expected in message_content for expected in {"dog", "puppy", "pup"}]) def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): @@ -356,3 +369,31 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): streamed_content += chunk.event.delta.text.lower() assert len(streamed_content) > 0 assert any(expected in streamed_content for expected in {"dog", "puppy", "pup"}) + + +def test_image_chat_completion_base64_url( + llama_stack_client, vision_model_id, base64_image_url +): + + message = { + "role": "user", + "content": [ + { + "type": "image", + "url": { + "uri": base64_image_url, + }, + }, + { + "type": "text", + "text": "Describe what is in this image.", + }, + ], + } + response = llama_stack_client.inference.chat_completion( + model_id=vision_model_id, + messages=[message], + stream=False, + ) + message_content = response.completion_message.content.lower().strip() + assert len(message_content) > 0