Separate chat_completion stream and non-stream implementations

This is a pretty important requirement. The streaming response type is
an AsyncGenerator while the non-stream one is a single object. So far
this has worked _sometimes_ due to various pre-existing hacks (and in
some cases, just failed.)
This commit is contained in:
Ashwin Bharambe 2024-10-08 10:52:16 -07:00 committed by Ashwin Bharambe
parent f8752ab8dc
commit 0c9eb3341c
5 changed files with 346 additions and 287 deletions

View file

@ -55,7 +55,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
async def shutdown(self) -> None:
pass
async def completion(
def completion(
self,
model: str,
content: InterleavedTextMedia,
@ -79,7 +79,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
return options
async def chat_completion(
def chat_completion(
self,
model: str,
messages: List[Message],
@ -90,24 +90,7 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
ollama_model = self.map_to_provider_model(request.model)
ollama_model = self.map_to_provider_model(model)
res = await self.client.ps()
need_model_pull = True
@ -123,133 +106,166 @@ class OllamaInferenceAdapter(ModelRegistryHelper, Inference):
status["status"] == "success"
), f"Failed to pull model {self.model} in ollama"
common_params = {
"model": ollama_model,
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return self._nonstream_chat_completion(request)
def _get_params(self, request: ChatCompletionRequest) -> dict:
messages = augment_messages_for_tools(request)
model_input = self.formatter.encode_dialog_prompt(messages)
prompt = self.tokenizer.decode(model_input.tokens)
# accumulate sampling params and other options to pass to ollama
options = self.get_ollama_chat_options(request)
return {
"model": self.map_to_provider_model(request.model),
"prompt": prompt,
"options": options,
"raw": True,
"stream": request.stream,
}
if not request.stream:
r = await self.client.generate(**common_params)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = self._get_params(request)
r = await self.client.generate(**params)
stop_reason = None
if r["done"]:
if r["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
completion_message = self.formatter.decode_assistant_message_from_content(
r["response"], stop_reason
)
return ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.generate(**params)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif r["done_reason"] == "length":
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
completion_message = self.formatter.decode_assistant_message_from_content(
r["response"], stop_reason
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.generate(**common_params)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
if chunk["done"]:
if stop_reason is None and chunk["done_reason"] == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason is None and chunk["done_reason"] == "length":
stop_reason = StopReason.out_of_tokens
break
text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
text = chunk["response"]
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
parse_status=ToolCallParseStatus.started,
),
stop_reason=stop_reason,
)
)
buffer += text
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = self.formatter.decode_assistant_message_from_content(
buffer, stop_reason
)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)