forked from phoenix-oss/llama-stack-mirror
Enable vision models for (Together, Fireworks, Meta-Reference, Ollama) (#376)
* Enable vision models for Together and Fireworks * Works with ollama 0.4.0 pre-release with the vision model * localize media for meta_reference inference * Fix
This commit is contained in:
parent
db30809141
commit
cde9bc1388
11 changed files with 465 additions and 81 deletions
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
convert_message_to_dict,
|
||||||
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
@ -82,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest, client: Fireworks
|
self, request: CompletionRequest, client: Fireworks
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await client.completion.acreate(**params)
|
r = await client.completion.acreate(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(
|
async def _stream_completion(
|
||||||
self, request: CompletionRequest, client: Fireworks
|
self, request: CompletionRequest, client: Fireworks
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = client.completion.acreate(**params)
|
stream = client.completion.acreate(**params)
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
|
@ -128,33 +130,55 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await client.completion.acreate(**params)
|
if "messages" in params:
|
||||||
|
r = await client.chat.completions.acreate(**params)
|
||||||
|
else:
|
||||||
|
r = await client.completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
return process_chat_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest, client: Fireworks
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
if "messages" in params:
|
||||||
|
stream = client.chat.completions.acreate(**params)
|
||||||
|
else:
|
||||||
|
stream = client.completion.acreate(**params)
|
||||||
|
|
||||||
stream = client.completion.acreate(**params)
|
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
stream, self.formatter
|
stream, self.formatter
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request) -> dict:
|
async def _get_params(
|
||||||
prompt = ""
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
if type(request) == ChatCompletionRequest:
|
) -> dict:
|
||||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
input_dict = {}
|
||||||
elif type(request) == CompletionRequest:
|
media_present = request_has_media(request)
|
||||||
prompt = completion_request_to_prompt(request, self.formatter)
|
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if media_present:
|
||||||
|
input_dict["messages"] = [
|
||||||
|
await convert_message_to_dict(m) for m in request.messages
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
elif isinstance(request, CompletionRequest):
|
||||||
|
assert (
|
||||||
|
not media_present
|
||||||
|
), "Fireworks does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
raise ValueError(f"Unknown request type {type(request)}")
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if prompt.startswith("<|begin_of_text|>"):
|
if "prompt" in input_dict:
|
||||||
prompt = prompt[len("<|begin_of_text|>") :]
|
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||||
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
options = get_sampling_options(request.sampling_params)
|
options = get_sampling_options(request.sampling_params)
|
||||||
options.setdefault("max_tokens", 512)
|
options.setdefault("max_tokens", 512)
|
||||||
|
@ -172,9 +196,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": self.map_to_provider_model(request.model),
|
||||||
"prompt": prompt,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**options,
|
**options,
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
convert_image_media_to_url,
|
||||||
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
OLLAMA_SUPPORTED_MODELS = {
|
OLLAMA_SUPPORTED_MODELS = {
|
||||||
|
@ -38,6 +40,7 @@ OLLAMA_SUPPORTED_MODELS = {
|
||||||
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
"Llama3.2-3B-Instruct": "llama3.2:3b-instruct-fp16",
|
||||||
"Llama-Guard-3-8B": "llama-guard3:8b",
|
"Llama-Guard-3-8B": "llama-guard3:8b",
|
||||||
"Llama-Guard-3-1B": "llama-guard3:1b",
|
"Llama-Guard-3-1B": "llama-guard3:1b",
|
||||||
|
"Llama3.2-11B-Vision-Instruct": "x/llama3.2-vision:11b-instruct-fp16",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,22 +112,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
|
||||||
sampling_options = get_sampling_options(request.sampling_params)
|
|
||||||
# This is needed since the Ollama API expects num_predict to be set
|
|
||||||
# for early truncation instead of max_tokens.
|
|
||||||
if sampling_options["max_tokens"] is not None:
|
|
||||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
|
||||||
return {
|
|
||||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
|
||||||
"prompt": completion_request_to_prompt(request, self.formatter),
|
|
||||||
"options": sampling_options,
|
|
||||||
"raw": True,
|
|
||||||
"stream": request.stream,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.generate(**params)
|
s = await self.client.generate(**params)
|
||||||
|
@ -142,7 +131,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
r = await self.client.generate(**params)
|
||||||
assert isinstance(r, dict)
|
assert isinstance(r, dict)
|
||||||
|
|
||||||
|
@ -183,26 +172,66 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
|
) -> dict:
|
||||||
|
sampling_options = get_sampling_options(request.sampling_params)
|
||||||
|
# This is needed since the Ollama API expects num_predict to be set
|
||||||
|
# for early truncation instead of max_tokens.
|
||||||
|
if sampling_options.get("max_tokens") is not None:
|
||||||
|
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||||
|
|
||||||
|
input_dict = {}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if media_present:
|
||||||
|
contents = [
|
||||||
|
await convert_message_to_dict_for_ollama(m)
|
||||||
|
for m in request.messages
|
||||||
|
]
|
||||||
|
# flatten the list of lists
|
||||||
|
input_dict["messages"] = [
|
||||||
|
item for sublist in contents for item in sublist
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
input_dict["raw"] = True
|
||||||
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not media_present
|
||||||
|
), "Ollama does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
|
input_dict["raw"] = True
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
**input_dict,
|
||||||
"options": get_sampling_options(request.sampling_params),
|
"options": sampling_options,
|
||||||
"raw": True,
|
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
if "messages" in params:
|
||||||
|
r = await self.client.chat(**params)
|
||||||
|
else:
|
||||||
|
r = await self.client.generate(**params)
|
||||||
assert isinstance(r, dict)
|
assert isinstance(r, dict)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
if "message" in r:
|
||||||
finish_reason=r["done_reason"] if r["done"] else None,
|
choice = OpenAICompatCompletionChoice(
|
||||||
text=r["response"],
|
finish_reason=r["done_reason"] if r["done"] else None,
|
||||||
)
|
text=r["message"]["content"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=r["done_reason"] if r["done"] else None,
|
||||||
|
text=r["response"],
|
||||||
|
)
|
||||||
response = OpenAICompatCompletionResponse(
|
response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
@ -211,15 +240,24 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.generate(**params)
|
if "messages" in params:
|
||||||
|
s = await self.client.chat(**params)
|
||||||
|
else:
|
||||||
|
s = await self.client.generate(**params)
|
||||||
async for chunk in s:
|
async for chunk in s:
|
||||||
choice = OpenAICompatCompletionChoice(
|
if "message" in chunk:
|
||||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
choice = OpenAICompatCompletionChoice(
|
||||||
text=chunk["response"],
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||||
)
|
text=chunk["message"]["content"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||||
|
text=chunk["response"],
|
||||||
|
)
|
||||||
yield OpenAICompatCompletionResponse(
|
yield OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
@ -236,3 +274,26 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_message_to_dict_for_ollama(message: Message) -> List[dict]:
|
||||||
|
async def _convert_content(content) -> dict:
|
||||||
|
if isinstance(content, ImageMedia):
|
||||||
|
return {
|
||||||
|
"role": message.role,
|
||||||
|
"images": [
|
||||||
|
await convert_image_media_to_url(
|
||||||
|
content, download=True, include_format=False
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"role": message.role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
return [await _convert_content(c) for c in message.content]
|
||||||
|
else:
|
||||||
|
return [await _convert_content(message.content)]
|
||||||
|
|
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
convert_message_to_dict,
|
||||||
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
@ -97,12 +99,12 @@ class TogetherInferenceAdapter(
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest
|
self, request: CompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client().completions.create(**params)
|
r = self._get_client().completions.create(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
params = self._get_params_for_completion(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
|
@ -131,14 +133,6 @@ class TogetherInferenceAdapter(
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
|
||||||
return {
|
|
||||||
"model": self.map_to_provider_model(request.model),
|
|
||||||
"prompt": completion_request_to_prompt(request, self.formatter),
|
|
||||||
"stream": request.stream,
|
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -171,18 +165,24 @@ class TogetherInferenceAdapter(
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = self._get_client().completions.create(**params)
|
if "messages" in params:
|
||||||
|
r = self._get_client().chat.completions.create(**params)
|
||||||
|
else:
|
||||||
|
r = self._get_client().completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
return process_chat_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
async def _to_async_generator():
|
async def _to_async_generator():
|
||||||
s = self._get_client().completions.create(**params)
|
if "messages" in params:
|
||||||
|
s = self._get_client().chat.completions.create(**params)
|
||||||
|
else:
|
||||||
|
s = self._get_client().completions.create(**params)
|
||||||
for chunk in s:
|
for chunk in s:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
@ -192,10 +192,29 @@ class TogetherInferenceAdapter(
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(
|
||||||
|
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||||
|
) -> dict:
|
||||||
|
input_dict = {}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
if media_present:
|
||||||
|
input_dict["messages"] = [
|
||||||
|
await convert_message_to_dict(m) for m in request.messages
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
|
request, self.formatter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
not media_present
|
||||||
|
), "Together does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": self.map_to_provider_model(request.model),
|
||||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**self._build_options(request.sampling_params, request.response_format),
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
convert_image_media_to_url,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -87,6 +92,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
request = await request_with_localized_media(request)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self._stream_completion(request)
|
return self._stream_completion(request)
|
||||||
|
@ -211,6 +217,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
self.check_model(request)
|
self.check_model(request)
|
||||||
|
request = await request_with_localized_media(request)
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
if SEMAPHORE.locked():
|
if SEMAPHORE.locked():
|
||||||
|
@ -388,3 +395,31 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def request_with_localized_media(
|
||||||
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
) -> Union[ChatCompletionRequest, CompletionRequest]:
|
||||||
|
if not request_has_media(request):
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def _convert_single_content(content):
|
||||||
|
if isinstance(content, ImageMedia):
|
||||||
|
url = await convert_image_media_to_url(content, download=True)
|
||||||
|
return ImageMedia(image=URL(uri=url))
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
async def _convert_content(content):
|
||||||
|
if isinstance(content, list):
|
||||||
|
return [await _convert_single_content(c) for c in content]
|
||||||
|
else:
|
||||||
|
return await _convert_single_content(content)
|
||||||
|
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
for m in request.messages:
|
||||||
|
m.content = await _convert_content(m.content)
|
||||||
|
else:
|
||||||
|
request.content = await _convert_content(request.content)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
|
@ -19,12 +19,11 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
config.addinivalue_line(
|
for model in ["llama_8b", "llama_3b", "llama_vision"]:
|
||||||
"markers", "llama_8b: mark test to run only with the given model"
|
config.addinivalue_line(
|
||||||
)
|
"markers", f"{model}: mark test to run only with the given model"
|
||||||
config.addinivalue_line(
|
)
|
||||||
"markers", "llama_3b: mark test to run only with the given model"
|
|
||||||
)
|
|
||||||
for fixture_name in INFERENCE_FIXTURES:
|
for fixture_name in INFERENCE_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
|
@ -37,6 +36,14 @@ MODEL_PARAMS = [
|
||||||
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
pytest.param("Llama3.2-3B-Instruct", marks=pytest.mark.llama_3b, id="llama_3b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
VISION_MODEL_PARAMS = [
|
||||||
|
pytest.param(
|
||||||
|
"Llama3.2-11B-Vision-Instruct",
|
||||||
|
marks=pytest.mark.llama_vision,
|
||||||
|
id="llama_vision",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
|
@ -44,7 +51,11 @@ def pytest_generate_tests(metafunc):
|
||||||
if model:
|
if model:
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(model, id="")]
|
||||||
else:
|
else:
|
||||||
params = MODEL_PARAMS
|
cls_name = metafunc.cls.__name__
|
||||||
|
if "Vision" in cls_name:
|
||||||
|
params = VISION_MODEL_PARAMS
|
||||||
|
else:
|
||||||
|
params = MODEL_PARAMS
|
||||||
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
|
|
BIN
llama_stack/providers/tests/inference/pasta.jpeg
Normal file
BIN
llama_stack/providers/tests/inference/pasta.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 438 KiB |
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -15,6 +14,9 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from .utils import group_chunks
|
||||||
|
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
# pytest -v -s llama_stack/providers/tests/inference/test_inference.py
|
||||||
|
@ -22,15 +24,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
# --env FIREWORKS_API_KEY=<your_api_key>
|
# --env FIREWORKS_API_KEY=<your_api_key>
|
||||||
|
|
||||||
|
|
||||||
def group_chunks(response):
|
|
||||||
return {
|
|
||||||
event_type: list(group)
|
|
||||||
for event_type, group in itertools.groupby(
|
|
||||||
response, key=lambda chunk: chunk.event.event_type
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_expected_stop_reason(model: str):
|
def get_expected_stop_reason(model: str):
|
||||||
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
return StopReason.end_of_message if "Llama3.1" in model else StopReason.end_of_turn
|
||||||
|
|
||||||
|
|
128
llama_stack/providers/tests/inference/test_vision_inference.py
Normal file
128
llama_stack/providers/tests/inference/test_vision_inference.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
from .utils import group_chunks
|
||||||
|
|
||||||
|
THIS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisionModelInference:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vision_chat_completion_non_streaming(
|
||||||
|
self, inference_model, inference_stack
|
||||||
|
):
|
||||||
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type not in (
|
||||||
|
"meta-reference",
|
||||||
|
"remote::together",
|
||||||
|
"remote::fireworks",
|
||||||
|
"remote::ollama",
|
||||||
|
):
|
||||||
|
pytest.skip(
|
||||||
|
"Other inference providers don't support vision chat completion() yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
ImageMedia(image=PIL_Image.open(THIS_DIR / "pasta.jpeg")),
|
||||||
|
ImageMedia(
|
||||||
|
image=URL(
|
||||||
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# These are a bit hit-and-miss, need to be careful
|
||||||
|
expected_strings_to_check = [
|
||||||
|
["spaghetti"],
|
||||||
|
["puppy"],
|
||||||
|
]
|
||||||
|
for image, expected_strings in zip(images, expected_strings_to_check):
|
||||||
|
response = await inference_impl.chat_completion(
|
||||||
|
model=inference_model,
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
UserMessage(
|
||||||
|
content=[image, "Describe this image in two sentences."]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, ChatCompletionResponse)
|
||||||
|
assert response.completion_message.role == "assistant"
|
||||||
|
assert isinstance(response.completion_message.content, str)
|
||||||
|
for expected_string in expected_strings:
|
||||||
|
assert expected_string in response.completion_message.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vision_chat_completion_streaming(
|
||||||
|
self, inference_model, inference_stack
|
||||||
|
):
|
||||||
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type not in (
|
||||||
|
"meta-reference",
|
||||||
|
"remote::together",
|
||||||
|
"remote::fireworks",
|
||||||
|
"remote::ollama",
|
||||||
|
):
|
||||||
|
pytest.skip(
|
||||||
|
"Other inference providers don't support vision chat completion() yet"
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
ImageMedia(
|
||||||
|
image=URL(
|
||||||
|
uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg"
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
expected_strings_to_check = [
|
||||||
|
["puppy"],
|
||||||
|
]
|
||||||
|
for image, expected_strings in zip(images, expected_strings_to_check):
|
||||||
|
response = [
|
||||||
|
r
|
||||||
|
async for r in await inference_impl.chat_completion(
|
||||||
|
model=inference_model,
|
||||||
|
messages=[
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
UserMessage(
|
||||||
|
content=[image, "Describe this image in two sentences."]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(response) > 0
|
||||||
|
assert all(
|
||||||
|
isinstance(chunk, ChatCompletionResponseStreamChunk)
|
||||||
|
for chunk in response
|
||||||
|
)
|
||||||
|
grouped = group_chunks(response)
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.start]) == 1
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.progress]) > 0
|
||||||
|
assert len(grouped[ChatCompletionResponseEventType.complete]) == 1
|
||||||
|
|
||||||
|
content = "".join(
|
||||||
|
chunk.event.delta
|
||||||
|
for chunk in grouped[ChatCompletionResponseEventType.progress]
|
||||||
|
)
|
||||||
|
for expected_string in expected_strings:
|
||||||
|
assert expected_string in content
|
16
llama_stack/providers/tests/inference/utils.py
Normal file
16
llama_stack/providers/tests/inference/utils.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
|
def group_chunks(response):
|
||||||
|
return {
|
||||||
|
event_type: list(group)
|
||||||
|
for event_type, group in itertools.groupby(
|
||||||
|
response, key=lambda chunk: chunk.event.event_type
|
||||||
|
)
|
||||||
|
}
|
|
@ -46,6 +46,9 @@ def text_from_choice(choice) -> str:
|
||||||
if hasattr(choice, "delta") and choice.delta:
|
if hasattr(choice, "delta") and choice.delta:
|
||||||
return choice.delta.content
|
return choice.delta.content
|
||||||
|
|
||||||
|
if hasattr(choice, "message"):
|
||||||
|
return choice.message.content
|
||||||
|
|
||||||
return choice.text
|
return choice.text
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +102,6 @@ def process_chat_completion_response(
|
||||||
async def process_completion_stream_response(
|
async def process_completion_stream_response(
|
||||||
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
@ -158,6 +160,10 @@ async def process_chat_completion_stream_response(
|
||||||
break
|
break
|
||||||
|
|
||||||
text = text_from_choice(choice)
|
text = text_from_choice(choice)
|
||||||
|
if not text:
|
||||||
|
# Sometimes you get empty chunks from providers
|
||||||
|
continue
|
||||||
|
|
||||||
# check if its a tool call ( aka starts with <|python_tag|> )
|
# check if its a tool call ( aka starts with <|python_tag|> )
|
||||||
if not ipython and text.startswith("<|python_tag|>"):
|
if not ipython and text.startswith("<|python_tag|>"):
|
||||||
ipython = True
|
ipython = True
|
||||||
|
|
|
@ -3,10 +3,16 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from PIL import Image as PIL_Image
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
@ -24,6 +30,90 @@ from llama_models.sku_list import resolve_model
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
|
|
||||||
|
def content_has_media(content: InterleavedTextMedia):
|
||||||
|
def _has_media_content(c):
|
||||||
|
return isinstance(c, ImageMedia)
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
return any(_has_media_content(c) for c in content)
|
||||||
|
else:
|
||||||
|
return _has_media_content(content)
|
||||||
|
|
||||||
|
|
||||||
|
def messages_have_media(messages: List[Message]):
|
||||||
|
return any(content_has_media(m.content) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
|
def request_has_media(request: Union[ChatCompletionRequest, CompletionRequest]):
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
return messages_have_media(request.messages)
|
||||||
|
else:
|
||||||
|
return content_has_media(request.content)
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_image_media_to_url(
|
||||||
|
media: ImageMedia, download: bool = False, include_format: bool = True
|
||||||
|
) -> str:
|
||||||
|
if isinstance(media.image, PIL_Image.Image):
|
||||||
|
if media.image.format == "PNG":
|
||||||
|
format = "png"
|
||||||
|
elif media.image.format == "GIF":
|
||||||
|
format = "gif"
|
||||||
|
elif media.image.format == "JPEG":
|
||||||
|
format = "jpeg"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported image format {media.image.format}")
|
||||||
|
|
||||||
|
bytestream = io.BytesIO()
|
||||||
|
media.image.save(bytestream, format=media.image.format)
|
||||||
|
bytestream.seek(0)
|
||||||
|
content = bytestream.getvalue()
|
||||||
|
else:
|
||||||
|
if not download:
|
||||||
|
return media.image.uri
|
||||||
|
else:
|
||||||
|
assert isinstance(media.image, URL)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(media.image.uri)
|
||||||
|
content = r.content
|
||||||
|
content_type = r.headers.get("content-type")
|
||||||
|
if content_type:
|
||||||
|
format = content_type.split("/")[-1]
|
||||||
|
else:
|
||||||
|
format = "png"
|
||||||
|
|
||||||
|
if include_format:
|
||||||
|
return f"data:image/{format};base64," + base64.b64encode(content).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def convert_message_to_dict(message: Message) -> dict:
|
||||||
|
async def _convert_content(content) -> dict:
|
||||||
|
if isinstance(content, ImageMedia):
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": await convert_image_media_to_url(content),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
assert isinstance(content, str)
|
||||||
|
return {"type": "text", "text": content}
|
||||||
|
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
content = [await _convert_content(c) for c in message.content]
|
||||||
|
else:
|
||||||
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": message.role,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def completion_request_to_prompt(
|
def completion_request_to_prompt(
|
||||||
request: CompletionRequest, formatter: ChatFormat
|
request: CompletionRequest, formatter: ChatFormat
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue