diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 0259c7061..8dfe37c55 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -22,6 +22,9 @@ from llama_stack.providers.utils.inference.openai_compat import ( ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, + completion_request_to_prompt, + convert_message_to_dict, + request_has_media, ) from .config import VLLMInferenceAdapterConfig @@ -105,19 +108,25 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def _nonstream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI ) -> ChatCompletionResponse: - params = self._get_params(request) - r = client.completions.create(**params) + params = await self._get_params(request) + if "messages" in params: + r = client.chat.completions.create(**params) + else: + r = client.completions.create(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( self, request: ChatCompletionRequest, client: OpenAI ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) # TODO: Can we use client.completions.acreate() or maybe there is another way to directly create an async # generator so this wrapper is not necessary? async def _to_async_generator(): - s = client.completions.create(**params) + if "messages" in params: + s = client.chat.completions.create(**params) + else: + s = client.completions.create(**params) for chunk in s: yield chunk @@ -127,7 +136,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ): yield chunk - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: options = get_sampling_options(request.sampling_params) if "max_tokens" not in options: options["max_tokens"] = self.config.max_tokens @@ -136,9 +147,28 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if model is None: raise ValueError(f"Unknown model: {request.model}") + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + # vllm does not seem to work well with image urls, so we download the images + input_dict["messages"] = [ + await convert_message_to_dict(m, download=True) + for m in request.messages + ] + else: + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Together does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + return { "model": model.huggingface_repo, - "prompt": chat_completion_request_to_prompt(request, self.formatter), + **input_dict, "stream": request.stream, **options, } diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 1939d6934..3e785b757 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -20,8 +20,25 @@ THIS_DIR = Path(__file__).parent class TestVisionModelInference: @pytest.mark.asyncio + @pytest.mark.parametrize( + "image, expected_strings", + [ + ( + ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), + ["spaghetti"], + ), + ( + ImageMedia( + image=URL( + uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + ) + ), + ["puppy"], + ), + ], + ) async def test_vision_chat_completion_non_streaming( - self, inference_model, inference_stack + self, inference_model, inference_stack, image, expected_strings ): inference_impl, _ = inference_stack @@ -31,42 +48,27 @@ class TestVisionModelInference: "remote::together", "remote::fireworks", "remote::ollama", + "remote::vllm", ): pytest.skip( "Other inference providers don't support vision chat completion() yet" ) - images = [ - ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), - ImageMedia( - image=URL( - uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" - ) - ), - ] + response = await inference_impl.chat_completion( + model=inference_model, + messages=[ + UserMessage(content="You are a helpful assistant."), + UserMessage(content=[image, "Describe this image in two sentences."]), + ], + stream=False, + sampling_params=SamplingParams(max_tokens=100), + ) - # These are a bit hit-and-miss, need to be careful - expected_strings_to_check = [ - ["spaghetti"], - ["puppy"], - ] - for image, expected_strings in zip(images, expected_strings_to_check): - response = await inference_impl.chat_completion( - model=inference_model, - messages=[ - SystemMessage(content="You are a helpful assistant."), - UserMessage( - content=[image, "Describe this image in two sentences."] - ), - ], - stream=False, - ) - - assert isinstance(response, ChatCompletionResponse) - assert response.completion_message.role == "assistant" - assert isinstance(response.completion_message.content, str) - for expected_string in expected_strings: - assert expected_string in response.completion_message.content + assert isinstance(response, ChatCompletionResponse) + assert response.completion_message.role == "assistant" + assert isinstance(response.completion_message.content, str) + for expected_string in expected_strings: + assert expected_string in response.completion_message.content @pytest.mark.asyncio async def test_vision_chat_completion_streaming( @@ -80,6 +82,7 @@ class TestVisionModelInference: "remote::together", "remote::fireworks", "remote::ollama", + "remote::vllm", ): pytest.skip( "Other inference providers don't support vision chat completion() yet" @@ -101,12 +104,13 @@ class TestVisionModelInference: async for r in await inference_impl.chat_completion( model=inference_model, messages=[ - SystemMessage(content="You are a helpful assistant."), + UserMessage(content="You are a helpful assistant."), UserMessage( content=[image, "Describe this image in two sentences."] ), ], stream=True, + sampling_params=SamplingParams(max_tokens=100), ) ] diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 9decf5a00..45e43c898 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -90,13 +90,15 @@ async def convert_image_media_to_url( return base64.b64encode(content).decode("utf-8") -async def convert_message_to_dict(message: Message) -> dict: +# TODO: name this function better! this is about OpenAI compatibile image +# media conversion of the message. this should probably go in openai_compat.py +async def convert_message_to_dict(message: Message, download: bool = False) -> dict: async def _convert_content(content) -> dict: if isinstance(content, ImageMedia): return { "type": "image_url", "image_url": { - "url": await convert_image_media_to_url(content), + "url": await convert_image_media_to_url(content, download=download), }, } else: