Works with ollama 0.4.0 pre-release with the vision model

This commit is contained in:
Ashwin Bharambe 2024-11-05 14:59:18 -08:00
parent 03013dafc1
commit d543eb442b
5 changed files with 137 additions and 57 deletions

View file

@ -143,7 +143,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
params = self._get_params(request)
if "messages" in params:
print(f"Using chat completion endpoint: {params}")
stream = client.chat.completions.acreate(**params)
else:
stream = client.completion.acreate(**params)

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import io
from typing import AsyncGenerator
import httpx
@ -29,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
OLLAMA_SUPPORTED_MODELS = {
@ -38,6 +41,7 @@ OLLAMA_SUPPORTED_MODELS = {
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
"Llama-Guard-3-8B": "llama-guard3:8b",
"Llama-Guard-3-1B": "llama-guard3:1b",
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
}
@ -109,22 +113,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": completion_request_to_prompt(request, self.formatter),
"options": sampling_options,
"raw": True,
"stream": request.stream,
}
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
@ -142,7 +132,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
params = await self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
@ -183,26 +173,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
else:
return await self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
contents = [
await convert_message_to_dict_for_ollama(m)
for m in request.messages
]
# flatten the list of lists
input_dict["messages"] = [
item for sublist in contents for item in sublist
]
else:
input_dict["raw"] = True
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
else:
assert (
not media_present
), "Ollama does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
input_dict["raw"] = True
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request.sampling_params),
"raw": True,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
if "message" in r:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
text=r["response"],
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
@ -211,15 +241,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
if "messages" in params:
s = await self.client.chat(**params)
else:
s = await self.client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["message"]["content"],
)
else:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"],
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
@ -236,3 +275,37 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
async def _convert_content(content) -> dict:
if isinstance(content, ImageMedia):
return {
"role": message.role,
"images": [await convert_image_media_to_base64(content)],
}
else:
return {
"role": message.role,
"content": content,
}
if isinstance(message.content, list):
return [await _convert_content(c) for c in message.content]
else:
return [await _convert_content(message.content)]
async def convert_image_media_to_base64(media: ImageMedia) -> str:
if isinstance(media.image, PIL_Image.Image):
bytestream = io.BytesIO()
media.image.save(bytestream, format=media.image.format)
bytestream.seek(0)
content = bytestream.getvalue()
else:
assert isinstance(media.image, URL)
async with httpx.AsyncClient() as client:
r = await client.get(media.image.uri)
content = r.content
return base64.b64encode(content).decode("utf-8")

View file

@ -19,12 +19,11 @@ def pytest_addoption(parser):
def pytest_configure(config):
config.addinivalue_line(
"markers", "llama_8b: mark test to run only with the given model"
)
config.addinivalue_line(
"markers", "llama_3b: mark test to run only with the given model"
)
for model in ["llama_8b", "llama_3b", "llama_vision"]:
config.addinivalue_line(
"markers", f"{model}: mark test to run only with the given model"
)
for fixture_name in INFERENCE_FIXTURES:
config.addinivalue_line(
"markers",
@ -37,6 +36,14 @@ MODEL_PARAMS = [
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
]
VISION_MODEL_PARAMS = [
pytest.param(
"Llama3.2-11B-Vision-Instruct",
marks=pytest.mark.llama_vision,
id="llama_vision",
),
]
def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc):
if model:
params = [pytest.param(model, id="")]
else:
params = MODEL_PARAMS
cls_name = metafunc.cls.__name__
if "Vision" in cls_name:
params = VISION_MODEL_PARAMS
else:
params = MODEL_PARAMS
metafunc.parametrize(
"inference_model",

View file

@ -29,11 +29,6 @@ def inference_model(request):
return request.config.getoption("--inference-model", None)
@pytest.fixture(scope="session")
def vision_inference_model():
return "Llama3.2-11B-Vision-Instruct"
@pytest.fixture(scope="session")
def inference_remote() -> ProviderFixture:
return remote_stack_fixture()

View file

@ -21,19 +21,20 @@ THIS_DIR = Path(__file__).parent
class TestVisionModelInference:
@pytest.mark.asyncio
async def test_vision_chat_completion_non_streaming(
self, vision_inference_model, inference_stack
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(
vision_inference_model
)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
@ -51,7 +52,7 @@ class TestVisionModelInference:
]
for image, expected_strings in zip(images, expected_strings_to_check):
response = await inference_impl.chat_completion(
model=vision_inference_model,
model=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(
@ -69,19 +70,20 @@ class TestVisionModelInference:
@pytest.mark.asyncio
async def test_vision_chat_completion_streaming(
self, vision_inference_model, inference_stack
self, inference_model, inference_stack
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(
vision_inference_model
)
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::together",
"remote::fireworks",
"remote::ollama",
):
pytest.skip("Other inference providers don't support completion() yet")
pytest.skip(
"Other inference providers don't support vision chat completion() yet"
)
images = [
ImageMedia(
@ -97,7 +99,7 @@ class TestVisionModelInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=vision_inference_model,
model=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(