mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
Enable vision models for Together and Fireworks
This commit is contained in:
parent
8de845a96d
commit
03013dafc1
9 changed files with 297 additions and 35 deletions
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import FireworksImplConfig
|
||||
|
@ -129,7 +131,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
self, request: ChatCompletionRequest, client: Fireworks
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = await client.completion.acreate(**params)
|
||||
if "messages" in params:
|
||||
r = await client.chat.completions.acreate(**params)
|
||||
else:
|
||||
r = await client.completion.acreate(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
|
@ -137,24 +142,44 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
stream = client.completion.acreate(**params)
|
||||
if "messages" in params:
|
||||
print(f"Using chat completion endpoint: {params}")
|
||||
stream = client.chat.completions.acreate(**params)
|
||||
else:
|
||||
stream = client.completion.acreate(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request) -> dict:
|
||||
prompt = ""
|
||||
if type(request) == ChatCompletionRequest:
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
elif type(request) == CompletionRequest:
|
||||
prompt = completion_request_to_prompt(request, self.formatter)
|
||||
def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
elif isinstance(request, CompletionRequest):
|
||||
assert (
|
||||
not media_present
|
||||
), "Fireworks does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
else:
|
||||
raise ValueError(f"Unknown request type {type(request)}")
|
||||
|
||||
# Fireworks always prepends with BOS
|
||||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
if "prompt" in input_dict:
|
||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
@ -172,9 +197,10 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown response format {fmt.type}")
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": prompt,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**options,
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_message_to_dict,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import TogetherImplConfig
|
||||
|
@ -102,7 +104,7 @@ class TogetherInferenceAdapter(
|
|||
return process_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
params = self._get_params(request)
|
||||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
|
@ -131,14 +133,6 @@ class TogetherInferenceAdapter(
|
|||
|
||||
return options
|
||||
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -172,7 +166,10 @@ class TogetherInferenceAdapter(
|
|||
self, request: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = self._get_client().completions.create(**params)
|
||||
if "messages" in params:
|
||||
r = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
r = self._get_client().completions.create(**params)
|
||||
return process_chat_completion_response(r, self.formatter)
|
||||
|
||||
async def _stream_chat_completion(
|
||||
|
@ -182,7 +179,10 @@ class TogetherInferenceAdapter(
|
|||
|
||||
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||
async def _to_async_generator():
|
||||
s = self._get_client().completions.create(**params)
|
||||
if "messages" in params:
|
||||
s = self._get_client().chat.completions.create(**params)
|
||||
else:
|
||||
s = self._get_client().completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
|
@ -192,10 +192,29 @@ class TogetherInferenceAdapter(
|
|||
):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
def _get_params(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest]
|
||||
) -> dict:
|
||||
input_dict = {}
|
||||
media_present = request_has_media(request)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present:
|
||||
input_dict["messages"] = [
|
||||
convert_message_to_dict(m) for m in request.messages
|
||||
]
|
||||
else:
|
||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||
request, self.formatter
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
not media_present
|
||||
), "Together does not support media for Completion requests"
|
||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||
|
||||
return {
|
||||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue