update inference adapters

This commit is contained in:
Xi Yan 2024-09-11 19:58:20 -07:00
parent 29d1ef3fdc
commit 7d6ebf4b72
5 changed files with 75 additions and 5 deletions

View file

@ -76,7 +76,28 @@ class FireworksInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to fireworks

View file

@ -84,7 +84,28 @@ class OllamaInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)

View file

@ -87,7 +87,28 @@ class TGIInferenceAdapter(Inference):
return options
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
async def chat_completion(
self,
model: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = list(),
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = prepare_messages(request)
model_input = self.formatter.encode_dialog_prompt(messages)

View file

@ -249,7 +249,12 @@ class InferenceTests(unittest.IsolatedAsyncioTestCase):
stream=True,
tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)],
)
iterator = self.api.chat_completion(request)
iterator = self.api.chat_completion(
request.model,
request.messages,
stream=request.stream,
tools=request.tools,
)
events = []
async for chunk in iterator:

View file

@ -61,7 +61,9 @@ class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase):
],
stream=False,
)
iterator = self.api.chat_completion(request)
iterator = self.api.chat_completion(
request.model, request.messages, stream=request.stream
)
async for r in iterator:
response = r
print(response.completion_message.content)