Fixes #183: Preload models during server initialization to prevent request timeouts

This commit is contained in:
Ezreal 2024-10-11 12:27:47 +08:00
parent 9fbe8852aa
commit 4ce9314fdd

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional
import httpx import httpx
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message, ModelDef
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient from ollama import AsyncClient
@ -59,7 +59,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
pass pass
async def register_model(self, model: ModelDef) -> None: async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported") """
Registers and loads the specified model using ollama.client.pull().
"""
ollama_model = OLLAMA_SUPPORTED_MODELS.get(model.llama_model)
if not ollama_model:
raise ValueError(f"Model {model.llama_model} is not supported by Ollama.")
print(f"Registering model: {ollama_model}")
status = await self.client.pull(ollama_model)
assert (
status["status"] == "success"
), f"Failed to register model {ollama_model} in Ollama."
async def list_models(self) -> List[ModelDef]: async def list_models(self) -> List[ModelDef]:
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()} ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
@ -88,7 +99,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model: str, model: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -98,20 +109,20 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model: str, model: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
messages=messages, messages=messages,
sampling_params=sampling_params, sampling_params=sampling_params or SamplingParams(),
tools=tools or [], tools=tools or [],
tool_choice=tool_choice, tool_choice=tool_choice or ToolChoice.auto,
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format or ToolPromptFormat.json,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
@ -133,17 +144,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = self._get_params(request) params = self._get_params(request)
r = await self.client.generate(**params) response = await self.client.generate(**params)
assert isinstance(r, dict) assert isinstance(response, dict)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None, finish_reason=response["done_reason"] if response["done"] else None,
text=r["response"], text=response["response"],
) )
response = OpenAICompatCompletionResponse( openai_response = OpenAICompatCompletionResponse(
choices=[choice], choices=[choice],
) )
return process_chat_completion_response(request, response, self.formatter) return process_chat_completion_response(request, openai_response, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest self, request: ChatCompletionRequest
@ -151,8 +162,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
params = self._get_params(request) params = self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params) stream = await self.client.generate(**params)
async for chunk in s: async for chunk in stream:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None, finish_reason=chunk["done_reason"] if chunk["done"] else None,
text=chunk["response"], text=chunk["response"],
@ -161,9 +172,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
choices=[choice], choices=[choice],
) )
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response( async for chunk in process_chat_completion_stream_response(
request, stream, self.formatter request, _generate_and_convert_to_openai_compat(), self.formatter
): ):
yield chunk yield chunk