mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Enable vision models for Together and Fireworks
This commit is contained in:
parent
8de845a96d
commit
03013dafc1
9 changed files with 297 additions and 35 deletions
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
@ -129,7 +131,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
|
@ -137,24 +142,44 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
if "messages" in params:
|
||||
print(f"Using chat completion endpoint: {params}")
|
||||
stream = client.chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = client.completion.acreate(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request) -> dict:
|
||||
prompt = ""
|
||||
if type(request) == ChatCompletionRequest:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
elif type(request) == CompletionRequest:
|
||||
prompt = completion_request_to_prompt(request, self.formatter)
|
||||
def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
elif isinstance(request, CompletionRequest):
|
||||
assert (
|
||||
not media_present
|
||||
), "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
@ -172,9 +197,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
@ -102,7 +104,7 @@ class TogetherInferenceAdapter(
|
|||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
params = self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
|
@ -131,14 +133,6 @@ class TogetherInferenceAdapter(
|
|||
|
||||
return options
|
||||
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -172,7 +166,10 @@ class TogetherInferenceAdapter(
|
|||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
|
@ -182,7 +179,10 @@ class TogetherInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
@ -192,10 +192,29 @@ class TogetherInferenceAdapter(
|
|||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
|
|
@ -29,6 +29,11 @@ def inference_model(request):
|
|||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vision_inference_model():
|
||||
return "Llama3.2-11B-Vision-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
|
BIN
llama_stack/providers/tests/inference/pasta.jpeg
Normal file
BIN
llama_stack/providers/tests/inference/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -15,6 +14,9 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
||||
|
@ -22,15 +24,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
# --env FIREWORKS_API_KEY=<your_api_key>
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_expected_stop_reason(model: str):
|
||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||
|
||||
|
|
126
llama_stack/providers/tests/inference/test_vision_inference.py
Normal file
126
llama_stack/providers/tests/inference/test_vision_inference.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from .utils import group_chunks
|
||||
|
||||
THIS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
class TestVisionModelInference:
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_chat_completion_non_streaming(
|
||||
self, vision_inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(
|
||||
vision_inference_model
|
||||
)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
images = [
|
||||
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
|
||||
ImageMedia(
|
||||
image=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
# These are a bit hit-and-miss, need to be careful
|
||||
expected_strings_to_check = [
|
||||
["spaghetti"],
|
||||
["puppy"],
|
||||
]
|
||||
for image, expected_strings in zip(images, expected_strings_to_check):
|
||||
response = await inference_impl.chat_completion(
|
||||
model=vision_inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
content=[image, "Describe this image in two sentences."]
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, ChatCompletionResponse)
|
||||
assert response.completion_message.role == "assistant"
|
||||
assert isinstance(response.completion_message.content, str)
|
||||
for expected_string in expected_strings:
|
||||
assert expected_string in response.completion_message.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vision_chat_completion_streaming(
|
||||
self, vision_inference_model, inference_stack
|
||||
):
|
||||
inference_impl, _ = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(
|
||||
vision_inference_model
|
||||
)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::together",
|
||||
"remote::fireworks",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
images = [
|
||||
ImageMedia(
|
||||
image=URL(
|
||||
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||
)
|
||||
),
|
||||
]
|
||||
expected_strings_to_check = [
|
||||
["puppy"],
|
||||
]
|
||||
for image, expected_strings in zip(images, expected_strings_to_check):
|
||||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=vision_inference_model,
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
content=[image, "Describe this image in two sentences."]
|
||||
),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
]
|
||||
|
||||
assert len(response) > 0
|
||||
assert all(
|
||||
isinstance(chunk, ChatCompletionResponseStreamChunk)
|
||||
for chunk in response
|
||||
)
|
||||
grouped = group_chunks(response)
|
||||
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||
|
||||
content = "".join(
|
||||
chunk.event.delta
|
||||
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||
)
|
||||
for expected_string in expected_strings:
|
||||
assert expected_string in content
|
16
llama_stack/providers/tests/inference/utils.py
Normal file
16
llama_stack/providers/tests/inference/utils.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
def group_chunks(response):
|
||||
return {
|
||||
event_type: list(group)
|
||||
for event_type, group in itertools.groupby(
|
||||
response, key=lambda chunk: chunk.event.event_type
|
||||
)
|
||||
}
|
|
@ -46,6 +46,9 @@ def text_from_choice(choice) -> str:
|
|||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
|
||||
if hasattr(choice, "message"):
|
||||
return choice.message.content
|
||||
|
||||
return choice.text
|
||||
|
||||
|
||||
|
@ -99,7 +102,6 @@ def process_chat_completion_response(
|
|||
async def process_completion_stream_response(
|
||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||
) -> AsyncGenerator:
|
||||
|
||||
stop_reason = None
|
||||
|
||||
async for chunk in stream:
|
||||
|
@ -158,6 +160,10 @@ async def process_chat_completion_stream_response(
|
|||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
if not text:
|
||||
# Sometimes you get empty chunks from providers
|
||||
continue
|
||||
|
||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||
if not ipython and text.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
|
|
|
@ -3,10 +3,14 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from typing import Tuple
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from PIL import Image as PIL_Image
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
@ -24,6 +28,73 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
|
||||
|
||||
def content_has_media(content: InterleavedTextMedia):
|
||||
def _has_media_content(c):
|
||||
return isinstance(c, ImageMedia)
|
||||
|
||||
if isinstance(content, list):
|
||||
return any(_has_media_content(c) for c in content)
|
||||
else:
|
||||
return _has_media_content(content)
|
||||
|
||||
|
||||
def messages_have_media(messages: List[Message]):
|
||||
return any(content_has_media(m.content) for m in messages)
|
||||
|
||||
|
||||
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
return messages_have_media(request.messages)
|
||||
else:
|
||||
return content_has_media(request.content)
|
||||
|
||||
|
||||
def convert_image_media_to_url(media: ImageMedia) -> str:
|
||||
if isinstance(media.image, PIL_Image.Image):
|
||||
if media.image.format == "PNG":
|
||||
format = "png"
|
||||
elif media.image.format == "GIF":
|
||||
format = "gif"
|
||||
elif media.image.format == "JPEG":
|
||||
format = "jpeg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported image format {media.image.format}")
|
||||
|
||||
bytestream = io.BytesIO()
|
||||
media.image.save(bytestream, format=media.image.format)
|
||||
bytestream.seek(0)
|
||||
return f"data:image/{format};base64," + base64.b64encode(
|
||||
bytestream.getvalue()
|
||||
).decode("utf-8")
|
||||
else:
|
||||
assert isinstance(media.image, URL)
|
||||
return media.image.uri
|
||||
|
||||
|
||||
def convert_message_to_dict(message: Message) -> dict:
|
||||
def _convert_content(content) -> dict:
|
||||
if isinstance(content, ImageMedia):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": convert_image_media_to_url(content),
|
||||
},
|
||||
}
|
||||
else:
|
||||
assert isinstance(content, str)
|
||||
return {"type": "text", "text": content}
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content = [_convert_content(c) for c in message.content]
|
||||
else:
|
||||
content = [_convert_content(message.content)]
|
||||
|
||||
return {
|
||||
"role": message.role,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
def completion_request_to_prompt(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue