chore: turn OpenAIMixin into a pydantic.BaseModel

- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key
 - remove use of LiteLLMOpenAIMixin
 - add default initialize/shutdown methods to OpenAIMixin
 - remove __init__s to allow proper pydantic construction
 - remove dead code from vllm adapter and associated / duplicate unit tests
 - update vllm adapter to use openaimixin for model registration
 - remove ModelRegistryHelper from fireworks & together adapters
 - remove Inference from nvidia adapter
 - complete type hints on embedding_model_metadata
 - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc
 - new recordings for ollama
This commit is contained in:
Matthew Farrellee 2025-10-02 20:47:54 -04:00
parent ce77c27ff8
commit 60f0056cbc
57 changed files with 12520 additions and 1254 deletions

View file

@ -10,6 +10,6 @@ from .config import OllamaImplConfig
async def get_adapter_impl(config: OllamaImplConfig, _deps):
from .ollama import OllamaInferenceAdapter
impl = OllamaInferenceAdapter(config)
impl = OllamaInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -6,7 +6,6 @@
import asyncio
from typing import Any
from ollama import AsyncClient as AsyncOllamaClient
@ -16,48 +15,30 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
Message,
)
from llama_stack.apis.models import Model
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import (
HealthResponse,
HealthStatus,
ModelsProtocolPrivate,
)
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
convert_image_content_to_url,
request_has_media,
)
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
InferenceProvider,
ModelsProtocolPrivate,
):
class OllamaInferenceAdapter(OpenAIMixin):
config: OllamaImplConfig
# automatically set by the resolver when instantiating the provider
__provider_id__: str
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"all-minilm:l6-v2": {
"embedding_dimension": 384,
"context_length": 512,
@ -76,29 +57,8 @@ class OllamaInferenceAdapter(
},
}
def __init__(self, config: OllamaImplConfig) -> None:
# TODO: remove ModelRegistryHelper.__init__ when completion and
# chat_completion are. this exists to satisfy the input /
# output processing for llama models. specifically,
# tool_calling is handled by raw template processing,
# instead of using the /api/chat endpoint w/ tools=...
ModelRegistryHelper.__init__(
self,
model_entries=[
build_hf_repo_model_entry(
"llama3.2:3b-instruct-fp16",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
],
)
self.config = config
# Ollama does not support image urls, so we need to download the image and convert it to base64
self.download_images = True
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
download_images: bool = True
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def ollama_client(self) -> AsyncOllamaClient:
@ -142,50 +102,6 @@ class OllamaInferenceAdapter(
async def shutdown(self) -> None:
self._clients.clear()
async def _get_model(self, model_id: str) -> Model:
if not self.model_store:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
async def _get_params(self, request: ChatCompletionRequest) -> dict:
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options.get("max_tokens") is not None:
sampling_options["num_predict"] = sampling_options["max_tokens"]
input_dict: dict[str, Any] = {}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if media_present or not llama_model:
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
# flatten the list of lists
input_dict["messages"] = [item for sublist in contents for item in sublist]
else:
input_dict["raw"] = True
input_dict["prompt"] = await chat_completion_request_to_prompt(
request,
llama_model,
)
if fmt := request.response_format:
if isinstance(fmt, JsonSchemaResponseFormat):
input_dict["format"] = fmt.json_schema
elif isinstance(fmt, GrammarResponseFormat):
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
params = {
"model": request.model,
**input_dict,
"options": sampling_options,
"stream": request.stream,
}
logger.debug(f"params to ollama: {params}")
return params
async def register_model(self, model: Model) -> Model:
if await self.check_model_availability(model.provider_model_id):
return model