From 156bfa0e154baad738e8c03a43415e2166ef839c Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 31 Jul 2024 22:08:37 -0700 Subject: [PATCH] Added Ollama as an inference impl (#20) * fix non-streaming api in inference server * unit test for inline inference * Added non-streaming ollama inference impl * add streaming support for ollama inference with tests * addressing comments --------- Co-authored-by: Hardik Shah --- llama_toolchain/inference/api/config.py | 10 +- llama_toolchain/inference/api_instance.py | 4 + llama_toolchain/inference/client.py | 19 +- llama_toolchain/inference/event_logger.py | 23 +- llama_toolchain/inference/inference.py | 30 ++- llama_toolchain/inference/ollama.py | 264 +++++++++++++++++++ requirements.txt | 1 + tests/test_inference.py | 307 ++++++++++++++++++++++ tests/test_ollama_inference.py | 296 +++++++++++++++++++++ 9 files changed, 921 insertions(+), 33 deletions(-) create mode 100644 llama_toolchain/inference/ollama.py create mode 100644 tests/test_inference.py create mode 100644 tests/test_ollama_inference.py diff --git a/llama_toolchain/inference/api/config.py b/llama_toolchain/inference/api/config.py index 5a10c0360..6bac2d09d 100644 --- a/llama_toolchain/inference/api/config.py +++ b/llama_toolchain/inference/api/config.py @@ -23,6 +23,7 @@ from .datatypes import QuantizationConfig class ImplType(Enum): inline = "inline" remote = "remote" + ollama = "ollama" @json_schema_type @@ -80,10 +81,17 @@ class RemoteImplConfig(BaseModel): url: str = Field(..., description="The URL of the remote module") +@json_schema_type +class OllamaImplConfig(BaseModel): + impl_type: Literal[ImplType.ollama.value] = ImplType.ollama.value + model: str = Field(..., description="The name of the model in ollama catalog") + url: str = Field(..., description="The URL for the ollama server") + + @json_schema_type class InferenceConfig(BaseModel): impl_config: Annotated[ - Union[InlineImplConfig, RemoteImplConfig], + Union[InlineImplConfig, RemoteImplConfig, OllamaImplConfig], Field(discriminator="impl_type"), ] diff --git a/llama_toolchain/inference/api_instance.py b/llama_toolchain/inference/api_instance.py index 366e46fa1..975de3446 100644 --- a/llama_toolchain/inference/api_instance.py +++ b/llama_toolchain/inference/api_instance.py @@ -12,6 +12,10 @@ async def get_inference_api_instance(config: InferenceConfig): from .inference import InferenceImpl return InferenceImpl(config.impl_config) + elif config.impl_config.impl_type == ImplType.ollama.value: + from .ollama import OllamaInference + + return OllamaInference(config.impl_config) from .client import InferenceClient diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index 3523e1867..3dd646457 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -14,6 +14,7 @@ from termcolor import cprint from .api import ( ChatCompletionRequest, + ChatCompletionResponse, ChatCompletionResponseStreamChunk, CompletionRequest, Inference, @@ -50,35 +51,33 @@ class InferenceClient(Inference): if line.startswith("data:"): data = line[len("data: ") :] try: - yield ChatCompletionResponseStreamChunk(**json.loads(data)) + if request.stream: + yield ChatCompletionResponseStreamChunk(**json.loads(data)) + else: + yield ChatCompletionResponse(**json.loads(data)) except Exception as e: print(data) print(f"Error with parsing or validation: {e}") -async def run_main(host: str, port: int): +async def run_main(host: str, port: int, stream: bool): client = InferenceClient(f"http://{host}:{port}") message = UserMessage(content="hello world, help me out here") cprint(f"User>{message.content}", "green") - req = ChatCompletionRequest( - model=InstructModel.llama3_70b_chat, - messages=[message], - stream=True, - ) iterator = client.chat_completion( ChatCompletionRequest( model=InstructModel.llama3_8b_chat, messages=[message], - stream=True, + stream=stream, ) ) async for log in EventLogger().log(iterator): log.print() -def main(host: str, port: int): - asyncio.run(run_main(host, port)) +def main(host: str, port: int, stream: bool = True): + asyncio.run(run_main(host, port, stream)) if __name__ == "__main__": diff --git a/llama_toolchain/inference/event_logger.py b/llama_toolchain/inference/event_logger.py index 4e29c3614..9d9434b6a 100644 --- a/llama_toolchain/inference/event_logger.py +++ b/llama_toolchain/inference/event_logger.py @@ -6,7 +6,10 @@ from termcolor import cprint -from llama_toolchain.inference.api import ChatCompletionResponseEventType +from llama_toolchain.inference.api import ( + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk +) class LogEvent: @@ -25,12 +28,16 @@ class LogEvent: class EventLogger: - async def log(self, event_generator, stream=True): + async def log(self, event_generator): async for chunk in event_generator: - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: + if isinstance(chunk, ChatCompletionResponseStreamChunk): + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.start: + yield LogEvent("Assistant> ", color="cyan", end="") + elif event.event_type == ChatCompletionResponseEventType.progress: + yield LogEvent(event.delta, color="yellow", end="") + elif event.event_type == ChatCompletionResponseEventType.complete: + yield LogEvent("") + else: yield LogEvent("Assistant> ", color="cyan", end="") - elif event.event_type == ChatCompletionResponseEventType.progress: - yield LogEvent(event.delta, color="yellow", end="") - elif event.event_type == ChatCompletionResponseEventType.complete: - yield LogEvent("") + yield LogEvent(chunk.completion_message.content, color="yellow") diff --git a/llama_toolchain/inference/inference.py b/llama_toolchain/inference/inference.py index b49736208..d7211ae65 100644 --- a/llama_toolchain/inference/inference.py +++ b/llama_toolchain/inference/inference.py @@ -16,6 +16,7 @@ from .api.datatypes import ( ToolCallParseStatus, ) from .api.endpoints import ( + ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponseStreamChunk, CompletionRequest, @@ -40,12 +41,13 @@ class InferenceImpl(Inference): raise NotImplementedError() async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta="", + if request.stream: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) ) - ) tokens = [] logprobs = [] @@ -101,13 +103,15 @@ class InferenceImpl(Inference): ) else: delta = text - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=delta, - stop_reason=stop_reason, + + if stop_reason is None: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) ) - ) if stop_reason is None: stop_reason = StopReason.out_of_tokens @@ -152,8 +156,6 @@ class InferenceImpl(Inference): # TODO(ashwin): what else do we need to send out here when everything finishes? else: yield ChatCompletionResponse( - content=message.content, - tool_calls=message.tool_calls, - stop_reason=stop_reason, + completion_message=message, logprobs=logprobs if request.logprobs else None, ) diff --git a/llama_toolchain/inference/ollama.py b/llama_toolchain/inference/ollama.py new file mode 100644 index 000000000..91727fd62 --- /dev/null +++ b/llama_toolchain/inference/ollama.py @@ -0,0 +1,264 @@ +import httpx +import uuid + +from typing import AsyncGenerator + +from ollama import AsyncClient + +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + CompletionMessage, + Message, + StopReason, + ToolCall, +) +from llama_models.llama3_1.api.tool_utils import ToolUtils + +from .api.config import OllamaImplConfig +from .api.datatypes import ( + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ToolCallDelta, + ToolCallParseStatus, +) +from .api.endpoints import ( + ChatCompletionResponse, + ChatCompletionRequest, + ChatCompletionResponseStreamChunk, + CompletionRequest, + Inference, +) + + + +class OllamaInference(Inference): + + def __init__(self, config: OllamaImplConfig) -> None: + self.config = config + self.model = config.model + + async def initialize(self) -> None: + self.client = AsyncClient(host=self.config.url) + try: + status = await self.client.pull(self.model) + assert status['status'] == 'success', f"Failed to pull model {self.model} in ollama" + except httpx.ConnectError: + print("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + raise + + async def shutdown(self) -> None: + pass + + async def completion(self, request: CompletionRequest) -> AsyncGenerator: + raise NotImplementedError() + + def _messages_to_ollama_messages(self, messages: list[Message]) -> list: + ollama_messages = [] + for message in messages: + ollama_messages.append( + {"role": message.role, "content": message.content} + ) + + return ollama_messages + + async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + if not request.stream: + r = await self.client.chat( + model=self.model, + messages=self._messages_to_ollama_messages(request.messages), + stream=False, + #TODO: add support for options like temp, top_p, max_seq_length, etc + ) + if r['done']: + if r['done_reason'] == 'stop': + stop_reason = StopReason.end_of_turn + elif r['done_reason'] == 'length': + stop_reason = StopReason.out_of_tokens + + completion_message = decode_assistant_message_from_content( + r['message']['content'], + stop_reason, + ) + yield ChatCompletionResponse( + completion_message=completion_message, + logprobs=None, + ) + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.start, + delta="", + ) + ) + + stream = await self.client.chat( + model=self.model, + messages=self._messages_to_ollama_messages(request.messages), + stream=True + ) + + buffer = "" + ipython = False + stop_reason = None + + async for chunk in stream: + # check if ollama is done + if chunk['done']: + if chunk['done_reason'] == 'stop': + stop_reason = StopReason.end_of_turn + elif chunk['done_reason'] == 'length': + stop_reason = StopReason.out_of_tokens + break + + text = chunk['message']['content'] + + # check if its a tool call ( aka starts with <|python_tag|> ) + if not ipython and text.startswith("<|python_tag|>"): + ipython = True + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) + ) + buffer = buffer[len("<|python_tag|>") :] + continue + + if ipython: + if text == "<|eot_id|>": + stop_reason = StopReason.end_of_turn + text = "" + continue + elif text == "<|eom_id|>": + stop_reason = StopReason.end_of_message + text = "" + continue + + buffer += text + delta = ToolCallDelta( + content=text, + parse_status=ToolCallParseStatus.in_progress, + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=delta, + stop_reason=stop_reason, + ) + ) + else: + buffer += text + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=text, + stop_reason=stop_reason, + ) + ) + + # parse tool calls and report errors + message = decode_assistant_message_from_content(buffer, stop_reason) + + parsed_tool_calls = len(message.tool_calls) > 0 + if ipython and not parsed_tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.failure, + ), + stop_reason=stop_reason, + ) + ) + + for tool_call in message.tool_calls: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content=tool_call, + parse_status=ToolCallParseStatus.success, + ), + stop_reason=stop_reason, + ) + ) + + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta="", + stop_reason=stop_reason, + ) + ) + + +#TODO: Consolidate this with impl in llama-models +def decode_assistant_message_from_content( + content: str, + stop_reason: StopReason, +) -> CompletionMessage: + ipython = content.startswith("<|python_tag|>") + if ipython: + content = content[len("<|python_tag|>") :] + + if content.endswith("<|eot_id|>"): + content = content[: -len("<|eot_id|>")] + stop_reason = StopReason.end_of_turn + elif content.endswith("<|eom_id|>"): + content = content[: -len("<|eom_id|>")] + stop_reason = StopReason.end_of_message + + tool_name = None + tool_arguments = {} + + custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) + if custom_tool_info is not None: + tool_name, tool_arguments = custom_tool_info + # Sometimes when agent has custom tools alongside builin tools + # Agent responds for builtin tool calls in the format of the custom tools + # This code tries to handle that case + if tool_name in BuiltinTool.__members__: + tool_name = BuiltinTool[tool_name] + tool_arguments = { + "query": list(tool_arguments.values())[0], + } + else: + builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) + if builtin_tool_info is not None: + tool_name, query = builtin_tool_info + tool_arguments = { + "query": query, + } + if tool_name in BuiltinTool.__members__: + tool_name = BuiltinTool[tool_name] + elif ipython: + tool_name = BuiltinTool.code_interpreter + tool_arguments = { + "code": content, + } + + tool_calls = [] + if tool_name is not None and tool_arguments is not None: + call_id = str(uuid.uuid4()) + tool_calls.append( + ToolCall( + call_id=call_id, + tool_name=tool_name, + arguments=tool_arguments, + ) + ) + content = "" + + if stop_reason is None: + stop_reason = StopReason.out_of_tokens + + return CompletionMessage( + content=content, + stop_reason=stop_reason, + tool_calls=tool_calls, + ) diff --git a/requirements.txt b/requirements.txt index a51bc74d9..05d642f81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ hydra-zen json-strong-typing llama-models matplotlib +ollama omegaconf pandas Pillow diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..bbd4a2c1a --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,307 @@ +# Run this test using the following command: +# python -m unittest tests/test_inference.py + +import asyncio +import os +import textwrap +import unittest + +from datetime import datetime + +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + InstructModel, + UserMessage, + StopReason, + SystemMessage, +) + +from llama_toolchain.inference.api.config import ( + ImplType, + InferenceConfig, + InlineImplConfig, + RemoteImplConfig, + ModelCheckpointConfig, + PytorchCheckpoint, + CheckpointQuantizationFormat, +) +from llama_toolchain.inference.api_instance import ( + get_inference_api_instance, +) +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, +) +from llama_toolchain.inference.api.endpoints import ( + ChatCompletionRequest +) +from llama_toolchain.inference.inference import InferenceImpl +from llama_toolchain.inference.event_logger import EventLogger + + +HELPER_MSG = """ +This test needs llama-3.1-8b-instruct models. +Please donwload using the llama cli + +llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token +""" + + +class InferenceTests(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + # This runs the async setup function + asyncio.run(cls.asyncSetUpClass()) + + @classmethod + async def asyncSetUpClass(cls): + # assert model exists on local + model_dir = os.path.expanduser("~/.llama/checkpoints/Meta-Llama-3.1-8B-Instruct/original/") + assert os.path.isdir(model_dir), HELPER_MSG + + tokenizer_path = os.path.join(model_dir, "tokenizer.model") + assert os.path.exists(tokenizer_path), HELPER_MSG + + inline_config = InlineImplConfig( + checkpoint_config=ModelCheckpointConfig( + checkpoint=PytorchCheckpoint( + checkpoint_dir=model_dir, + tokenizer_path=tokenizer_path, + model_parallel_size=1, + quantization_format=CheckpointQuantizationFormat.bf16, + ) + ), + max_seq_len=2048, + ) + inference_config = InferenceConfig( + impl_config=inline_config + ) + + # -- For faster testing iteration -- + # remote_config = RemoteImplConfig( + # url="http://localhost:5000" + # ) + # inference_config = InferenceConfig( + # impl_config=remote_config + # ) + + cls.api = await get_inference_api_instance(inference_config) + await cls.api.initialize() + + current_date = datetime.now() + formatted_date = current_date.strftime("%d %B %Y") + cls.system_prompt = SystemMessage( + content=textwrap.dedent(f""" + Environment: ipython + Tools: brave_search + + Cutting Knowledge Date: December 2023 + Today Date:{formatted_date} + + """), + ) + cls.system_prompt_with_custom_tool = SystemMessage( + content=textwrap.dedent(""" + Environment: ipython + Tools: brave_search, wolfram_alpha, photogen + + Cutting Knowledge Date: December 2023 + Today Date: 30 July 2024 + + + You have access to the following functions: + + Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' + {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} + + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + + """ + ), + ) + + @classmethod + def tearDownClass(cls): + # This runs the async teardown function + asyncio.run(cls.asyncTearDownClass()) + + @classmethod + async def asyncTearDownClass(cls): + await cls.api.shutdown() + + async def test_text(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + ) + iterator = InferenceTests.api.chat_completion(request) + + async for chunk in iterator: + response = chunk + + result = response.completion_message.content + self.assertTrue("Paris" in result, result) + + async def test_text_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=True, + ) + iterator = InferenceTests.api.chat_completion(request) + + events = [] + async for chunk in iterator: + events.append(chunk.event) + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + + response = "" + for e in events[1:-1]: + response += e.delta + + self.assertTrue("Paris" in response, response) + + async def test_custom_tool_call(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + InferenceTests.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice in fahrenheit?", + ), + ], + stream=False, + ) + iterator = InferenceTests.api.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + + # FIXME: This test fails since there is a bug where + # custom tool calls return incoorect stop_reason as out_of_tokens + # instead of end_of_turn + # self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + + args = completion_message.tool_calls[0].arguments + self.assertTrue(isinstance(args, dict)) + self.assertTrue(args["liquid_name"], "polyjuice") + + async def test_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=True, + ) + iterator = InferenceTests.api.chat_completion(request) + + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_message + ) + self.assertEqual( + events[-2].delta.content.tool_name, + BuiltinTool.brave_search + ) + + async def test_custom_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + InferenceTests.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=True, + ) + iterator = InferenceTests.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + self.assertEqual( + events[-1].stop_reason, + StopReason.end_of_turn + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_turn + ) + self.assertEqual( + events[-2].delta.content.tool_name, + "get_boiling_point" + ) diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py new file mode 100644 index 000000000..d9dbab6e9 --- /dev/null +++ b/tests/test_ollama_inference.py @@ -0,0 +1,296 @@ +import textwrap +import unittest +from datetime import datetime + +from llama_models.llama3_1.api.datatypes import ( + BuiltinTool, + InstructModel, + UserMessage, + StopReason, + SystemMessage, +) +from llama_toolchain.inference.api_instance import ( + get_inference_api_instance, +) +from llama_toolchain.inference.api.datatypes import ( + ChatCompletionResponseEventType, +) +from llama_toolchain.inference.api.endpoints import ( + ChatCompletionRequest +) +from llama_toolchain.inference.api.config import ( + InferenceConfig, + OllamaImplConfig +) +from llama_toolchain.inference.ollama import ( + OllamaInference +) + + +class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + ollama_config = OllamaImplConfig( + model="llama3.1", + url="http://localhost:11434", + ) + + # setup ollama + self.api = await get_inference_api_instance( + InferenceConfig(impl_config=ollama_config) + ) + await self.api.initialize() + + current_date = datetime.now() + formatted_date = current_date.strftime("%d %B %Y") + self.system_prompt = SystemMessage( + content=textwrap.dedent(f""" + Environment: ipython + Tools: brave_search + + Cutting Knowledge Date: December 2023 + Today Date:{formatted_date} + + """), + ) + + self.system_prompt_with_custom_tool = SystemMessage( + content=textwrap.dedent(""" + Environment: ipython + Tools: brave_search, wolfram_alpha, photogen + + Cutting Knowledge Date: December 2023 + Today Date: 30 July 2024 + + + You have access to the following functions: + + Use the function 'get_boiling_point' to 'Get the boiling point of a imaginary liquids (eg. polyjuice)' + {"name": "get_boiling_point", "description": "Get the boiling point of a imaginary liquids (eg. polyjuice)", "parameters": {"liquid_name": {"param_type": "string", "description": "The name of the liquid", "required": true}, "celcius": {"param_type": "boolean", "description": "Whether to return the boiling point in Celcius", "required": false}}} + + + Think very carefully before calling functions. + If you choose to call a function ONLY reply in the following format with no prefix or suffix: + + {"example_name": "example_value"} + + Reminder: + - If looking for real time information use relevant functions before falling back to brave_search + - Function calls MUST follow the specified format, start with + - Required parameters MUST be specified + - Only call one function at a time + - Put the entire function call reply on one line + + """ + ), + ) + + async def asyncTearDown(self): + await self.api.shutdown() + + async def test_text(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=False, + ) + iterator = self.api.chat_completion(request) + async for r in iterator: + response = r + + self.assertTrue("Paris" in response.completion_message.content) + self.assertEqual(response.completion_message.stop_reason, StopReason.end_of_turn) + + async def test_tool_call(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=False, + ) + iterator = self.api.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search) + self.assertTrue( + "president" in completion_message.tool_calls[0].arguments["query"].lower() + ) + + async def test_code_execution(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Write code to compute the 5th prime number", + ), + ], + stream=False, + ) + iterator = self.api.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_message) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter) + code = completion_message.tool_calls[0].arguments["code"] + self.assertTrue("def " in code.lower(), code) + + async def test_custom_tool(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=False, + ) + iterator = self.api.chat_completion(request) + async for r in iterator: + response = r + + completion_message = response.completion_message + + self.assertEqual(completion_message.content, "") + self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) + + self.assertEqual(len(completion_message.tool_calls), 1, completion_message.tool_calls) + self.assertEqual(completion_message.tool_calls[0].tool_name, "get_boiling_point") + + args = completion_message.tool_calls[0].arguments + self.assertTrue(isinstance(args, dict)) + self.assertTrue(args["liquid_name"], "polyjuice") + + + async def test_text_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + UserMessage( + content="What is the capital of France?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + response = "" + for e in events[1:-1]: + response += e.delta + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + # last but 1 event should be of type "progress" + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].stop_reason, + None, + ) + self.assertTrue("Paris" in response, response) + + async def test_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt, + UserMessage( + content="Who is the current US President?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + + async def test_custom_tool_call_streaming(self): + request = ChatCompletionRequest( + model=InstructModel.llama3_8b_chat, + messages=[ + self.system_prompt_with_custom_tool, + UserMessage( + content="Use provided function to find the boiling point of polyjuice?", + ), + ], + stream=True, + ) + iterator = self.api.chat_completion(request) + events = [] + async for chunk in iterator: + # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") + events.append(chunk.event) + + self.assertEqual( + events[0].event_type, + ChatCompletionResponseEventType.start + ) + # last event is of type "complete" + self.assertEqual( + events[-1].event_type, + ChatCompletionResponseEventType.complete + ) + self.assertEqual( + events[-1].stop_reason, + StopReason.end_of_turn + ) + # last but one event should be eom with tool call + self.assertEqual( + events[-2].event_type, + ChatCompletionResponseEventType.progress + ) + self.assertEqual( + events[-2].delta.content.tool_name, + "get_boiling_point" + ) + self.assertEqual( + events[-2].stop_reason, + StopReason.end_of_turn + )