mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
ollama with streaming
This commit is contained in:
parent
65e6b05f5b
commit
9d644a5634
4 changed files with 167 additions and 1 deletions
|
@ -9,6 +9,7 @@ import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
|
from litellm.utils import get_secret, install_and_import, CustomStreamWrapper, read_config_args
|
||||||
|
from litellm.utils import get_ollama_response_stream, stream_to_string
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
new_response = {
|
new_response = {
|
||||||
|
@ -426,6 +427,15 @@ def completion(
|
||||||
model_response["created"] = time.time()
|
model_response["created"] = time.time()
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
response = model_response
|
response = model_response
|
||||||
|
elif custom_llm_provider == "ollama":
|
||||||
|
endpoint = litellm.api_base if litellm.api_base is not None else custom_api_base
|
||||||
|
prompt = " ".join([message["content"] for message in messages])
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn)
|
||||||
|
generator = get_ollama_response_stream(endpoint, model, prompt)
|
||||||
|
# assume all responses are streamed
|
||||||
|
return generator
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
logging(model=model, input=messages, azure=azure, logger_fn=logger_fn)
|
||||||
|
|
62
litellm/tests/test_ollama.py
Normal file
62
litellm/tests/test_ollama.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||||
|
# import aiohttp
|
||||||
|
# import json
|
||||||
|
# import asyncio
|
||||||
|
# import requests
|
||||||
|
|
||||||
|
# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||||
|
# session = aiohttp.ClientSession()
|
||||||
|
# url = f'{api_base}/api/generate'
|
||||||
|
# data = {
|
||||||
|
# "model": model,
|
||||||
|
# "prompt": prompt,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# response = ""
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# async with session.post(url, json=data) as resp:
|
||||||
|
# async for line in resp.content.iter_any():
|
||||||
|
# if line:
|
||||||
|
# try:
|
||||||
|
# json_chunk = line.decode("utf-8")
|
||||||
|
# chunks = json_chunk.split("\n")
|
||||||
|
# for chunk in chunks:
|
||||||
|
# if chunk.strip() != "":
|
||||||
|
# j = json.loads(chunk)
|
||||||
|
# if "response" in j:
|
||||||
|
# print(j["response"])
|
||||||
|
# yield {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": j["response"]
|
||||||
|
# }
|
||||||
|
# # self.responses.append(j["response"])
|
||||||
|
# # yield "blank"
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error decoding JSON: {e}")
|
||||||
|
# finally:
|
||||||
|
# await session.close()
|
||||||
|
|
||||||
|
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||||
|
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
||||||
|
# # response = ""
|
||||||
|
# # async for elem in generator:
|
||||||
|
# # print(elem)
|
||||||
|
# # response += elem["content"]
|
||||||
|
# # return response
|
||||||
|
|
||||||
|
# # #generator = get_ollama_response_stream()
|
||||||
|
|
||||||
|
# # result = asyncio.run(get_ollama_response_no_stream())
|
||||||
|
# # print(result)
|
||||||
|
|
||||||
|
# # # return this generator to the client for streaming requests
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # async def get_response():
|
||||||
|
# # global generator
|
||||||
|
# # async for elem in generator:
|
||||||
|
# # print(elem)
|
||||||
|
|
||||||
|
# # asyncio.run(get_response())
|
52
litellm/tests/test_ollama_local.py
Normal file
52
litellm/tests/test_ollama_local.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||||
|
|
||||||
|
# import sys, os
|
||||||
|
# import traceback
|
||||||
|
# from dotenv import load_dotenv
|
||||||
|
# load_dotenv()
|
||||||
|
# import os
|
||||||
|
# sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path
|
||||||
|
# import pytest
|
||||||
|
# import litellm
|
||||||
|
# from litellm import embedding, completion
|
||||||
|
# import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# user_message = "respond in 20 words. who are you?"
|
||||||
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
|
||||||
|
# async def get_response(generator):
|
||||||
|
# response = ""
|
||||||
|
# async for elem in generator:
|
||||||
|
# print(elem)
|
||||||
|
# response += elem["content"]
|
||||||
|
# return response
|
||||||
|
|
||||||
|
# def test_completion_ollama():
|
||||||
|
# try:
|
||||||
|
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama")
|
||||||
|
# print(response)
|
||||||
|
# string_response = asyncio.run(get_response(response))
|
||||||
|
# print(string_response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# # test_completion_ollama()
|
||||||
|
|
||||||
|
# def test_completion_ollama_stream():
|
||||||
|
# try:
|
||||||
|
# response = completion(model="llama2", messages=messages, custom_api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
|
||||||
|
# print(response)
|
||||||
|
# string_response = asyncio.run(get_response(response))
|
||||||
|
# print(string_response)
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# test_completion_ollama_stream()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -743,4 +743,46 @@ def read_config_args(config_path):
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("An error occurred while reading config:", str(e))
|
print("An error occurred while reading config:", str(e))
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
########## ollama implementation ############################
|
||||||
|
import aiohttp
|
||||||
|
async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||||
|
session = aiohttp.ClientSession()
|
||||||
|
url = f'{api_base}/api/generate'
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
async with session.post(url, json=data) as resp:
|
||||||
|
async for line in resp.content.iter_any():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
json_chunk = line.decode("utf-8")
|
||||||
|
chunks = json_chunk.split("\n")
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.strip() != "":
|
||||||
|
j = json.loads(chunk)
|
||||||
|
if "response" in j:
|
||||||
|
print(j["response"])
|
||||||
|
yield {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": j["response"]
|
||||||
|
}
|
||||||
|
# self.responses.append(j["response"])
|
||||||
|
# yield "blank"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding JSON: {e}")
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_to_string(generator):
|
||||||
|
response = ""
|
||||||
|
async for chunk in generator:
|
||||||
|
response += chunk["content"]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue