mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
feat(internal): add image_url download feature to OpenAIMixin
simplify Ollama inference adapter by - - moving image_url download code to OpenAIMixin - being a ModelRegistryHelper instead of having one (mypy blocks check_model_availability method assignment) testing - - add unit tests for new download feature - add integration tests for openai_chat_completion w/ image_url (close test gap)
This commit is contained in:
parent
e3f77c1004
commit
65c4ffca28
5 changed files with 257 additions and 87 deletions
|
@ -6,8 +6,7 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
@ -33,10 +32,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
|
@ -60,7 +55,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
|
@ -73,7 +67,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
content_has_media,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
localize_image_content,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -84,6 +77,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
|
|||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
|
@ -91,8 +85,10 @@ class OllamaInferenceAdapter(
|
|||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
|
@ -171,6 +167,7 @@ class OllamaInferenceAdapter(
|
|||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
self._model_cache = {m.identifier: m for m in models} # for fast check_model_availability
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
|
@ -190,9 +187,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
|
@ -301,7 +295,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
|
@ -410,7 +404,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
model = await super().register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
|
||||
|
@ -441,75 +435,6 @@ class OllamaInferenceAdapter(
|
|||
|
||||
return model
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url and c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
||||
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
return m
|
||||
|
||||
messages = [await _convert_message(m) for m in messages]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await OpenAIMixin.openai_chat_completion(self, **params)
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
async def _convert_content(content) -> dict:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue