mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge remote-tracking branch 'origin/main' into distros
This commit is contained in:
commit
f64668319c
3 changed files with 171 additions and 137 deletions
|
@ -9,7 +9,6 @@ import uuid
|
|||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
CompletionMessage,
|
||||
|
@ -19,6 +18,8 @@ from llama_models.llama3_1.api.datatypes import (
|
|||
)
|
||||
from llama_models.llama3_1.api.tool_utils import ToolUtils
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from ollama import AsyncClient
|
||||
|
||||
from .api.config import OllamaImplConfig
|
||||
|
@ -36,6 +37,13 @@ from .api.endpoints import (
|
|||
Inference,
|
||||
)
|
||||
|
||||
# TODO: Eventually this will move to the llama cli model list command
|
||||
# mapping of Model SKUs to ollama models
|
||||
OLLAMA_SUPPORTED_SKUS = {
|
||||
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
|
||||
# TODO: Add other variants for llama3.1
|
||||
}
|
||||
|
||||
|
||||
def get_provider_impl(config: OllamaImplConfig) -> Inference:
|
||||
assert isinstance(
|
||||
|
@ -76,14 +84,41 @@ class OllamaInference(Inference):
|
|||
|
||||
return ollama_messages
|
||||
|
||||
def resolve_ollama_model(self, model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert (
|
||||
model is not None
|
||||
and model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
|
||||
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
|
||||
|
||||
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
|
||||
|
||||
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||
options = {}
|
||||
if request.sampling_params is not None:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(request.sampling_params, attr):
|
||||
options[attr] = getattr(request.sampling_params, attr)
|
||||
if (
|
||||
request.sampling_params.repetition_penalty is not None
|
||||
and request.sampling_params.repetition_penalty != 1.0
|
||||
):
|
||||
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||
|
||||
return options
|
||||
|
||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||
# accumulate sampling params and other options to pass to ollama
|
||||
options = self.get_ollama_chat_options(request)
|
||||
ollama_model = self.resolve_ollama_model(request.model)
|
||||
if not request.stream:
|
||||
r = await self.client.chat(
|
||||
model=self.model,
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(request.messages),
|
||||
stream=False,
|
||||
# TODO: add support for options like temp, top_p, max_seq_length, etc
|
||||
options=options,
|
||||
)
|
||||
stop_reason = None
|
||||
if r["done"]:
|
||||
if r["done_reason"] == "stop":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
@ -107,9 +142,10 @@ class OllamaInference(Inference):
|
|||
)
|
||||
|
||||
stream = await self.client.chat(
|
||||
model=self.model,
|
||||
model=ollama_model,
|
||||
messages=self._messages_to_ollama_messages(request.messages),
|
||||
stream=True,
|
||||
options=options,
|
||||
)
|
||||
|
||||
buffer = ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue