diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py
index 5a10c0360..6bac2d09d 100644
--- a/llama_toolchain/inference/api/config.py
+++ b/llama_toolchain/inference/api/config.py
@@ -23,6 +23,7 @@ from .datatypes import QuantizationConfig
class ImplType(Enum):
inline = "inline"
remote = "remote"
+ ollama = "ollama"
@json_schema_type
@@ -80,10 +81,17 @@ class RemoteImplConfig(BaseModel):
url: str = Field(..., description="The URL of the remote module")
+@json_schema_type
+class OllamaImplConfig(BaseModel):
+ impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value
+ model: str = Field(..., description="The name of the model in ollama catalog")
+ url: str = Field(..., description="The URL for the ollama server")
+
+
@json_schema_type
class InferenceConfig(BaseModel):
impl_config: Annotated[
- Union[InlineImplConfig, RemoteImplConfig],
+ Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig],
Field(discriminator="impl_type"),
]
diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py
index 366e46fa1..25b5ecf4b 100644
--- a/llama_toolchain/inference/api_instance.py
+++ b/llama_toolchain/inference/api_instance.py
@@ -12,6 +12,10 @@ async def get_inference_api_instance(config: InferenceConfig):
from .inference import InferenceImpl
return InferenceImpl(config.impl_config)
+ elif config.impl_config.impl_type == ImplType.ollama.value:
+ from .inference import OllamaInference
+
+ return OllamaInference(config.impl_config)
from .client import InferenceClient
diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py
new file mode 100644
index 000000000..485e5d558
--- /dev/null
+++ b/llama_toolchain/inference/ollama.py
@@ -0,0 +1,143 @@
+import httpx
+import uuid
+
+from typing import AsyncGenerator
+
+from ollama import AsyncClient
+
+from llama_models.llama3_1.api.datatypes import (
+ BuiltinTool,
+ CompletionMessage,
+ Message,
+ StopReason,
+ ToolCall,
+)
+from llama_models.llama3_1.api.tool_utils import ToolUtils
+
+from .api.config import OllamaImplConfig
+from .api.endpoints import (
+ ChatCompletionResponse,
+ ChatCompletionRequest,
+ ChatCompletionResponseStreamChunk,
+ CompletionRequest,
+ Inference,
+)
+
+
+
+class OllamaInference(Inference):
+
+ def __init__(self, config: OllamaImplConfig) -> None:
+ self.config = config
+ self.model = config.model
+
+ async def initialize(self) -> None:
+ self.client = AsyncClient(host=self.config.url)
+ try:
+ status = await self.client.pull(self.model)
+ assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama"
+ except httpx.ConnectError:
+ print("Ollama Server is not running, start it using `ollama serve` in a separate terminal")
+ raise
+
+ async def shutdown(self) -> None:
+ pass
+
+ async def completion(self, request: CompletionRequest) -> AsyncGenerator:
+ raise NotImplementedError()
+
+ def _messages_to_ollama_messages(self, messages: list[Message]) -> list:
+ ollama_messages = []
+ for message in messages:
+ ollama_messages.append(
+ {"role": message.role, "content": message.content}
+ )
+
+ return ollama_messages
+
+ async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
+ if not request.stream:
+ r = await self.client.chat(
+ model=self.model,
+ messages=self._messages_to_ollama_messages(request.messages),
+ stream=False
+ )
+ completion_message = decode_assistant_message_from_content(
+ r['message']['content']
+ )
+
+ yield ChatCompletionResponse(
+ completion_message=completion_message,
+ logprobs=None,
+ )
+ else:
+ raise NotImplementedError()
+
+
+#TODO: Consolidate this with impl in llama-models
+def decode_assistant_message_from_content(content: str) -> CompletionMessage:
+ ipython = content.startswith("<|python_tag|>")
+ if ipython:
+ content = content[len("<|python_tag|>") :]
+
+ if content.endswith("<|eot_id|>"):
+ content = content[: -len("<|eot_id|>")]
+ stop_reason = StopReason.end_of_turn
+ elif content.endswith("<|eom_id|>"):
+ content = content[: -len("<|eom_id|>")]
+ stop_reason = StopReason.end_of_message
+ else:
+ # Ollama does not return <|eot_id|>
+ # and hence we explicitly set it as the default.
+ #TODO: Check for StopReason.out_of_tokens
+ stop_reason = StopReason.end_of_turn
+
+ tool_name = None
+ tool_arguments = {}
+
+ custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
+ if custom_tool_info is not None:
+ tool_name, tool_arguments = custom_tool_info
+ # Sometimes when agent has custom tools alongside builin tools
+ # Agent responds for builtin tool calls in the format of the custom tools
+ # This code tries to handle that case
+ if tool_name in BuiltinTool.__members__:
+ tool_name = BuiltinTool[tool_name]
+ tool_arguments = {
+ "query": list(tool_arguments.values())[0],
+ }
+ else:
+ builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
+ if builtin_tool_info is not None:
+ tool_name, query = builtin_tool_info
+ tool_arguments = {
+ "query": query,
+ }
+ if tool_name in BuiltinTool.__members__:
+ tool_name = BuiltinTool[tool_name]
+ elif ipython:
+ tool_name = BuiltinTool.code_interpreter
+ tool_arguments = {
+ "code": content,
+ }
+
+ tool_calls = []
+ if tool_name is not None and tool_arguments is not None:
+ call_id = str(uuid.uuid4())
+ tool_calls.append(
+ ToolCall(
+ call_id=call_id,
+ tool_name=tool_name,
+ arguments=tool_arguments,
+ )
+ )
+ content = ""
+
+ if stop_reason is None:
+ stop_reason = StopReason.out_of_tokens
+
+ return CompletionMessage(
+ content=content,
+ stop_reason=stop_reason,
+ tool_calls=tool_calls,
+ )
diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py
new file mode 100644
index 000000000..d37ff26c3
--- /dev/null
+++ b/tests/test_ollama_inference.py
@@ -0,0 +1,176 @@
+import textwrap
+import unittest
+from datetime import datetime
+
+from llama_models.llama3_1.api.datatypes import (
+ BuiltinTool,
+ InstructModel,
+ UserMessage,
+ StopReason,
+ SystemMessage,
+)
+
+from llama_toolchain.inference.api.endpoints import (
+ ChatCompletionRequest
+)
+from llama_toolchain.inference.api.config import (
+ OllamaImplConfig
+)
+from llama_toolchain.inference.ollama import (
+ OllamaInference
+)
+
+
+class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
+
+ async def asyncSetUp(self):
+ ollama_config = OllamaImplConfig(
+ model="llama3.1",
+ url="http://localhost:11434",
+ )
+
+ # setup ollama
+ self.inference = OllamaInference(ollama_config)
+ await self.inference.initialize()
+
+ current_date = datetime.now()
+ formatted_date = current_date.strftime("%d %B %Y")
+ self.system_prompt = SystemMessage(
+ content=textwrap.dedent(f"""
+ Environment: ipython
+ Tools: brave_search
+
+ Cutting Knowledge Date: December 2023
+ Today Date:{formatted_date}
+
+ """),
+ )
+
+ self.system_prompt_with_custom_tool = SystemMessage(
+ content=textwrap.dedent("""
+ Environment: ipython
+ Tools: brave_search, wolfram_alpha, photogen
+
+ Cutting Knowledge Date: December 2023
+ Today Date: 30 July 2024
+
+
+ You have access to the following functions:
+
+ Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)'
+ {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}}
+
+
+ Think very carefully before calling functions.
+ If you choose to call a function ONLY reply in the following format with no prefix or suffix:
+
+ {"example_name": "example_value"}
+
+ Reminder:
+ - If looking for real time information use relevant functions before falling back to brave_search
+ - Function calls MUST follow the specified format, start with
+ - Required parameters MUST be specified
+ - Only call one function at a time
+ - Put the entire function call reply on one line
+
+ """
+ ),
+ )
+
+ async def asyncTearDown(self):
+ await self.inference.shutdown()
+
+ async def test_text(self):
+ request = ChatCompletionRequest(
+ model=InstructModel.llama3_8b_chat,
+ messages=[
+ UserMessage(
+ content="What is the capital of France?",
+ ),
+ ],
+ stream=False,
+ )
+ iterator = self.inference.chat_completion(request)
+ async for r in iterator:
+ response = r
+
+ self.assertTrue("Paris" in response.completion_message.content)
+ self.assertEquals(response.completion_message.stop_reason, StopReason.end_of_turn)
+
+ async def test_tool_call(self):
+ request = ChatCompletionRequest(
+ model=InstructModel.llama3_8b_chat,
+ messages=[
+ self.system_prompt,
+ UserMessage(
+ content="Who is the current US President?",
+ ),
+ ],
+ stream=False,
+ )
+ iterator = self.inference.chat_completion(request)
+ async for r in iterator:
+ response = r
+
+ completion_message = response.completion_message
+
+ self.assertEquals(completion_message.content, "")
+ self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
+
+ self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
+ self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search)
+ self.assertTrue(
+ "president" in completion_message.tool_calls[0].arguments["query"].lower()
+ )
+
+ async def test_code_execution(self):
+ request = ChatCompletionRequest(
+ model=InstructModel.llama3_8b_chat,
+ messages=[
+ self.system_prompt,
+ UserMessage(
+ content="Write code to compute the 5th prime number",
+ ),
+ ],
+ stream=False,
+ )
+ iterator = self.inference.chat_completion(request)
+ async for r in iterator:
+ response = r
+
+ completion_message = response.completion_message
+
+ self.assertEquals(completion_message.content, "")
+ self.assertEquals(completion_message.stop_reason, StopReason.end_of_message)
+
+ self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
+ self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter)
+ code = completion_message.tool_calls[0].arguments["code"]
+ self.assertTrue("def " in code.lower(), code)
+
+ async def test_custom_tool(self):
+ request = ChatCompletionRequest(
+ model=InstructModel.llama3_8b_chat,
+ messages=[
+ self.system_prompt_with_custom_tool,
+ UserMessage(
+ content="Use provided function to find the boiling point of polyjuice in fahrenheit?",
+ ),
+ ],
+ stream=False,
+ )
+ iterator = self.inference.chat_completion(request)
+ async for r in iterator:
+ response = r
+
+ completion_message = response.completion_message
+
+ self.assertEqual(completion_message.content, "")
+ self.assertEquals(completion_message.stop_reason, StopReason.end_of_turn)
+
+ self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls)
+ self.assertEquals(completion_message.tool_calls[0].tool_name, "get_boiling_point")
+
+ args = completion_message.tool_calls[0].arguments
+ self.assertTrue(isinstance(args, dict))
+ self.assertTrue(args["liquid_name"], "polyjuice")