mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Fixes #183: Preload models during server initialization to prevent request timeouts
This commit is contained in:
parent
9fbe8852aa
commit
4ce9314fdd
1 changed files with 30 additions and 20 deletions
|
@ -4,12 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message, ModelDef
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
@ -59,7 +59,18 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_model(self, model: ModelDef) -> None:
|
async def register_model(self, model: ModelDef) -> None:
|
||||||
raise ValueError("Dynamic model registration is not supported")
|
"""
|
||||||
|
Registers and loads the specified model using ollama.client.pull().
|
||||||
|
"""
|
||||||
|
ollama_model = OLLAMA_SUPPORTED_MODELS.get(model.llama_model)
|
||||||
|
if not ollama_model:
|
||||||
|
raise ValueError(f"Model {model.llama_model} is not supported by Ollama.")
|
||||||
|
|
||||||
|
print(f"Registering model: {ollama_model}")
|
||||||
|
status = await self.client.pull(ollama_model)
|
||||||
|
assert (
|
||||||
|
status["status"] == "success"
|
||||||
|
), f"Failed to register model {ollama_model} in Ollama."
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelDef]:
|
async def list_models(self) -> List[ModelDef]:
|
||||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||||
|
@ -88,7 +99,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content: InterleavedTextMedia,
|
content: InterleavedTextMedia,
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
@ -98,20 +109,20 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = None,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params or SamplingParams(),
|
||||||
tools=tools or [],
|
tools=tools or [],
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice or ToolChoice.auto,
|
||||||
tool_prompt_format=tool_prompt_format,
|
tool_prompt_format=tool_prompt_format or ToolPromptFormat.json,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
@ -133,17 +144,17 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
r = await self.client.generate(**params)
|
response = await self.client.generate(**params)
|
||||||
assert isinstance(r, dict)
|
assert isinstance(response, dict)
|
||||||
|
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=r["done_reason"] if r["done"] else None,
|
finish_reason=response["done_reason"] if response["done"] else None,
|
||||||
text=r["response"],
|
text=response["response"],
|
||||||
)
|
)
|
||||||
response = OpenAICompatCompletionResponse(
|
openai_response = OpenAICompatCompletionResponse(
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
return process_chat_completion_response(request, response, self.formatter)
|
return process_chat_completion_response(request, openai_response, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest
|
self, request: ChatCompletionRequest
|
||||||
|
@ -151,8 +162,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
params = self._get_params(request)
|
params = self._get_params(request)
|
||||||
|
|
||||||
async def _generate_and_convert_to_openai_compat():
|
async def _generate_and_convert_to_openai_compat():
|
||||||
s = await self.client.generate(**params)
|
stream = await self.client.generate(**params)
|
||||||
async for chunk in s:
|
async for chunk in stream:
|
||||||
choice = OpenAICompatCompletionChoice(
|
choice = OpenAICompatCompletionChoice(
|
||||||
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
finish_reason=chunk["done_reason"] if chunk["done"] else None,
|
||||||
text=chunk["response"],
|
text=chunk["response"],
|
||||||
|
@ -161,9 +172,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
choices=[choice],
|
choices=[choice],
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = _generate_and_convert_to_openai_compat()
|
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
request, stream, self.formatter
|
request, _generate_and_convert_to_openai_compat(), self.formatter
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue