mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
olla upgrades, fix streaming, add non streaming resp
This commit is contained in:
parent
6cb03d7c63
commit
56bd8c1c52
5 changed files with 135 additions and 86 deletions
35
litellm/llms/ollama.py
Normal file
35
litellm/llms/ollama.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
# ollama implementation
|
||||||
|
def get_ollama_response_stream(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="llama2",
|
||||||
|
prompt="Why is the sky blue?"
|
||||||
|
):
|
||||||
|
url = f"{api_base}/api/generate"
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
session = requests.Session()
|
||||||
|
|
||||||
|
with session.post(url, json=data, stream=True) as resp:
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
json_chunk = line.decode("utf-8")
|
||||||
|
chunks = json_chunk.split("\n")
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.strip() != "":
|
||||||
|
j = json.loads(chunk)
|
||||||
|
if "response" in j:
|
||||||
|
completion_obj = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
}
|
||||||
|
completion_obj["content"] = j["response"]
|
||||||
|
yield {"choices": [{"delta": completion_obj}]}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding JSON: {e}")
|
||||||
|
session.close()
|
|
@ -28,6 +28,7 @@ from .llms import replicate
|
||||||
from .llms import aleph_alpha
|
from .llms import aleph_alpha
|
||||||
from .llms import baseten
|
from .llms import baseten
|
||||||
from .llms import vllm
|
from .llms import vllm
|
||||||
|
from .llms import ollama
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Callable, List, Optional, Dict
|
from typing import Callable, List, Optional, Dict
|
||||||
|
@ -39,9 +40,6 @@ from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
read_config_args,
|
read_config_args,
|
||||||
)
|
)
|
||||||
from litellm.utils import (
|
|
||||||
get_ollama_response_stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
@ -728,10 +726,27 @@ def completion(
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
||||||
)
|
)
|
||||||
|
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
|
||||||
generator = get_ollama_response_stream(endpoint, model, prompt)
|
if optional_params.get("stream", False) == True:
|
||||||
# assume all responses are streamed
|
# assume all ollama responses are streamed
|
||||||
return generator
|
return generator
|
||||||
|
else:
|
||||||
|
response_string = ""
|
||||||
|
for chunk in generator:
|
||||||
|
response_string+=chunk['choices'][0]['delta']['content']
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["message"]["content"] = response_string
|
||||||
|
model_response["created"] = time.time()
|
||||||
|
model_response["model"] = "ollama/" + model
|
||||||
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
|
completion_tokens = len(encoding.encode(response_string))
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
}
|
||||||
|
response = model_response
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "baseten"
|
custom_llm_provider == "baseten"
|
||||||
or litellm.api_base == "https://app.baseten.co"
|
or litellm.api_base == "https://app.baseten.co"
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||||
# import aiohttp
|
# import aiohttp
|
||||||
# import json
|
# import json
|
||||||
# import asyncio
|
# import asyncio
|
||||||
|
@ -37,25 +37,64 @@
|
||||||
# finally:
|
# finally:
|
||||||
# await session.close()
|
# await session.close()
|
||||||
|
|
||||||
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||||
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
||||||
# # response = ""
|
# response = ""
|
||||||
# # async for elem in generator:
|
# async for elem in generator:
|
||||||
# # print(elem)
|
# print(elem)
|
||||||
# # response += elem["content"]
|
# response += elem["content"]
|
||||||
# # return response
|
# return response
|
||||||
|
|
||||||
# # #generator = get_ollama_response_stream()
|
# #generator = get_ollama_response_stream()
|
||||||
|
|
||||||
# # result = asyncio.run(get_ollama_response_no_stream())
|
# result = asyncio.run(get_ollama_response_no_stream())
|
||||||
# # print(result)
|
# print(result)
|
||||||
|
|
||||||
# # # return this generator to the client for streaming requests
|
# # return this generator to the client for streaming requests
|
||||||
|
|
||||||
|
|
||||||
# # async def get_response():
|
# async def get_response():
|
||||||
# # global generator
|
# global generator
|
||||||
# # async for elem in generator:
|
# async for elem in generator:
|
||||||
# # print(elem)
|
# print(elem)
|
||||||
|
|
||||||
# # asyncio.run(get_response())
|
# asyncio.run(get_response())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### latest implementation of making raw http post requests to local ollama server
|
||||||
|
|
||||||
|
# import requests
|
||||||
|
# import json
|
||||||
|
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||||
|
# url = f"{api_base}/api/generate"
|
||||||
|
# data = {
|
||||||
|
# "model": model,
|
||||||
|
# "prompt": prompt,
|
||||||
|
# }
|
||||||
|
# session = requests.Session()
|
||||||
|
|
||||||
|
# with session.post(url, json=data, stream=True) as resp:
|
||||||
|
# for line in resp.iter_lines():
|
||||||
|
# if line:
|
||||||
|
# try:
|
||||||
|
# json_chunk = line.decode("utf-8")
|
||||||
|
# chunks = json_chunk.split("\n")
|
||||||
|
# for chunk in chunks:
|
||||||
|
# if chunk.strip() != "":
|
||||||
|
# j = json.loads(chunk)
|
||||||
|
# if "response" in j:
|
||||||
|
# completion_obj = {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "",
|
||||||
|
# }
|
||||||
|
# completion_obj["content"] = j["response"]
|
||||||
|
# yield {"choices": [{"delta": completion_obj}]}
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Error decoding JSON: {e}")
|
||||||
|
# session.close()
|
||||||
|
|
||||||
|
# response = get_ollama_response_stream()
|
||||||
|
|
||||||
|
# for chunk in response:
|
||||||
|
# print(chunk['choices'][0]['delta'])
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
# ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||||
|
# # https://ollama.ai/
|
||||||
|
|
||||||
# import sys, os
|
# import sys, os
|
||||||
# import traceback
|
# import traceback
|
||||||
|
@ -15,32 +16,36 @@
|
||||||
# user_message = "respond in 20 words. who are you?"
|
# user_message = "respond in 20 words. who are you?"
|
||||||
# messages = [{ "content": user_message,"role": "user"}]
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
|
||||||
# async def get_response(generator):
|
|
||||||
# response = ""
|
|
||||||
# async for elem in generator:
|
|
||||||
# print(elem)
|
|
||||||
# response += elem["content"]
|
|
||||||
# return response
|
|
||||||
|
|
||||||
# def test_completion_ollama():
|
# def test_completion_ollama():
|
||||||
# try:
|
# try:
|
||||||
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama")
|
# response = completion(
|
||||||
|
# model="llama2",
|
||||||
|
# messages=messages,
|
||||||
|
# api_base="http://localhost:11434",
|
||||||
|
# custom_llm_provider="ollama"
|
||||||
|
# )
|
||||||
# print(response)
|
# print(response)
|
||||||
# string_response = asyncio.run(get_response(response))
|
|
||||||
# print(string_response)
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# test_completion_ollama()
|
||||||
# # test_completion_ollama()
|
|
||||||
|
|
||||||
# def test_completion_ollama_stream():
|
# def test_completion_ollama_stream():
|
||||||
|
# user_message = "what is litellm?"
|
||||||
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
# try:
|
# try:
|
||||||
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
|
# response = completion(
|
||||||
|
# model="llama2",
|
||||||
|
# messages=messages,
|
||||||
|
# api_base="http://localhost:11434",
|
||||||
|
# custom_llm_provider="ollama",
|
||||||
|
# stream=True
|
||||||
|
# )
|
||||||
# print(response)
|
# print(response)
|
||||||
# string_response = asyncio.run(get_response(response))
|
# for chunk in response:
|
||||||
# print(string_response)
|
# print(chunk['choices'][0]['delta'])
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_ollama_stream()
|
# # test_completion_ollama_stream()
|
||||||
|
|
|
@ -2217,51 +2217,6 @@ def read_config_args(config_path):
|
||||||
print("An error occurred while reading config:", str(e))
|
print("An error occurred while reading config:", str(e))
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
########## ollama implementation ############################
|
|
||||||
|
|
||||||
|
|
||||||
async def get_ollama_response_stream(
|
|
||||||
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
|
|
||||||
):
|
|
||||||
session = aiohttp.ClientSession()
|
|
||||||
url = f"{api_base}/api/generate"
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
async with session.post(url, json=data) as resp:
|
|
||||||
async for line in resp.content.iter_any():
|
|
||||||
if line:
|
|
||||||
try:
|
|
||||||
json_chunk = line.decode("utf-8")
|
|
||||||
chunks = json_chunk.split("\n")
|
|
||||||
for chunk in chunks:
|
|
||||||
if chunk.strip() != "":
|
|
||||||
j = json.loads(chunk)
|
|
||||||
if "response" in j:
|
|
||||||
completion_obj = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "",
|
|
||||||
}
|
|
||||||
completion_obj["content"] = j["response"]
|
|
||||||
yield {"choices": [{"delta": completion_obj}]}
|
|
||||||
# self.responses.append(j["response"])
|
|
||||||
# yield "blank"
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error decoding JSON: {e}")
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_to_string(generator):
|
|
||||||
response = ""
|
|
||||||
async for chunk in generator:
|
|
||||||
response += chunk["content"]
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
########## experimental completion variants ############################
|
########## experimental completion variants ############################
|
||||||
|
|
||||||
def get_model_split_test(models, completion_call_id):
|
def get_model_split_test(models, completion_call_id):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue