update inference adapters

This commit is contained in:
Xi Yan 2024-09-11 19:58:20 -07:00
parent 29d1ef3fdc
commit 7d6ebf4b72
5 changed files with 75 additions and 5 deletions

View file

@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
# accumulate sampling params and other options to pass to fireworks # accumulate sampling params and other options to pass to fireworks

View file

@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama # accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request) options = self.get_ollama_chat_options(request)

View file

@ -87,7 +87,28 @@ class TGIInferenceAdapter(Inference):
return options return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request) messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages) model_input = self.formatter.encode_dialog_prompt(messages)

View file

@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
stream=True, stream=True,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
) )
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(
request.model,
request.messages,
stream=request.stream,
tools=request.tools,
)
events = [] events = []
async for chunk in iterator: async for chunk in iterator:

View file

@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
], ],
stream=False, stream=False,
) )
iterator = self.api.chat_completion(request) iterator = self.api.chat_completion(
request.model, request.messages, stream=request.stream
)
async for r in iterator: async for r in iterator:
response = r response = r
print(response.completion_message.content) print(response.completion_message.content)