olla upgrades, fix streaming, add non streaming resp

This commit is contained in:
ishaan-jaff 2023-09-09 14:07:11 -07:00
parent 6cb03d7c63
commit 56bd8c1c52
5 changed files with 135 additions and 86 deletions

35
litellm/llms/ollama.py Normal file
View file

@ -0,0 +1,35 @@
import requests
import json
# ollama implementation
def get_ollama_response_stream(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?"
):
url = f"{api_base}/api/generate"
data = {
"model": model,
"prompt": prompt,
}
session = requests.Session()
with session.post(url, json=data, stream=True) as resp:
for line in resp.iter_lines():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield {"choices": [{"delta": completion_obj}]}
except Exception as e:
print(f"Error decoding JSON: {e}")
session.close()

View file

@ -28,6 +28,7 @@ from .llms import replicate
from .llms import aleph_alpha
from .llms import baseten
from .llms import vllm
from .llms import ollama
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
@ -39,9 +40,6 @@ from litellm.utils import (
ModelResponse,
read_config_args,
)
from litellm.utils import (
get_ollama_response_stream,
)
####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv
@ -728,10 +726,27 @@ def completion(
logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
)
generator = get_ollama_response_stream(endpoint, model, prompt)
# assume all responses are streamed
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
return generator
else:
response_string = ""
for chunk in generator:
response_string+=chunk['choices'][0]['delta']['content']
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = response_string
model_response["created"] = time.time()
model_response["model"] = "ollama/" + model
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(response_string))
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response
elif (
custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co"

View file

@ -1,4 +1,4 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# import aiohttp
# import json
# import asyncio
@ -37,25 +37,64 @@
# finally:
# await session.close()
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
# # response = ""
# # async for elem in generator:
# # print(elem)
# # response += elem["content"]
# # return response
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
# response = ""
# async for elem in generator:
# print(elem)
# response += elem["content"]
# return response
# # #generator = get_ollama_response_stream()
# #generator = get_ollama_response_stream()
# # result = asyncio.run(get_ollama_response_no_stream())
# # print(result)
# result = asyncio.run(get_ollama_response_no_stream())
# print(result)
# # # return this generator to the client for streaming requests
# # return this generator to the client for streaming requests
# # async def get_response():
# # global generator
# # async for elem in generator:
# # print(elem)
# async def get_response():
# global generator
# async for elem in generator:
# print(elem)
# # asyncio.run(get_response())
# asyncio.run(get_response())
##### latest implementation of making raw http post requests to local ollama server
# import requests
# import json
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# url = f"{api_base}/api/generate"
# data = {
# "model": model,
# "prompt": prompt,
# }
# session = requests.Session()
# with session.post(url, json=data, stream=True) as resp:
# for line in resp.iter_lines():
# if line:
# try:
# json_chunk = line.decode("utf-8")
# chunks = json_chunk.split("\n")
# for chunk in chunks:
# if chunk.strip() != "":
# j = json.loads(chunk)
# if "response" in j:
# completion_obj = {
# "role": "assistant",
# "content": "",
# }
# completion_obj["content"] = j["response"]
# yield {"choices": [{"delta": completion_obj}]}
# except Exception as e:
# print(f"Error decoding JSON: {e}")
# session.close()
# response = get_ollama_response_stream()
# for chunk in response:
# print(chunk['choices'][0]['delta'])

View file

@ -1,4 +1,5 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# # https://ollama.ai/
# import sys, os
# import traceback
@ -15,32 +16,36 @@
# user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}]
# async def get_response(generator):
# response = ""
# async for elem in generator:
# print(elem)
# response += elem["content"]
# return response
# def test_completion_ollama():
# try:
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama")
# response = completion(
# model="llama2",
# messages=messages,
# api_base="http://localhost:11434",
# custom_llm_provider="ollama"
# )
# print(response)
# string_response = asyncio.run(get_response(response))
# print(string_response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama()
# test_completion_ollama()
# def test_completion_ollama_stream():
# user_message = "what is litellm?"
# messages = [{ "content": user_message,"role": "user"}]
# try:
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
# response = completion(
# model="llama2",
# messages=messages,
# api_base="http://localhost:11434",
# custom_llm_provider="ollama",
# stream=True
# )
# print(response)
# string_response = asyncio.run(get_response(response))
# print(string_response)
# for chunk in response:
# print(chunk['choices'][0]['delta'])
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_stream()
# # test_completion_ollama_stream()

View file

@ -2217,51 +2217,6 @@ def read_config_args(config_path):
print("An error occurred while reading config:", str(e))
raise e
########## ollama implementation ############################
async def get_ollama_response_stream(
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
):
session = aiohttp.ClientSession()
url = f"{api_base}/api/generate"
data = {
"model": model,
"prompt": prompt,
}
try:
async with session.post(url, json=data) as resp:
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield {"choices": [{"delta": completion_obj}]}
# self.responses.append(j["response"])
# yield "blank"
except Exception as e:
print(f"Error decoding JSON: {e}")
finally:
await session.close()
async def stream_to_string(generator):
response = ""
async for chunk in generator:
response += chunk["content"]
return response
########## experimental completion variants ############################
def get_model_split_test(models, completion_call_id):