add streaming support for ollama inference with tests

This commit is contained in:
Hardik Shah 2024-07-31 19:33:36 -07:00
parent 0e75e73fa7
commit 0e985648f5
4 changed files with 491 additions and 61 deletions

View file

@ -7,14 +7,20 @@ from ollama import AsyncClient
from llama_models.llama3_1.api.datatypes import (
BuiltinTool,
CompletionMessage,
Message,
CompletionMessage,
Message,
StopReason,
ToolCall,
)
from llama_models.llama3_1.api.tool_utils import ToolUtils
from .api.config import OllamaImplConfig
from .api.datatypes import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ToolCallDelta,
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
@ -54,28 +60,148 @@ class OllamaInference(Inference):
)
return ollama_messages
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
if not request.stream:
r = await self.client.chat(
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=False
stream=False,
#TODO: add support for options like temp, top_p, max_seq_length, etc
)
if r['done']:
if r['done_reason'] == 'stop':
stop_reason = StopReason.end_of_turn
elif r['done_reason'] == 'length':
stop_reason = StopReason.out_of_tokens
completion_message = decode_assistant_message_from_content(
r['message']['content']
r['message']['content'],
stop_reason,
)
yield ChatCompletionResponse(
completion_message=completion_message,
logprobs=None,
)
else:
raise NotImplementedError()
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
stream = await self.client.chat(
model=self.model,
messages=self._messages_to_ollama_messages(request.messages),
stream=True
)
buffer = ""
ipython = False
stop_reason = None
async for chunk in stream:
# check if ollama is done
if chunk['done']:
if chunk['done_reason'] == 'stop':
stop_reason = StopReason.end_of_turn
elif chunk['done_reason'] == 'length':
stop_reason = StopReason.out_of_tokens
break
text = chunk['message']['content']
# check if its a tool call ( aka starts with <|python_tag|> )
if not ipython and text.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
buffer = buffer[len("<|python_tag|>") :]
continue
if ipython:
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
text = ""
continue
elif text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
text = ""
continue
buffer += text
delta = ToolCallDelta(
content=text,
parse_status=ToolCallParseStatus.in_progress,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
else:
buffer += text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=text,
stop_reason=stop_reason,
)
)
# parse tool calls and report errors
message = decode_assistant_message_from_content(buffer, stop_reason)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.failure,
),
stop_reason=stop_reason,
)
)
for tool_call in message.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content=tool_call,
parse_status=ToolCallParseStatus.success,
),
stop_reason=stop_reason,
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta="",
stop_reason=stop_reason,
)
)
#TODO: Consolidate this with impl in llama-models
def decode_assistant_message_from_content(content: str) -> CompletionMessage:
def decode_assistant_message_from_content(
content: str,
stop_reason: StopReason,
) -> CompletionMessage:
ipython = content.startswith("<|python_tag|>")
if ipython:
content = content[len("<|python_tag|>") :]
@ -86,11 +212,6 @@ def decode_assistant_message_from_content(content: str) -> CompletionMessage:
elif content.endswith("<|eom_id|>"):
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
else:
# Ollama does not return <|eot_id|>
# and hence we explicitly set it as the default.
#TODO: Check for StopReason.out_of_tokens
stop_reason = StopReason.end_of_turn
tool_name = None
tool_arguments = {}