mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 20:44:58 +00:00
Merge branch 'main' into remove-deprecated-embeddings
This commit is contained in:
commit
5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions
|
@ -6,13 +6,10 @@
|
|||
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
from openai import AsyncOpenAI
|
||||
from ollama import AsyncClient as AsyncOllamaClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
|
@ -32,13 +29,6 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
|
@ -46,8 +36,9 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
|
@ -56,124 +47,105 @@ from llama_stack.providers.datatypes import (
|
|||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
b64_encode_openai_embeddings_response,
|
||||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
prepare_openai_embeddings_params,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
convert_image_content_to_url,
|
||||
localize_image_content,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference::ollama")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
|
||||
embedding_model_metadata = {
|
||||
"all-minilm:l6-v2": {
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
"nomic-embed-text:latest": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:v1.5": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
"nomic-embed-text:137m-v1.5-fp16": {
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
# TODO: remove ModelRegistryHelper.__init__ when completion and
|
||||
# chat_completion are. this exists to satisfy the input /
|
||||
# output processing for llama models. specifically,
|
||||
# tool_calling is handled by raw template processing,
|
||||
# instead of using the /api/chat endpoint w/ tools=...
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=[
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.2:3b-instruct-fp16",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
],
|
||||
)
|
||||
self.config = config
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
self.download_images = True
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
def ollama_client(self) -> AsyncOllamaClient:
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
if self._openai_client is None:
|
||||
url = self.config.url.rstrip("/")
|
||||
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
|
||||
return self._openai_client
|
||||
def get_api_key(self):
|
||||
return "NO_KEY"
|
||||
|
||||
def get_base_url(self):
|
||||
return self.config.url.rstrip("/") + "/v1"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
Performs a health check by verifying connectivity to the Ollama server.
|
||||
|
@ -183,7 +155,7 @@ class OllamaInferenceAdapter(
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
await self.client.ps()
|
||||
await self.ollama_client.ps()
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
except Exception as e:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
@ -191,9 +163,6 @@ class OllamaInferenceAdapter(
|
|||
async def shutdown(self) -> None:
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def _get_model(self, model_id: str) -> Model:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set")
|
||||
|
@ -232,7 +201,7 @@ class OllamaInferenceAdapter(
|
|||
params = await self._get_params(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||
|
@ -248,7 +217,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r["done_reason"] if r["done"] else None,
|
||||
|
@ -302,7 +271,7 @@ class OllamaInferenceAdapter(
|
|||
|
||||
input_dict: dict[str, Any] = {}
|
||||
media_present = request_has_media(request)
|
||||
llama_model = self.register_helper.get_llama_model(request.model)
|
||||
llama_model = self.get_llama_model(request.model)
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
if media_present or not llama_model:
|
||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||
|
@ -340,9 +309,9 @@ class OllamaInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
if "messages" in params:
|
||||
r = await self.client.chat(**params)
|
||||
r = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
r = await self.client.generate(**params)
|
||||
r = await self.ollama_client.generate(**params)
|
||||
|
||||
if "message" in r:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -366,9 +335,9 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
if "messages" in params:
|
||||
s = await self.client.chat(**params)
|
||||
s = await self.ollama_client.chat(**params)
|
||||
else:
|
||||
s = await self.client.generate(**params)
|
||||
s = await self.ollama_client.generate(**params)
|
||||
async for chunk in s:
|
||||
if "message" in chunk:
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
|
@ -389,208 +358,16 @@ class OllamaInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
if await self.check_model_availability(model.provider_model_id):
|
||||
return model
|
||||
elif await self.check_model_availability(f"{model.provider_model_id}:latest"):
|
||||
model.provider_resource_id = f"{model.provider_model_id}:latest"
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_model_id}'"
|
||||
)
|
||||
return model
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
||||
# we use list() here instead of ps() -
|
||||
# - ps() only lists running models, not available models
|
||||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
model_obj = await self._get_model(model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||
|
||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||
params = prepare_openai_embeddings_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
input=input,
|
||||
encoding_format=encoding_format,
|
||||
dimensions=dimensions,
|
||||
user=user,
|
||||
)
|
||||
|
||||
response = await self.openai_client.embeddings.create(**params)
|
||||
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
)
|
||||
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.identifier,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Ollama does not support non-string prompts for completion")
|
||||
|
||||
model_obj = await self._get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self.openai_client.completions.create(**params) # type: ignore
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
frequency_penalty: float | None = None,
|
||||
function_call: str | dict[str, Any] | None = None,
|
||||
functions: list[dict[str, Any]] | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
# Ollama does not support image urls, so we need to download the image and convert it to base64
|
||||
async def _convert_message(m: OpenAIMessageParam) -> OpenAIMessageParam:
|
||||
if isinstance(m.content, list):
|
||||
for c in m.content:
|
||||
if c.type == "image_url" and c.image_url and c.image_url.url:
|
||||
localize_result = await localize_image_content(c.image_url.url)
|
||||
if localize_result is None:
|
||||
raise ValueError(f"Failed to localize image content from {c.image_url.url}")
|
||||
|
||||
content, format = localize_result
|
||||
c.image_url.url = f"data:image/{format};base64,{base64.b64encode(content).decode('utf-8')}"
|
||||
return m
|
||||
|
||||
messages = [await _convert_message(m) for m in messages]
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
response = await self.openai_client.chat.completions.create(**params)
|
||||
return await self._adjust_ollama_chat_completion_response_ids(response)
|
||||
|
||||
async def _adjust_ollama_chat_completion_response_ids(
|
||||
self,
|
||||
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
id = f"chatcmpl-{uuid.uuid4()}"
|
||||
if isinstance(response, AsyncIterator):
|
||||
|
||||
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
async for chunk in response:
|
||||
chunk.id = id
|
||||
yield chunk
|
||||
|
||||
return stream_with_chunk_ids()
|
||||
else:
|
||||
response.id = id
|
||||
return response
|
||||
raise UnsupportedModelError(model.provider_model_id, list(self._model_cache.keys()))
|
||||
|
||||
|
||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue