ollama with streaming

This commit is contained in:
ishaan-jaff 2023-08-12 14:49:49 -07:00
parent 65e6b05f5b
commit 9d644a5634
4 changed files with 167 additions and 1 deletions

View file

@ -9,6 +9,7 @@ import tiktoken
from concurrent.futures import ThreadPoolExecutor
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
from litellm.utils import get_ollama_response_stream, stream_to_string
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
new_response = {
@ -426,6 +427,15 @@ def completion(
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif custom_llm_provider == "ollama":
endpoint = litellm.api_base if litellm.api_base is not None else custom_api_base
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
generator = get_ollama_response_stream(endpoint, model, prompt)
# assume all responses are streamed
return generator
else:
## LOGGING
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)

View file

@ -0,0 +1,62 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# import aiohttp
# import json
# import asyncio
# import requests
# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# session = aiohttp.ClientSession()
# url = f'{api_base}/api/generate'
# data = {
# "model": model,
# "prompt": prompt,
# }
# response = ""
# try:
# async with session.post(url, json=data) as resp:
# async for line in resp.content.iter_any():
# if line:
# try:
# json_chunk = line.decode("utf-8")
# chunks = json_chunk.split("\n")
# for chunk in chunks:
# if chunk.strip() != "":
# j = json.loads(chunk)
# if "response" in j:
# print(j["response"])
# yield {
# "role": "assistant",
# "content": j["response"]
# }
# # self.responses.append(j["response"])
# # yield "blank"
# except Exception as e:
# print(f"Error decoding JSON: {e}")
# finally:
# await session.close()
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
# # response = ""
# # async for elem in generator:
# # print(elem)
# # response += elem["content"]
# # return response
# # #generator = get_ollama_response_stream()
# # result = asyncio.run(get_ollama_response_no_stream())
# # print(result)
# # # return this generator to the client for streaming requests
# # async def get_response():
# # global generator
# # async for elem in generator:
# # print(elem)
# # asyncio.run(get_response())

View file

@ -0,0 +1,52 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# import sys, os
# import traceback
# from dotenv import load_dotenv
# load_dotenv()
# import os
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
# import pytest
# import litellm
# from litellm import embedding, completion
# import asyncio
# user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}]
# async def get_response(generator):
# response = ""
# async for elem in generator:
# print(elem)
# response += elem["content"]
# return response
# def test_completion_ollama():
# try:
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama")
# print(response)
# string_response = asyncio.run(get_response(response))
# print(string_response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama()
# def test_completion_ollama_stream():
# try:
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
# print(response)
# string_response = asyncio.run(get_response(response))
# print(string_response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_stream()

View file

@ -743,4 +743,46 @@ def read_config_args(config_path):
return config
except Exception as e:
print("An error occurred while reading config:", str(e))
raise e
raise e
########## ollama implementation ############################
import aiohttp
async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
session = aiohttp.ClientSession()
url = f'{api_base}/api/generate'
data = {
"model": model,
"prompt": prompt,
}
try:
async with session.post(url, json=data) as resp:
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
print(j["response"])
yield {
"role": "assistant",
"content": j["response"]
}
# self.responses.append(j["response"])
# yield "blank"
except Exception as e:
print(f"Error decoding JSON: {e}")
finally:
await session.close()
async def stream_to_string(generator):
response = ""
async for chunk in generator:
response += chunk["content"]
return response