forked from phoenix-oss/llama-stack-mirror
added options to ollama inference
This commit is contained in:
parent
09cf3fe78b
commit
d7a4cdd70d
2 changed files with 89 additions and 13 deletions
|
@ -5,6 +5,7 @@ from typing import AsyncGenerator
|
||||||
|
|
||||||
from ollama import AsyncClient
|
from ollama import AsyncClient
|
||||||
|
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
from llama_models.llama3_1.api.datatypes import (
|
from llama_models.llama3_1.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
|
@ -29,6 +30,12 @@ from .api.endpoints import (
|
||||||
Inference,
|
Inference,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Eventually this will move to the llama cli model list command
|
||||||
|
# mapping of Model SKUs to ollama models
|
||||||
|
OLLAMA_SUPPORTED_SKUS = {
|
||||||
|
"Meta-Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16"
|
||||||
|
# TODO: Add other variants for llama3.1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class OllamaInference(Inference):
|
class OllamaInference(Inference):
|
||||||
|
@ -61,14 +68,41 @@ class OllamaInference(Inference):
|
||||||
|
|
||||||
return ollama_messages
|
return ollama_messages
|
||||||
|
|
||||||
|
def resolve_ollama_model(self, model_name: str) -> str:
|
||||||
|
model = resolve_model(model_name)
|
||||||
|
assert (
|
||||||
|
model is not None and
|
||||||
|
model.descriptor(shorten_default_variant=True) in OLLAMA_SUPPORTED_SKUS
|
||||||
|
), f"Unsupported model: {model_name}, use one of the supported models: {','.join(OLLAMA_SUPPORTED_SKUS.keys())}"
|
||||||
|
|
||||||
|
return OLLAMA_SUPPORTED_SKUS.get(model.descriptor(shorten_default_variant=True))
|
||||||
|
|
||||||
|
def get_ollama_chat_options(self, request: ChatCompletionRequest) -> dict:
|
||||||
|
options = {}
|
||||||
|
if request.sampling_params is not None:
|
||||||
|
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||||
|
if getattr(request.sampling_params, attr):
|
||||||
|
options[attr] = getattr(request.sampling_params, attr)
|
||||||
|
if (
|
||||||
|
request.sampling_params.repetition_penalty is not None and
|
||||||
|
request.sampling_params.repetition_penalty != 1.0
|
||||||
|
):
|
||||||
|
options["repeat_penalty"] = request.sampling_params.repetition_penalty
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
# accumulate sampling params and other options to pass to ollama
|
||||||
|
options = self.get_ollama_chat_options(request)
|
||||||
|
ollama_model = self.resolve_ollama_model(request.model)
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
r = await self.client.chat(
|
r = await self.client.chat(
|
||||||
model=self.model,
|
model=ollama_model,
|
||||||
messages=self._messages_to_ollama_messages(request.messages),
|
messages=self._messages_to_ollama_messages(request.messages),
|
||||||
stream=False,
|
stream=False,
|
||||||
#TODO: add support for options like temp, top_p, max_seq_length, etc
|
options=options,
|
||||||
)
|
)
|
||||||
|
stop_reason = None
|
||||||
if r['done']:
|
if r['done']:
|
||||||
if r['done_reason'] == 'stop':
|
if r['done_reason'] == 'stop':
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
|
@ -92,9 +126,10 @@ class OllamaInference(Inference):
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = await self.client.chat(
|
stream = await self.client.chat(
|
||||||
model=self.model,
|
model=ollama_model,
|
||||||
messages=self._messages_to_ollama_messages(request.messages),
|
messages=self._messages_to_ollama_messages(request.messages),
|
||||||
stream=True
|
stream=True,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
|
|
|
@ -4,9 +4,10 @@ from datetime import datetime
|
||||||
|
|
||||||
from llama_models.llama3_1.api.datatypes import (
|
from llama_models.llama3_1.api.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
InstructModel,
|
|
||||||
UserMessage,
|
UserMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
SamplingParams,
|
||||||
|
SamplingStrategy,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from llama_toolchain.inference.api_instance import (
|
from llama_toolchain.inference.api_instance import (
|
||||||
|
@ -84,13 +85,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
async def asyncTearDown(self):
|
||||||
await self.api.shutdown()
|
await self.api.shutdown()
|
||||||
|
|
||||||
async def test_text(self):
|
async def test_text(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What is the capital of France?",
|
content="What is the capital of France?",
|
||||||
|
@ -107,7 +109,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_tool_call(self):
|
async def test_tool_call(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
self.system_prompt,
|
self.system_prompt,
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -133,7 +135,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_code_execution(self):
|
async def test_code_execution(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
self.system_prompt,
|
self.system_prompt,
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -158,7 +160,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_custom_tool(self):
|
async def test_custom_tool(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
self.system_prompt_with_custom_tool,
|
self.system_prompt_with_custom_tool,
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -174,7 +176,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
completion_message = response.completion_message
|
completion_message = response.completion_message
|
||||||
|
|
||||||
self.assertEqual(completion_message.content, "")
|
self.assertEqual(completion_message.content, "")
|
||||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
self.assertTrue(
|
||||||
|
completion_message.stop_reason in {
|
||||||
|
StopReason.end_of_turn,
|
||||||
|
StopReason.end_of_message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||||
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
|
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
|
||||||
|
@ -186,7 +193,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_text_streaming(self):
|
async def test_text_streaming(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(
|
UserMessage(
|
||||||
content="What is the capital of France?",
|
content="What is the capital of France?",
|
||||||
|
@ -226,7 +233,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_tool_call_streaming(self):
|
async def test_tool_call_streaming(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
self.system_prompt,
|
self.system_prompt,
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -253,7 +260,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
async def test_custom_tool_call_streaming(self):
|
async def test_custom_tool_call_streaming(self):
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=InstructModel.llama3_8b_chat,
|
model=self.valid_supported_model,
|
||||||
messages=[
|
messages=[
|
||||||
self.system_prompt_with_custom_tool,
|
self.system_prompt_with_custom_tool,
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
@ -294,3 +301,37 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
||||||
events[-2].stop_reason,
|
events[-2].stop_reason,
|
||||||
StopReason.end_of_turn
|
StopReason.end_of_turn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_resolve_ollama_model(self):
|
||||||
|
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
|
||||||
|
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
|
||||||
|
|
||||||
|
invalid_model = "Meta-Llama3.1-8B"
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
AssertionError, f"Unsupported model: {invalid_model}"
|
||||||
|
):
|
||||||
|
self.api.resolve_ollama_model(invalid_model)
|
||||||
|
|
||||||
|
async def test_ollama_chat_options(self):
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=self.valid_supported_model,
|
||||||
|
messages=[
|
||||||
|
UserMessage(
|
||||||
|
content="What is the capital of France?",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
sampling_strategy=SamplingStrategy.top_p,
|
||||||
|
top_p=0.99,
|
||||||
|
temperature=1.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
options = self.api.get_ollama_chat_options(request)
|
||||||
|
self.assertEqual(
|
||||||
|
options,
|
||||||
|
{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 0.99,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue