forked from phoenix-oss/llama-stack-mirror
add image support to NVIDIA inference provider (#907)
# What does this PR do? add support to the NVIDIA Inference provider for image inputs ## Test Plan 1. Run local [Llama 3.2 11b vision instruct](https://build.nvidia.com/meta/llama-3.2-11b-vision-instruct?snippet_tab=Docker) NIM 2. Start a stack, e.g. `llama stack run llama_stack/templates/nvidia/run.yaml --env NVIDIA_BASE_URL=http://localhost:8000` 3. Run image tests, e.g. `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_inference.py --vision-inference-model meta-llama/Llama-3.2-11B-Vision-Instruct -k image` ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests.
This commit is contained in:
parent
439d0da84c
commit
e21c8b6d80
3 changed files with 43 additions and 6 deletions
|
@ -186,7 +186,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
await check_health(self._config) # this raises errors
|
await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
request = convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
request=ChatCompletionRequest(
|
request=ChatCompletionRequest(
|
||||||
model=self.get_provider_model_id(model_id),
|
model=self.get_provider_model_id(model_id),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional
|
from typing import Any, AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import (
|
from llama_models.datatypes import (
|
||||||
GreedySamplingStrategy,
|
GreedySamplingStrategy,
|
||||||
|
@ -23,6 +23,8 @@ from openai import AsyncStream
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||||
|
ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam,
|
||||||
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
ChatCompletionMessageParam as OpenAIChatCompletionMessage,
|
||||||
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
|
||||||
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
|
||||||
|
@ -33,6 +35,9 @@ from openai.types.chat.chat_completion import (
|
||||||
Choice as OpenAIChoice,
|
Choice as OpenAIChoice,
|
||||||
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
|
||||||
)
|
)
|
||||||
|
from openai.types.chat.chat_completion_content_part_image_param import (
|
||||||
|
ImageURL as OpenAIImageURL,
|
||||||
|
)
|
||||||
from openai.types.chat.chat_completion_message_tool_call_param import (
|
from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||||
Function as OpenAIFunction,
|
Function as OpenAIFunction,
|
||||||
)
|
)
|
||||||
|
@ -40,6 +45,9 @@ from openai.types.completion import Completion as OpenAICompletion
|
||||||
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
ImageContentItem,
|
||||||
|
InterleavedContent,
|
||||||
|
TextContentItem,
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
@ -62,6 +70,10 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_image_content_to_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||||
"""
|
"""
|
||||||
|
@ -139,7 +151,7 @@ def _convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
async def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||||
"""
|
"""
|
||||||
Convert a Message to an OpenAI API-compatible dictionary.
|
Convert a Message to an OpenAI API-compatible dictionary.
|
||||||
"""
|
"""
|
||||||
|
@ -159,11 +171,35 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported message role: {message['role']}")
|
raise ValueError(f"Unsupported message role: {message['role']}")
|
||||||
|
|
||||||
|
# Map Llama Stack spec to OpenAI spec -
|
||||||
|
# str -> str
|
||||||
|
# {"type": "text", "text": ...} -> {"type": "text", "text": ...}
|
||||||
|
# {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}}
|
||||||
|
# {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}}
|
||||||
|
# List[...] -> List[...]
|
||||||
|
async def _convert_user_message_content(
|
||||||
|
content: InterleavedContent,
|
||||||
|
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||||
|
# Llama Stack and OpenAI spec match for str and text input
|
||||||
|
if isinstance(content, str) or isinstance(content, TextContentItem):
|
||||||
|
return content
|
||||||
|
elif isinstance(content, ImageContentItem):
|
||||||
|
return OpenAIChatCompletionContentPartImageParam(
|
||||||
|
image_url=OpenAIImageURL(
|
||||||
|
url=await convert_image_content_to_url(content)
|
||||||
|
),
|
||||||
|
type="image_url",
|
||||||
|
)
|
||||||
|
elif isinstance(content, List):
|
||||||
|
return [await _convert_user_message_content(item) for item in content]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {type(content)}")
|
||||||
|
|
||||||
out: OpenAIChatCompletionMessage = None
|
out: OpenAIChatCompletionMessage = None
|
||||||
if isinstance(message, UserMessage):
|
if isinstance(message, UserMessage):
|
||||||
out = OpenAIChatCompletionUserMessage(
|
out = OpenAIChatCompletionUserMessage(
|
||||||
role="user",
|
role="user",
|
||||||
content=message.content, # TODO(mf): handle image content
|
content=await _convert_user_message_content(message.content),
|
||||||
)
|
)
|
||||||
elif isinstance(message, CompletionMessage):
|
elif isinstance(message, CompletionMessage):
|
||||||
out = OpenAIChatCompletionAssistantMessage(
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
|
@ -198,7 +234,7 @@ def _convert_message(message: Message | Dict) -> OpenAIChatCompletionMessage:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def convert_chat_completion_request(
|
async def convert_chat_completion_request(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -235,7 +271,7 @@ def convert_chat_completion_request(
|
||||||
nvext = {}
|
nvext = {}
|
||||||
payload: Dict[str, Any] = dict(
|
payload: Dict[str, Any] = dict(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
messages=[_convert_message(message) for message in request.messages],
|
messages=[await _convert_message(message) for message in request.messages],
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
n=n,
|
n=n,
|
||||||
extra_body=dict(nvext=nvext),
|
extra_body=dict(nvext=nvext),
|
||||||
|
|
|
@ -186,6 +186,7 @@ async def localize_image_content(media: ImageContentItem) -> Tuple[bytes, str]:
|
||||||
return content, format
|
return content, format
|
||||||
else:
|
else:
|
||||||
# data is a base64 encoded string, decode it to bytes first
|
# data is a base64 encoded string, decode it to bytes first
|
||||||
|
# TODO(mf): do this more efficiently, decode less
|
||||||
data_bytes = base64.b64decode(image.data)
|
data_bytes = base64.b64decode(image.data)
|
||||||
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
|
pil_image = PIL_Image.open(io.BytesIO(data_bytes))
|
||||||
return data_bytes, pil_image.format
|
return data_bytes, pil_image.format
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue