mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Added Ollama as an inference impl (#20)
* fix non-streaming api in inference server * unit test for inline inference * Added non-streaming ollama inference impl * add streaming support for ollama inference with tests * addressing comments --------- Co-authored-by: Hardik Shah <hjshah@fb.com>
This commit is contained in:
parent
c253c1c9ad
commit
156bfa0e15
9 changed files with 921 additions and 33 deletions
|
@ -14,6 +14,7 @@ from termcolor import cprint
|
|||
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionRequest,
|
||||
Inference,
|
||||
|
@ -50,35 +51,33 @@ class InferenceClient(Inference):
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(**json.loads(data))
|
||||
else:
|
||||
yield ChatCompletionResponse(**json.loads(data))
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
async def run_main(host: str, port: int, stream: bool):
|
||||
client = InferenceClient(f"http://{host}:{port}")
|
||||
|
||||
message = UserMessage(content="hello world, help me out here")
|
||||
cprint(f"User>{message.content}", "green")
|
||||
req = ChatCompletionRequest(
|
||||
model=InstructModel.llama3_70b_chat,
|
||||
messages=[message],
|
||||
stream=True,
|
||||
)
|
||||
iterator = client.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model=InstructModel.llama3_8b_chat,
|
||||
messages=[message],
|
||||
stream=True,
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
async for log in EventLogger().log(iterator):
|
||||
log.print()
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue