diff --git a/llama_toolchain/inference/adapters/fireworks/fireworks.py b/llama_toolchain/inference/adapters/fireworks/fireworks.py index b0eb41017..e51a730de 100644 --- a/llama_toolchain/inference/adapters/fireworks/fireworks.py +++ b/llama_toolchain/inference/adapters/fireworks/fireworks.py @@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference): return options - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = list(), + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + messages = prepare_messages(request) # accumulate sampling params and other options to pass to fireworks diff --git a/llama_toolchain/inference/adapters/ollama/ollama.py b/llama_toolchain/inference/adapters/ollama/ollama.py index 375257ea9..92fbf7585 100644 --- a/llama_toolchain/inference/adapters/ollama/ollama.py +++ b/llama_toolchain/inference/adapters/ollama/ollama.py @@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference): return options - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = list(), + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + messages = prepare_messages(request) # accumulate sampling params and other options to pass to ollama options = self.get_ollama_chat_options(request) diff --git a/llama_toolchain/inference/adapters/tgi/tgi.py b/llama_toolchain/inference/adapters/tgi/tgi.py index 7eb36ac36..b6c699b1d 100644 --- a/llama_toolchain/inference/adapters/tgi/tgi.py +++ b/llama_toolchain/inference/adapters/tgi/tgi.py @@ -87,7 +87,28 @@ class TGIInferenceAdapter(Inference): return options - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = list(), + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + messages = prepare_messages(request) model_input = self.formatter.encode_dialog_prompt(messages) diff --git a/tests/test_inference.py b/tests/test_inference.py index 277cf7e8a..800046355 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase): stream=True, tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], ) - iterator = self.api.chat_completion(request) + iterator = self.api.chat_completion( + request.model, + request.messages, + stream=request.stream, + tools=request.tools, + ) events = [] async for chunk in iterator: diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py index f5b172e69..c3cef3a10 100644 --- a/tests/test_ollama_inference.py +++ b/tests/test_ollama_inference.py @@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): ], stream=False, ) - iterator = self.api.chat_completion(request) + iterator = self.api.chat_completion( + request.model, request.messages, stream=request.stream + ) async for r in iterator: response = r print(response.completion_message.content)