mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge c67bae2d07
into cbe89d2bdd
This commit is contained in:
commit
78de1af5c8
1 changed files with 43 additions and 34 deletions
|
@ -20,7 +20,6 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -55,7 +54,6 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
HealthStatus,
|
HealthStatus,
|
||||||
ModelsProtocolPrivate,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama.config import OllamaImplConfig
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
@ -90,13 +88,13 @@ logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
InferenceProvider,
|
InferenceProvider,
|
||||||
ModelsProtocolPrivate,
|
ModelRegistryHelper,
|
||||||
):
|
):
|
||||||
# automatically set by the resolver when instantiating the provider
|
# automatically set by the resolver when instantiating the provider
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, config: OllamaImplConfig) -> None:
|
def __init__(self, config: OllamaImplConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
self.config = config
|
self.config = config
|
||||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||||
self._openai_client = None
|
self._openai_client = None
|
||||||
|
@ -193,6 +191,41 @@ class OllamaInferenceAdapter(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific model is available in Ollama.
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
available_models = await self._query_available_models()
|
||||||
|
return model in available_models
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking model availability: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _query_available_models(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Query Ollama for available models.
|
||||||
|
|
||||||
|
Ollama allows omitting the `:latest` suffix, so we include some-name:latest as some-name and some-name:latest.
|
||||||
|
|
||||||
|
:return: A list of model identifiers (provider_model_ids).
|
||||||
|
"""
|
||||||
|
available_models = []
|
||||||
|
try:
|
||||||
|
# we use list() here instead of ps() -
|
||||||
|
# - ps() only lists running models, not available models
|
||||||
|
# - models not currently running are run by the ollama server as needed
|
||||||
|
for m in (await self.client.list()).models:
|
||||||
|
available_models.append(m.model)
|
||||||
|
if m.model.endswith(":latest"):
|
||||||
|
available_models.append(m.model[: -len(":latest")])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to query available models from Ollama: {e}")
|
||||||
|
return available_models
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self._clients.clear()
|
self._clients.clear()
|
||||||
|
|
||||||
|
@ -307,7 +340,7 @@ class OllamaInferenceAdapter(
|
||||||
|
|
||||||
input_dict: dict[str, Any] = {}
|
input_dict: dict[str, Any] = {}
|
||||||
media_present = request_has_media(request)
|
media_present = request_has_media(request)
|
||||||
llama_model = self.register_helper.get_llama_model(request.model)
|
llama_model = self.get_llama_model(request.model)
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
if media_present or not llama_model:
|
if media_present or not llama_model:
|
||||||
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages]
|
||||||
|
@ -415,38 +448,14 @@ class OllamaInferenceAdapter(
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
try:
|
|
||||||
model = await self.register_helper.register_model(model)
|
|
||||||
except ValueError:
|
|
||||||
pass # Ignore statically unknown model, will check live listing
|
|
||||||
|
|
||||||
if model.provider_resource_id is None:
|
|
||||||
raise ValueError("Model provider_resource_id cannot be None")
|
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
response = await self.client.list()
|
assert model.provider_resource_id, "Embedding models must have a provider_resource_id set"
|
||||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||||
|
# TODO: you should pull here only if the model is not found in a list
|
||||||
|
if not await self.check_model_availability(model.provider_resource_id):
|
||||||
await self.client.pull(model.provider_resource_id)
|
await self.client.pull(model.provider_resource_id)
|
||||||
|
|
||||||
# we use list() here instead of ps() -
|
return await ModelRegistryHelper.register_model(self, model)
|
||||||
# - ps() only lists running models, not available models
|
|
||||||
# - models not currently running are run by the ollama server as needed
|
|
||||||
response = await self.client.list()
|
|
||||||
available_models = [m.model for m in response.models]
|
|
||||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
|
||||||
if provider_resource_id is None:
|
|
||||||
provider_resource_id = model.provider_resource_id
|
|
||||||
if provider_resource_id not in available_models:
|
|
||||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
|
||||||
if provider_resource_id in available_models_latest:
|
|
||||||
logger.warning(
|
|
||||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
|
||||||
)
|
|
||||||
return model
|
|
||||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
|
||||||
model.provider_resource_id = provider_resource_id
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue