diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 5a10c0360..6bac2d09d 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -23,6 +23,7 @@ from .datatypes import QuantizationConfig class ImplType(Enum): inline = "inline" remote = "remote" + ollama = "ollama" @json_schema_type @@ -80,10 +81,17 @@ class RemoteImplConfig(BaseModel): url: str = Field(..., description="The URL of the remote module") +@json_schema_type +class OllamaImplConfig(BaseModel): + impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value + model: str = Field(..., description="The name of the model in ollama catalog") + url: str = Field(..., description="The URL for the ollama server") + + @json_schema_type class InferenceConfig(BaseModel): impl_config: Annotated[ - Union[InlineImplConfig, RemoteImplConfig], + Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig], Field(discriminator="impl_type"), ] diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py index 366e46fa1..25b5ecf4b 100644 --- a/llama_toolchain/inference/api_instance.py +++ b/llama_toolchain/inference/api_instance.py @@ -12,6 +12,10 @@ async def get_inference_api_instance(config: InferenceConfig): from .inference import InferenceImpl return InferenceImpl(config.impl_config) + elif config.impl_config.impl_type == ImplType.ollama.value: + from .inference import OllamaInference + + return OllamaInference(config.impl_config) from .client import InferenceClient diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py new file mode 100644 index 000000000..485e5d558 --- /dev/null +++ b/llama_toolchain/inference/ollama.py @@ -0,0 +1,143 @@ +import httpx +import uuid + +from typing import AsyncGenerator + +from ollama import AsyncClient + +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + CompletionMessage, + Message, + StopReason, + ToolCall, +) +from llama_models.llama3_1.api.tool_utils import ToolUtils + +from .api.config import OllamaImplConfig +from .api.endpoints import ( + ChatCompletionResponse, + ChatCompletionRequest, + ChatCompletionResponseStreamChunk, + CompletionRequest, + Inference, +) + + + +class OllamaInference(Inference): + + def __init__(self, config: OllamaImplConfig) -> None: + self.config = config + self.model = config.model + + async def initialize(self) -> None: + self.client = AsyncClient(host=self.config.url) + try: + status = await self.client.pull(self.model) + assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama" + except httpx.ConnectError: + print("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + raise + + async def shutdown(self) -> None: + pass + + async def completion(self, request: CompletionRequest) -> AsyncGenerator: + raise NotImplementedError() + + def _messages_to_ollama_messages(self, messages: list[Message]) -> list: + ollama_messages = [] + for message in messages: + ollama_messages.append( + {"role": message.role, "content": message.content} + ) + + return ollama_messages + + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + if not request.stream: + r = await self.client.chat( + model=self.model, + messages=self._messages_to_ollama_messages(request.messages), + stream=False + ) + completion_message = decode_assistant_message_from_content( + r['message']['content'] + ) + + yield ChatCompletionResponse( + completion_message=completion_message, + logprobs=None, + ) + else: + raise NotImplementedError() + + +#TODO: Consolidate this with impl in llama-models +def decode_assistant_message_from_content(content: str) -> CompletionMessage: + ipython = content.startswith("<|python_tag|>") + if ipython: + content = content[len("<|python_tag|>") :] + + if content.endswith("<|eot_id|>"): + content = content[: -len("<|eot_id|>")] + stop_reason = StopReason.end_of_turn + elif content.endswith("<|eom_id|>"): + content = content[: -len("<|eom_id|>")] + stop_reason = StopReason.end_of_message + else: + # Ollama does not return <|eot_id|> + # and hence we explicitly set it as the default. + #TODO: Check for StopReason.out_of_tokens + stop_reason = StopReason.end_of_turn + + tool_name = None + tool_arguments = {} + + custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) + if custom_tool_info is not None: + tool_name, tool_arguments = custom_tool_info + # Sometimes when agent has custom tools alongside builin tools + # Agent responds for builtin tool calls in the format of the custom tools + # This code tries to handle that case + if tool_name in BuiltinTool.__members__: + tool_name = BuiltinTool[tool_name] + tool_arguments = { + "query": list(tool_arguments.values())[0], + } + else: + builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) + if builtin_tool_info is not None: + tool_name, query = builtin_tool_info + tool_arguments = { + "query": query, + } + if tool_name in BuiltinTool.__members__: + tool_name = BuiltinTool[tool_name] + elif ipython: + tool_name = BuiltinTool.code_interpreter + tool_arguments = { + "code": content, + } + + tool_calls = [] + if tool_name is not None and tool_arguments is not None: + call_id = str(uuid.uuid4()) + tool_calls.append( + ToolCall( + call_id=call_id, + tool_name=tool_name, + arguments=tool_arguments, + ) + ) + content = "" + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + return CompletionMessage( + content=content, + stop_reason=stop_reason, + tool_calls=tool_calls, + ) diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py new file mode 100644 index 000000000..d37ff26c3 --- /dev/null +++ b/tests/test_ollama_inference.py @@ -0,0 +1,176 @@ +import textwrap +import unittest +from datetime import datetime + +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + InstructModel, + UserMessage, + StopReason, + SystemMessage, +) + +from llama_toolchain.inference.api.endpoints import ( + ChatCompletionRequest +) +from llama_toolchain.inference.api.config import ( + OllamaImplConfig +) +from llama_toolchain.inference.ollama import ( + OllamaInference +) + + +class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + ollama_config = OllamaImplConfig( + model="llama3.1", + url="http://localhost:11434", + ) + + # setup ollama + self.inference = OllamaInference(ollama_config) + await self.inference.initialize() + + current_date = datetime.now() + formatted_date = current_date.strftime("%d %B %Y") + self.system_prompt = SystemMessage( + content=textwrap.dedent(f""" + Environment: ipython + Tools: brave_search + + Cutting Knowledge Date: December 2023 + Today Date:{formatted_date} + + """), + ) + + self.system_prompt_with_custom_tool = SystemMessage( + content=textwrap.dedent(""" + Environment: ipython + Tools: brave_search, wolfram_alpha, photogen + + Cutting Knowledge Date: December 2023 + Today Date: 30 July 2024 + + + You have access to the following functions: + + Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' + {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} + + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + + """ + ), + ) + + async def asyncTearDown(self): + await self.inference.shutdown() + + async def test_text(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + ) + iterator = self.inference.chat_completion(request) + async for r in iterator: + response = r + + self.assertTrue("Paris" in response.completion_message.content) + self.assertEquals(response.completion_message.stop_reason, StopReason.end_of_turn) + + async def test_tool_call(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=False, + ) + iterator = self.inference.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEquals(completion_message.content, "") + self.assertEquals(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) + self.assertTrue( + "president" in completion_message.tool_calls[0].arguments["query"].lower() + ) + + async def test_code_execution(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Write code to compute the 5th prime number", + ), + ], + stream=False, + ) + iterator = self.inference.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEquals(completion_message.content, "") + self.assertEquals(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEquals(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) + code = completion_message.tool_calls[0].arguments["code"] + self.assertTrue("def " in code.lower(), code) + + async def test_custom_tool(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice in fahrenheit?", + ), + ], + stream=False, + ) + iterator = self.inference.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + self.assertEquals(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEquals(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEquals(completion_message.tool_calls[0].tool_name, "get_boiling_point") + + args = completion_message.tool_calls[0].arguments + self.assertTrue(isinstance(args, dict)) + self.assertTrue(args["liquid_name"], "polyjuice")