Works with ollama 0.4.0 pre-release with the vision model

This commit is contained in:
Ashwin Bharambe 2024-11-05 14:59:18 -08:00
parent 03013dafc1
commit d543eb442b
5 changed files with 137 additions and 57 deletions

View file

@ -143,7 +143,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
params = self._get_params(request) params = self._get_params(request)
if "messages" in params: if "messages" in params:
print(f"Using chat completion endpoint: {params}")
stream = client.chat.completions.acreate(**params) stream = client.chat.completions.acreate(**params)
else: else:
stream = client.completion.acreate(**params) stream = client.completion.acreate(**params)

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import io
from typing import AsyncGenerator from typing import AsyncGenerator
import httpx import httpx
@ -29,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
request_has_media,
) )
OLLAMA_SUPPORTED_MODELS = { OLLAMA_SUPPORTED_MODELS = {
@ -38,6 +41,7 @@ OLLAMA_SUPPORTED_MODELS = {
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16", "Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "llama-guard3:8b", "Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b", "Llama-Guard-3-1B": "llama-guard3:1b",
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
} }
@ -109,22 +113,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params) s = await self.client.generate(**params)
@ -142,7 +132,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator: async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request) params = await self._get_params(request)
r = await self.client.generate(**params) r = await self.client.generate(**params)
assert isinstance(r, dict) assert isinstance(r, dict)
@ -183,26 +173,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else: else:
return await self._nonstream_chat_completion(request) return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict: async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_dict_for_ollama(m)
for m in request.messages
]
# flatten the list of lists
input_dict["messages"] = [
item for sublist in contents for item in sublist
]
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
else:
assert (
not media_present
), "Ollama does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["raw"] = True
return { return {
"model": OLLAMA_SUPPORTED_MODELS[request.model], "model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter), **input_dict,
"options": get_sampling_options(request.sampling_params), "options": sampling_options,
"raw": True,
"stream": request.stream, "stream": request.stream,
} }
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = await self._get_params(request)
r = await self.client.generate(**params) if "messages" in params:
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict) assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice( if "message" in r:
finish_reason=r["done_reason"] if r["done"] else None, choice = OpenAICompatCompletionChoice(
text=r["response"], finish_reason=r["done_reason"] if r["done"] else None,
) text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse( response = OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
@ -211,15 +241,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = self._get_params(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params) if "messages" in params:
s = await self.client.chat(**params)
else:
s = await self.client.generate(**params)
async for chunk in s: async for chunk in s:
choice = OpenAICompatCompletionChoice( if "message" in chunk:
finish_reason=chunk["done_reason"] if chunk["done"] else None, choice = OpenAICompatCompletionChoice(
text=chunk["response"], finish_reason=chunk["done_reason"] if chunk["done"] else None,
) text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse( yield OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
@ -236,3 +275,37 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"role": message.role,
"images": [await convert_image_media_to_base64(content)],
}
else:
return {
"role": message.role,
"content": content,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]
async def convert_image_media_to_base64(media: ImageMedia) -> str:
if isinstance(media.image, PIL_Image.Image):
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
content = bytestream.getvalue()
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
return base64.b64encode(content).decode("utf-8")

View file

@ -19,12 +19,11 @@ def pytest_addoption(parser):
def pytest_configure(config): def pytest_configure(config):
config.addinivalue_line( for model in ["llama_8b", "llama_3b", "llama_vision"]:
"markers", "llama_8b: mark test to run only with the given model" config.addinivalue_line(
) "markers", f"{model}: mark test to run only with the given model"
config.addinivalue_line( )
"markers", "llama_3b: mark test to run only with the given model"
)
for fixture_name in INFERENCE_FIXTURES: for fixture_name in INFERENCE_FIXTURES:
config.addinivalue_line( config.addinivalue_line(
"markers", "markers",
@ -37,6 +36,14 @@ MODEL_PARAMS = [
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"), pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
] ]
VISION_MODEL_PARAMS = [
pytest.param(
"Llama3.2-11B-Vision-Instruct",
marks=pytest.mark.llama_vision,
id="llama_vision",
),
]
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames: if "inference_model" in metafunc.fixturenames:
@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc):
if model: if model:
params = [pytest.param(model, id="")] params = [pytest.param(model, id="")]
else: else:
params = MODEL_PARAMS cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize( metafunc.parametrize(
"inference_model", "inference_model",

View file

@ -29,11 +29,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None) return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def vision_inference_model():
return "Llama3.2-11B-Vision-Instruct"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture: def inference_remote() -> ProviderFixture:
return remote_stack_fixture() return remote_stack_fixture()

View file

@ -21,19 +21,20 @@ THIS_DIR = Path(__file__).parent
class TestVisionModelInference: class TestVisionModelInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vision_chat_completion_non_streaming( async def test_vision_chat_completion_non_streaming(
self, vision_inference_model, inference_stack self, inference_model, inference_stack
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl( provider = inference_impl.routing_table.get_provider_impl(inference_model)
vision_inference_model
)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::together", "remote::together",
"remote::fireworks", "remote::fireworks",
"remote::ollama",
): ):
pytest.skip("Other inference providers don't support completion() yet") pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [ images = [
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")), ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
@ -51,7 +52,7 @@ class TestVisionModelInference:
] ]
for image, expected_strings in zip(images, expected_strings_to_check): for image, expected_strings in zip(images, expected_strings_to_check):
response = await inference_impl.chat_completion( response = await inference_impl.chat_completion(
model=vision_inference_model, model=inference_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage( UserMessage(
@ -69,19 +70,20 @@ class TestVisionModelInference:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vision_chat_completion_streaming( async def test_vision_chat_completion_streaming(
self, vision_inference_model, inference_stack self, inference_model, inference_stack
): ):
inference_impl, _ = inference_stack inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl( provider = inference_impl.routing_table.get_provider_impl(inference_model)
vision_inference_model
)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"meta-reference", "meta-reference",
"remote::together", "remote::together",
"remote::fireworks", "remote::fireworks",
"remote::ollama",
): ):
pytest.skip("Other inference providers don't support completion() yet") pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [ images = [
ImageMedia( ImageMedia(
@ -97,7 +99,7 @@ class TestVisionModelInference:
response = [ response = [
r r
async for r in await inference_impl.chat_completion( async for r in await inference_impl.chat_completion(
model=vision_inference_model, model=inference_model,
messages=[ messages=[
SystemMessage(content="You are a helpful assistant."), SystemMessage(content="You are a helpful assistant."),
UserMessage( UserMessage(