Added Ollama as an inference impl (#20)

* fix non-streaming api in inference server

* unit test for inline inference

* Added non-streaming ollama inference impl

* add streaming support for ollama inference with tests

* addressing comments

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-07-31 22:08:37 -07:00 committed by GitHub
parent c253c1c9ad
commit 156bfa0e15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 921 additions and 33 deletions

View file

@ -16,6 +16,7 @@ from .api.datatypes import (
ToolCallParseStatus,
)
from .api.endpoints import (
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponseStreamChunk,
CompletionRequest,
@ -40,12 +41,13 @@ class InferenceImpl(Inference):
raise NotImplementedError()
async def chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta="",
)
)
)
tokens = []
logprobs = []
@ -101,13 +103,15 @@ class InferenceImpl(Inference):
)
else:
delta = text
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
if stop_reason is None:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=delta,
stop_reason=stop_reason,
)
)
)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
@ -152,8 +156,6 @@ class InferenceImpl(Inference):
# TODO(ashwin): what else do we need to send out here when everything finishes?
else:
yield ChatCompletionResponse(
content=message.content,
tool_calls=message.tool_calls,
stop_reason=stop_reason,
completion_message=message,
logprobs=logprobs if request.logprobs else None,
)