Enable vision models for (Together, Fireworks, Meta-Reference, Ollama) (#376)

* Enable vision models for Together and Fireworks

* Works with ollama 0.4.0 pre-release with the vision model

* localize media for meta_reference inference

* Fix
This commit is contained in:
Ashwin Bharambe 2024-11-05 16:22:33 -08:00 committed by GitHub
parent db30809141
commit cde9bc1388
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 465 additions and 81 deletions

View file

@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
convert_message_to_dict,
request_has_media,
)
from .config import FireworksImplConfig
@ -82,14 +84,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
) -> CompletionResponse:
params = self._get_params(request)
params = await self._get_params(request)
r = await client.completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
stream = client.completion.acreate(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
@ -128,33 +130,55 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await client.completion.acreate(**params)
params = await self._get_params(request)
if "messages" in params:
r = await client.chat.completions.acreate(**params)
else:
r = await client.completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = self._get_params(request)
params = await self._get_params(request)
if "messages" in params:
stream = client.chat.completions.acreate(**params)
else:
stream = client.completion.acreate(**params)
stream = client.completion.acreate(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(self, request) -> dict:
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(request, self.formatter)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
async def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
input_dict = {}
media_present = request_has_media(request)
if isinstance(request, ChatCompletionRequest):
if media_present:
input_dict["messages"] = [
await convert_message_to_dict(m) for m in request.messages
]
else:
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
elif isinstance(request, CompletionRequest):
assert (
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
@ -172,9 +196,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
**input_dict,
"stream": request.stream,
**options,
}