diff --git a/llama_stack/providers/adapters/inference/fireworks/fireworks.py b/llama_stack/providers/adapters/inference/fireworks/fireworks.py index f6eec838d..421706650 100644 --- a/llama_stack/providers/adapters/inference/fireworks/fireworks.py +++ b/llama_stack/providers/adapters/inference/fireworks/fireworks.py @@ -143,7 +143,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): params = self._get_params(request) if "messages" in params: - print(f"Using chat completion endpoint: {params}") stream = client.chat.completions.acreate(**params) else: stream = client.completion.acreate(**params) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 916241a7c..4c12a967f 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import io from typing import AsyncGenerator import httpx @@ -29,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, completion_request_to_prompt, + request_has_media, ) OLLAMA_SUPPORTED_MODELS = { @@ -38,6 +41,7 @@ OLLAMA_SUPPORTED_MODELS = { "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama-Guard-3-8B": "llama-guard3:8b", "Llama-Guard-3-1B": "llama-guard3:1b", + "Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16", } @@ -109,22 +113,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_completion(request) - def _get_params_for_completion(self, request: CompletionRequest) -> dict: - sampling_options = get_sampling_options(request.sampling_params) - # This is needed since the Ollama API expects num_predict to be set - # for early truncation instead of max_tokens. - if sampling_options["max_tokens"] is not None: - sampling_options["num_predict"] = sampling_options["max_tokens"] - return { - "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": completion_request_to_prompt(request, self.formatter), - "options": sampling_options, - "raw": True, - "stream": request.stream, - } - async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): s = await self.client.generate(**params) @@ -142,7 +132,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: - params = self._get_params_for_completion(request) + params = await self._get_params(request) r = await self.client.generate(**params) assert isinstance(r, dict) @@ -183,26 +173,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): else: return await self._nonstream_chat_completion(request) - def _get_params(self, request: ChatCompletionRequest) -> dict: + async def _get_params( + self, request: Union[ChatCompletionRequest, CompletionRequest] + ) -> dict: + sampling_options = get_sampling_options(request.sampling_params) + # This is needed since the Ollama API expects num_predict to be set + # for early truncation instead of max_tokens. + if sampling_options.get("max_tokens") is not None: + sampling_options["num_predict"] = sampling_options["max_tokens"] + + input_dict = {} + media_present = request_has_media(request) + if isinstance(request, ChatCompletionRequest): + if media_present: + contents = [ + await convert_message_to_dict_for_ollama(m) + for m in request.messages + ] + # flatten the list of lists + input_dict["messages"] = [ + item for sublist in contents for item in sublist + ] + else: + input_dict["raw"] = True + input_dict["prompt"] = chat_completion_request_to_prompt( + request, self.formatter + ) + else: + assert ( + not media_present + ), "Ollama does not support media for Completion requests" + input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) + input_dict["raw"] = True + return { "model": OLLAMA_SUPPORTED_MODELS[request.model], - "prompt": chat_completion_request_to_prompt(request, self.formatter), - "options": get_sampling_options(request.sampling_params), - "raw": True, + **input_dict, + "options": sampling_options, "stream": request.stream, } async def _nonstream_chat_completion( self, request: ChatCompletionRequest ) -> ChatCompletionResponse: - params = self._get_params(request) - r = await self.client.generate(**params) + params = await self._get_params(request) + if "messages" in params: + r = await self.client.chat(**params) + else: + r = await self.client.generate(**params) assert isinstance(r, dict) - choice = OpenAICompatCompletionChoice( - finish_reason=r["done_reason"] if r["done"] else None, - text=r["response"], - ) + if "message" in r: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=r["done_reason"] if r["done"] else None, + text=r["response"], + ) response = OpenAICompatCompletionResponse( choices=[choice], ) @@ -211,15 +241,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def _stream_chat_completion( self, request: ChatCompletionRequest ) -> AsyncGenerator: - params = self._get_params(request) + params = await self._get_params(request) async def _generate_and_convert_to_openai_compat(): - s = await self.client.generate(**params) + if "messages" in params: + s = await self.client.chat(**params) + else: + s = await self.client.generate(**params) async for chunk in s: - choice = OpenAICompatCompletionChoice( - finish_reason=chunk["done_reason"] if chunk["done"] else None, - text=chunk["response"], - ) + if "message" in chunk: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["message"]["content"], + ) + else: + choice = OpenAICompatCompletionChoice( + finish_reason=chunk["done_reason"] if chunk["done"] else None, + text=chunk["response"], + ) yield OpenAICompatCompletionResponse( choices=[choice], ) @@ -236,3 +275,37 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): contents: List[InterleavedTextMedia], ) -> EmbeddingsResponse: raise NotImplementedError() + + +async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]: + async def _convert_content(content) -> dict: + if isinstance(content, ImageMedia): + return { + "role": message.role, + "images": [await convert_image_media_to_base64(content)], + } + else: + return { + "role": message.role, + "content": content, + } + + if isinstance(message.content, list): + return [await _convert_content(c) for c in message.content] + else: + return [await _convert_content(message.content)] + + +async def convert_image_media_to_base64(media: ImageMedia) -> str: + if isinstance(media.image, PIL_Image.Image): + bytestream = io.BytesIO() + media.image.save(bytestream, format=media.image.format) + bytestream.seek(0) + content = bytestream.getvalue() + else: + assert isinstance(media.image, URL) + async with httpx.AsyncClient() as client: + r = await client.get(media.image.uri) + content = r.content + + return base64.b64encode(content).decode("utf-8") diff --git a/llama_stack/providers/tests/inference/conftest.py b/llama_stack/providers/tests/inference/conftest.py index 71253871d..ba60b9925 100644 --- a/llama_stack/providers/tests/inference/conftest.py +++ b/llama_stack/providers/tests/inference/conftest.py @@ -19,12 +19,11 @@ def pytest_addoption(parser): def pytest_configure(config): - config.addinivalue_line( - "markers", "llama_8b: mark test to run only with the given model" - ) - config.addinivalue_line( - "markers", "llama_3b: mark test to run only with the given model" - ) + for model in ["llama_8b", "llama_3b", "llama_vision"]: + config.addinivalue_line( + "markers", f"{model}: mark test to run only with the given model" + ) + for fixture_name in INFERENCE_FIXTURES: config.addinivalue_line( "markers", @@ -37,6 +36,14 @@ MODEL_PARAMS = [ pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), ] +VISION_MODEL_PARAMS = [ + pytest.param( + "Llama3.2-11B-Vision-Instruct", + marks=pytest.mark.llama_vision, + id="llama_vision", + ), +] + def pytest_generate_tests(metafunc): if "inference_model" in metafunc.fixturenames: @@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc): if model: params = [pytest.param(model, id="")] else: - params = MODEL_PARAMS + cls_name = metafunc.cls.__name__ + if "Vision" in cls_name: + params = VISION_MODEL_PARAMS + else: + params = MODEL_PARAMS metafunc.parametrize( "inference_model", diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 777c27809..896acbad8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -29,11 +29,6 @@ def inference_model(request): return request.config.getoption("--inference-model", None) -@pytest.fixture(scope="session") -def vision_inference_model(): - return "Llama3.2-11B-Vision-Instruct" - - @pytest.fixture(scope="session") def inference_remote() -> ProviderFixture: return remote_stack_fixture() diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index 6f823564c..1939d6934 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -21,19 +21,20 @@ THIS_DIR = Path(__file__).parent class TestVisionModelInference: @pytest.mark.asyncio async def test_vision_chat_completion_non_streaming( - self, vision_inference_model, inference_stack + self, inference_model, inference_stack ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl( - vision_inference_model - ) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::together", "remote::fireworks", + "remote::ollama", ): - pytest.skip("Other inference providers don't support completion() yet") + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) images = [ ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), @@ -51,7 +52,7 @@ class TestVisionModelInference: ] for image, expected_strings in zip(images, expected_strings_to_check): response = await inference_impl.chat_completion( - model=vision_inference_model, + model=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage( @@ -69,19 +70,20 @@ class TestVisionModelInference: @pytest.mark.asyncio async def test_vision_chat_completion_streaming( - self, vision_inference_model, inference_stack + self, inference_model, inference_stack ): inference_impl, _ = inference_stack - provider = inference_impl.routing_table.get_provider_impl( - vision_inference_model - ) + provider = inference_impl.routing_table.get_provider_impl(inference_model) if provider.__provider_spec__.provider_type not in ( "meta-reference", "remote::together", "remote::fireworks", + "remote::ollama", ): - pytest.skip("Other inference providers don't support completion() yet") + pytest.skip( + "Other inference providers don't support vision chat completion() yet" + ) images = [ ImageMedia( @@ -97,7 +99,7 @@ class TestVisionModelInference: response = [ r async for r in await inference_impl.chat_completion( - model=vision_inference_model, + model=inference_model, messages=[ SystemMessage(content="You are a helpful assistant."), UserMessage(