From 9d644a5634ef8c3b68cb1ae5323b6cb4a76133e1 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Sat, 12 Aug 2023 14:49:49 -0700 Subject: [PATCH] ollama with streaming --- litellm/main.py | 10 +++++ litellm/tests/test_ollama.py | 62 ++++++++++++++++++++++++++++++ litellm/tests/test_ollama_local.py | 52 +++++++++++++++++++++++++ litellm/utils.py | 44 ++++++++++++++++++++- 4 files changed, 167 insertions(+), 1 deletion(-) create mode 100644 litellm/tests/test_ollama.py create mode 100644 litellm/tests/test_ollama_local.py diff --git a/litellm/main.py b/litellm/main.py index 4c3d75bf5..9bee0edd2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -9,6 +9,7 @@ import tiktoken from concurrent.futures import ThreadPoolExecutor encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args +from litellm.utils import get_ollama_response_stream, stream_to_string ####### ENVIRONMENT VARIABLES ################### dotenv.load_dotenv() # Loading env variables using dotenv new_response = { @@ -426,6 +427,15 @@ def completion( model_response["created"] = time.time() model_response["model"] = model response = model_response + elif custom_llm_provider == "ollama": + endpoint = litellm.api_base if litellm.api_base is not None else custom_api_base + prompt = " ".join([message["content"] for message in messages]) + + ## LOGGING + logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + generator = get_ollama_response_stream(endpoint, model, prompt) + # assume all responses are streamed + return generator else: ## LOGGING logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) diff --git a/litellm/tests/test_ollama.py b/litellm/tests/test_ollama.py new file mode 100644 index 000000000..d95414560 --- /dev/null +++ b/litellm/tests/test_ollama.py @@ -0,0 +1,62 @@ +###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### +# import aiohttp +# import json +# import asyncio +# import requests + +# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"): +# session = aiohttp.ClientSession() +# url = f'{api_base}/api/generate' +# data = { +# "model": model, +# "prompt": prompt, +# } + +# response = "" + +# try: +# async with session.post(url, json=data) as resp: +# async for line in resp.content.iter_any(): +# if line: +# try: +# json_chunk = line.decode("utf-8") +# chunks = json_chunk.split("\n") +# for chunk in chunks: +# if chunk.strip() != "": +# j = json.loads(chunk) +# if "response" in j: +# print(j["response"]) +# yield { +# "role": "assistant", +# "content": j["response"] +# } +# # self.responses.append(j["response"]) +# # yield "blank" +# except Exception as e: +# print(f"Error decoding JSON: {e}") +# finally: +# await session.close() + +# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"): +# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?") +# # response = "" +# # async for elem in generator: +# # print(elem) +# # response += elem["content"] +# # return response + +# # #generator = get_ollama_response_stream() + +# # result = asyncio.run(get_ollama_response_no_stream()) +# # print(result) + +# # # return this generator to the client for streaming requests + + + +# # async def get_response(): +# # global generator +# # async for elem in generator: +# # print(elem) + +# # asyncio.run(get_response()) diff --git a/litellm/tests/test_ollama_local.py b/litellm/tests/test_ollama_local.py new file mode 100644 index 000000000..22544f4cf --- /dev/null +++ b/litellm/tests/test_ollama_local.py @@ -0,0 +1,52 @@ +###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### + +# import sys, os +# import traceback +# from dotenv import load_dotenv +# load_dotenv() +# import os +# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +# import pytest +# import litellm +# from litellm import embedding, completion +# import asyncio + + + +# user_message = "respond in 20 words. who are you?" +# messages = [{ "content": user_message,"role": "user"}] + +# async def get_response(generator): +# response = "" +# async for elem in generator: +# print(elem) +# response += elem["content"] +# return response + +# def test_completion_ollama(): +# try: +# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama") +# print(response) +# string_response = asyncio.run(get_response(response)) +# print(string_response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + + +# # test_completion_ollama() + +# def test_completion_ollama_stream(): +# try: +# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True) +# print(response) +# string_response = asyncio.run(get_response(response)) +# print(string_response) +# except Exception as e: +# pytest.fail(f"Error occurred: {e}") + +# test_completion_ollama_stream() + + + + + diff --git a/litellm/utils.py b/litellm/utils.py index 65cd96a8e..2f8372e51 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -743,4 +743,46 @@ def read_config_args(config_path): return config except Exception as e: print("An error occurred while reading config:", str(e)) - raise e \ No newline at end of file + raise e + + +########## ollama implementation ############################ +import aiohttp +async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"): + session = aiohttp.ClientSession() + url = f'{api_base}/api/generate' + data = { + "model": model, + "prompt": prompt, + } + try: + async with session.post(url, json=data) as resp: + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + print(j["response"]) + yield { + "role": "assistant", + "content": j["response"] + } + # self.responses.append(j["response"]) + # yield "blank" + except Exception as e: + print(f"Error decoding JSON: {e}") + finally: + await session.close() + + +async def stream_to_string(generator): + response = "" + async for chunk in generator: + response += chunk["content"] + return response + +