mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
added options to ollama inference
This commit is contained in:
parent
09cf3fe78b
commit
d7a4cdd70d
2 changed files with 89 additions and 13 deletions
|
@ -4,9 +4,10 @@ from datetime import datetime
|
|||
|
||||
from llama_models.llama3_1.api.datatypes import (
|
||||
BuiltinTool,
|
||||
InstructModel,
|
||||
UserMessage,
|
||||
StopReason,
|
||||
SamplingParams,
|
||||
SamplingStrategy,
|
||||
SystemMessage,
|
||||
)
|
||||
from llama_toolchain.inference.api_instance import (
|
||||
|
@ -84,13 +85,14 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
"""
|
||||
),
|
||||
)
|
||||
self.valid_supported_model = "Meta-Llama3.1-8B-Instruct"
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.api.shutdown()
|
||||
|
||||
async def test_text(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
|
@ -107,7 +109,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_tool_call(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
self.system_prompt,
|
||||
UserMessage(
|
||||
|
@ -133,7 +135,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_code_execution(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
self.system_prompt,
|
||||
UserMessage(
|
||||
|
@ -158,7 +160,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_custom_tool(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
self.system_prompt_with_custom_tool,
|
||||
UserMessage(
|
||||
|
@ -174,7 +176,12 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
completion_message = response.completion_message
|
||||
|
||||
self.assertEqual(completion_message.content, "")
|
||||
self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn)
|
||||
self.assertTrue(
|
||||
completion_message.stop_reason in {
|
||||
StopReason.end_of_turn,
|
||||
StopReason.end_of_message,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls)
|
||||
self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point")
|
||||
|
@ -186,7 +193,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_text_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
|
@ -226,7 +233,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
self.system_prompt,
|
||||
UserMessage(
|
||||
|
@ -253,7 +260,7 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
|
||||
async def test_custom_tool_call_streaming(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
self.system_prompt_with_custom_tool,
|
||||
UserMessage(
|
||||
|
@ -294,3 +301,37 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
|
|||
events[-2].stop_reason,
|
||||
StopReason.end_of_turn
|
||||
)
|
||||
|
||||
def test_resolve_ollama_model(self):
|
||||
ollama_model = self.api.resolve_ollama_model(self.valid_supported_model)
|
||||
self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16")
|
||||
|
||||
invalid_model = "Meta-Llama3.1-8B"
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, f"Unsupported model: {invalid_model}"
|
||||
):
|
||||
self.api.resolve_ollama_model(invalid_model)
|
||||
|
||||
async def test_ollama_chat_options(self):
|
||||
request = ChatCompletionRequest(
|
||||
model=self.valid_supported_model,
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is the capital of France?",
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
sampling_params=SamplingParams(
|
||||
sampling_strategy=SamplingStrategy.top_p,
|
||||
top_p=0.99,
|
||||
temperature=1.0,
|
||||
)
|
||||
)
|
||||
options = self.api.get_ollama_chat_options(request)
|
||||
self.assertEqual(
|
||||
options,
|
||||
{
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.99,
|
||||
}
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue