mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: turn OpenAIMixin into a pydantic.BaseModel
- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key - remove use of LiteLLMOpenAIMixin - add default initialize/shutdown methods to OpenAIMixin - remove __init__s to allow proper pydantic construction - remove dead code from vllm adapter and associated / duplicate unit tests - update vllm adapter to use openaimixin for model registration - remove ModelRegistryHelper from fireworks & together adapters - remove Inference from nvidia adapter - complete type hints on embedding_model_metadata - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc - new recordings for ollama
This commit is contained in:
parent
ce77c27ff8
commit
60f0056cbc
57 changed files with 12520 additions and 1254 deletions
|
@ -10,6 +10,6 @@ from .config import OllamaImplConfig
|
|||
async def get_adapter_impl(config: OllamaImplConfig, _deps):
|
||||
from .ollama import OllamaInferenceAdapter
|
||||
|
||||
impl = OllamaInferenceAdapter(config)
|
||||
impl = OllamaInferenceAdapter(config=config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
|
@ -16,48 +15,30 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
ModelsProtocolPrivate,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
convert_image_content_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
class OllamaInferenceAdapter(OpenAIMixin):
|
||||
config: OllamaImplConfig
|
||||
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
|
@ -76,29 +57,8 @@ class OllamaInferenceAdapter(
|
|||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
download_images: bool = True
|
||||
_clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
|
@ -142,50 +102,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
return await self.model_store.get_model(model_id)
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options.get("max_tokens") is not None:
|
||||
sampling_options["num_predict"] = sampling_options["max_tokens"]
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
# flatten the list of lists
|
||||
input_dict["messages"] = [item for sublist in contents for item in sublist]
|
||||
else:
|
||||
input_dict["raw"] = True
|
||||
input_dict["prompt"] = await chat_completion_request_to_prompt(
|
||||
request,
|
||||
llama_model,
|
||||
)
|
||||
|
||||
if fmt := request.response_format:
|
||||
if isinstance(fmt, JsonSchemaResponseFormat):
|
||||
input_dict["format"] = fmt.json_schema
|
||||
elif isinstance(fmt, GrammarResponseFormat):
|
||||
raise NotImplementedError("Grammar response format is not supported")
|
||||
else:
|
||||
raise ValueError(f"Unknown response format type: {fmt.type}")
|
||||
|
||||
params = {
|
||||
"model": request.model,
|
||||
**input_dict,
|
||||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue