Added Ollama as an inference impl (#20)

* fix non-streaming api in inference server

* unit test for inline inference

* Added non-streaming ollama inference impl

* add streaming support for ollama inference with tests

* addressing comments

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
Hardik Shah 2024-07-31 22:08:37 -07:00 committed by GitHub
parent c253c1c9ad
commit 156bfa0e15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 921 additions and 33 deletions

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from .api import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
Inference,
@ -50,35 +51,33 @@ class InferenceClient(Inference):
if line.startswith("data:"):
data = line[len("data: ") :]
try:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
if request.stream:
yield ChatCompletionResponseStreamChunk(**json.loads(data))
else:
yield ChatCompletionResponse(**json.loads(data))
except Exception as e:
print(data)
print(f"Error with parsing or validation: {e}")
async def run_main(host: str, port: int):
async def run_main(host: str, port: int, stream: bool):
client = InferenceClient(f"http://{host}:{port}")
message = UserMessage(content="hello world, help me out here")
cprint(f"User>{message.content}", "green")
req = ChatCompletionRequest(
model=InstructModel.llama3_70b_chat,
messages=[message],
stream=True,
)
iterator = client.chat_completion(
ChatCompletionRequest(
model=InstructModel.llama3_8b_chat,
messages=[message],
stream=True,
stream=stream,
)
)
async for log in EventLogger().log(iterator):
log.print()
def main(host: str, port: int):
asyncio.run(run_main(host, port))
def main(host: str, port: int, stream: bool = True):
asyncio.run(run_main(host, port, stream))
if __name__ == "__main__":