olla upgrades, fix streaming, add non streaming resp

This commit is contained in:
ishaan-jaff 2023-09-09 14:07:11 -07:00
parent 6cb03d7c63
commit 56bd8c1c52
5 changed files with 135 additions and 86 deletions

35
litellm/llms/ollama.py Normal file
View file

@ -0,0 +1,35 @@
import requests
import json
# ollama implementation
def get_ollama_response_stream(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?"
):
url = f"{api_base}/api/generate"
data = {
"model": model,
"prompt": prompt,
}
session = requests.Session()
with session.post(url, json=data, stream=True) as resp:
for line in resp.iter_lines():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield {"choices": [{"delta": completion_obj}]}
except Exception as e:
print(f"Error decoding JSON: {e}")
session.close()

View file

@ -28,6 +28,7 @@ from .llms import replicate
from .llms import aleph_alpha from .llms import aleph_alpha
from .llms import baseten from .llms import baseten
from .llms import vllm from .llms import vllm
from .llms import ollama
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict from typing import Callable, List, Optional, Dict
@ -39,9 +40,6 @@ from litellm.utils import (
ModelResponse, ModelResponse,
read_config_args, read_config_args,
) )
from litellm.utils import (
get_ollama_response_stream,
)
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -728,10 +726,27 @@ def completion(
logging.pre_call( logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint} input=prompt, api_key=None, additional_args={"endpoint": endpoint}
) )
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
return generator
else:
response_string = ""
for chunk in generator:
response_string+=chunk['choices'][0]['delta']['content']
generator = get_ollama_response_stream(endpoint, model, prompt) ## RESPONSE OBJECT
# assume all responses are streamed model_response["choices"][0]["message"]["content"] = response_string
return generator model_response["created"] = time.time()
model_response["model"] = "ollama/" + model
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(encoding.encode(response_string))
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
response = model_response
elif ( elif (
custom_llm_provider == "baseten" custom_llm_provider == "baseten"
or litellm.api_base == "https://app.baseten.co" or litellm.api_base == "https://app.baseten.co"

View file

@ -1,4 +1,4 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# import aiohttp # import aiohttp
# import json # import json
# import asyncio # import asyncio
@ -37,25 +37,64 @@
# finally: # finally:
# await session.close() # await session.close()
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"): # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?") # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
# # response = "" # response = ""
# # async for elem in generator: # async for elem in generator:
# # print(elem) # print(elem)
# # response += elem["content"] # response += elem["content"]
# # return response # return response
# # #generator = get_ollama_response_stream() # #generator = get_ollama_response_stream()
# # result = asyncio.run(get_ollama_response_no_stream()) # result = asyncio.run(get_ollama_response_no_stream())
# # print(result) # print(result)
# # # return this generator to the client for streaming requests # # return this generator to the client for streaming requests
# # async def get_response(): # async def get_response():
# # global generator # global generator
# # async for elem in generator: # async for elem in generator:
# # print(elem) # print(elem)
# # asyncio.run(get_response()) # asyncio.run(get_response())
##### latest implementation of making raw http post requests to local ollama server
# import requests
# import json
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
# url = f"{api_base}/api/generate"
# data = {
# "model": model,
# "prompt": prompt,
# }
# session = requests.Session()
# with session.post(url, json=data, stream=True) as resp:
# for line in resp.iter_lines():
# if line:
# try:
# json_chunk = line.decode("utf-8")
# chunks = json_chunk.split("\n")
# for chunk in chunks:
# if chunk.strip() != "":
# j = json.loads(chunk)
# if "response" in j:
# completion_obj = {
# "role": "assistant",
# "content": "",
# }
# completion_obj["content"] = j["response"]
# yield {"choices": [{"delta": completion_obj}]}
# except Exception as e:
# print(f"Error decoding JSON: {e}")
# session.close()
# response = get_ollama_response_stream()
# for chunk in response:
# print(chunk['choices'][0]['delta'])

View file

@ -1,4 +1,5 @@
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ###### # ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
# # https://ollama.ai/
# import sys, os # import sys, os
# import traceback # import traceback
@ -15,32 +16,36 @@
# user_message = "respond in 20 words. who are you?" # user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# async def get_response(generator):
# response = ""
# async for elem in generator:
# print(elem)
# response += elem["content"]
# return response
# def test_completion_ollama(): # def test_completion_ollama():
# try: # try:
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama") # response = completion(
# model="llama2",
# messages=messages,
# api_base="http://localhost:11434",
# custom_llm_provider="ollama"
# )
# print(response) # print(response)
# string_response = asyncio.run(get_response(response))
# print(string_response)
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama()
# # test_completion_ollama()
# def test_completion_ollama_stream(): # def test_completion_ollama_stream():
# user_message = "what is litellm?"
# messages = [{ "content": user_message,"role": "user"}]
# try: # try:
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True) # response = completion(
# model="llama2",
# messages=messages,
# api_base="http://localhost:11434",
# custom_llm_provider="ollama",
# stream=True
# )
# print(response) # print(response)
# string_response = asyncio.run(get_response(response)) # for chunk in response:
# print(string_response) # print(chunk['choices'][0]['delta'])
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_stream() # # test_completion_ollama_stream()

View file

@ -2217,51 +2217,6 @@ def read_config_args(config_path):
print("An error occurred while reading config:", str(e)) print("An error occurred while reading config:", str(e))
raise e raise e
########## ollama implementation ############################
async def get_ollama_response_stream(
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
):
session = aiohttp.ClientSession()
url = f"{api_base}/api/generate"
data = {
"model": model,
"prompt": prompt,
}
try:
async with session.post(url, json=data) as resp:
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield {"choices": [{"delta": completion_obj}]}
# self.responses.append(j["response"])
# yield "blank"
except Exception as e:
print(f"Error decoding JSON: {e}")
finally:
await session.close()
async def stream_to_string(generator):
response = ""
async for chunk in generator:
response += chunk["content"]
return response
########## experimental completion variants ############################ ########## experimental completion variants ############################
def get_model_split_test(models, completion_call_id): def get_model_split_test(models, completion_call_id):