diff --git a/llama_toolchain/inference/adapters/together/together.py b/llama_toolchain/inference/adapters/together/together.py index 4800de6ad..76403a85b 100644 --- a/llama_toolchain/inference/adapters/together/together.py +++ b/llama_toolchain/inference/adapters/together/together.py @@ -76,7 +76,29 @@ class TogetherInferenceAdapter(Inference): return options - async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + async def chat_completion( + self, + model: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = SamplingParams(), + tools: Optional[List[ToolDefinition]] = list(), + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + # wrapper request to make it easier to pass around (internal only, not exposed to API) + request = ChatCompletionRequest( + model=model, + messages=messages, + sampling_params=sampling_params, + tools=tools, + tool_choice=tool_choice, + tool_prompt_format=tool_prompt_format, + stream=stream, + logprobs=logprobs, + ) + # accumulate sampling params and other options to pass to together options = self.get_together_chat_options(request) together_model = self.resolve_together_model(request.model)