Enable vision models for (Together, Fireworks, Meta-Reference, Ollama) (#376)

* Enable vision models for Together and Fireworks

* Works with ollama 0.4.0 pre-release with the vision model

* localize media for meta_reference inference

* Fix
This commit is contained in:
Ashwin Bharambe 2024-11-05 16:22:33 -08:00 committed by GitHub
parent db30809141
commit cde9bc1388
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 465 additions and 81 deletions

View file

@ -14,6 +14,11 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url,
request_has_media,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator
@ -87,6 +92,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if request.stream:
return self._stream_completion(request)
@ -211,6 +217,7 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
logprobs=logprobs,
)
self.check_model(request)
request = await request_with_localized_media(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
@ -388,3 +395,31 @@ class MetaReferenceInferenceImpl(Inference, ModelsProtocolPrivate):
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest],
) -> Union[ChatCompletionRequest, CompletionRequest]:
if not request_has_media(request):
return request
async def _convert_single_content(content):
if isinstance(content, ImageMedia):
url = await convert_image_media_to_url(content, download=True)
return ImageMedia(image=URL(uri=url))
else:
return content
async def _convert_content(content):
if isinstance(content, list):
return [await _convert_single_content(c) for c in content]
else:
return await _convert_single_content(content)
if isinstance(request, ChatCompletionRequest):
for m in request.messages:
m.content = await _convert_content(m.content)
else:
request.content = await _convert_content(request.content)
return request