mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
olla upgrades, fix streaming, add non streaming resp
This commit is contained in:
parent
6cb03d7c63
commit
56bd8c1c52
5 changed files with 135 additions and 86 deletions
35
litellm/llms/ollama.py
Normal file
35
litellm/llms/ollama.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
# ollama implementation
|
||||
def get_ollama_response_stream(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?"
|
||||
):
|
||||
url = f"{api_base}/api/generate"
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
}
|
||||
session = requests.Session()
|
||||
|
||||
with session.post(url, json=data, stream=True) as resp:
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
completion_obj["content"] = j["response"]
|
||||
yield {"choices": [{"delta": completion_obj}]}
|
||||
except Exception as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
session.close()
|
|
@ -28,6 +28,7 @@ from .llms import replicate
|
|||
from .llms import aleph_alpha
|
||||
from .llms import baseten
|
||||
from .llms import vllm
|
||||
from .llms import ollama
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, List, Optional, Dict
|
||||
|
@ -39,9 +40,6 @@ from litellm.utils import (
|
|||
ModelResponse,
|
||||
read_config_args,
|
||||
)
|
||||
from litellm.utils import (
|
||||
get_ollama_response_stream,
|
||||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
|
@ -728,10 +726,27 @@ def completion(
|
|||
logging.pre_call(
|
||||
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
||||
)
|
||||
|
||||
generator = get_ollama_response_stream(endpoint, model, prompt)
|
||||
# assume all responses are streamed
|
||||
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
return generator
|
||||
else:
|
||||
response_string = ""
|
||||
for chunk in generator:
|
||||
response_string+=chunk['choices'][0]['delta']['content']
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = response_string
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(encoding.encode(response_string))
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
response = model_response
|
||||
elif (
|
||||
custom_llm_provider == "baseten"
|
||||
or litellm.api_base == "https://app.baseten.co"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||
# import aiohttp
|
||||
# import json
|
||||
# import asyncio
|
||||
|
@ -37,25 +37,64 @@
|
|||
# finally:
|
||||
# await session.close()
|
||||
|
||||
# # async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# # generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
||||
# # response = ""
|
||||
# # async for elem in generator:
|
||||
# # print(elem)
|
||||
# # response += elem["content"]
|
||||
# # return response
|
||||
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
||||
# response = ""
|
||||
# async for elem in generator:
|
||||
# print(elem)
|
||||
# response += elem["content"]
|
||||
# return response
|
||||
|
||||
# # #generator = get_ollama_response_stream()
|
||||
# #generator = get_ollama_response_stream()
|
||||
|
||||
# # result = asyncio.run(get_ollama_response_no_stream())
|
||||
# # print(result)
|
||||
# result = asyncio.run(get_ollama_response_no_stream())
|
||||
# print(result)
|
||||
|
||||
# # # return this generator to the client for streaming requests
|
||||
# # return this generator to the client for streaming requests
|
||||
|
||||
|
||||
# # async def get_response():
|
||||
# # global generator
|
||||
# # async for elem in generator:
|
||||
# # print(elem)
|
||||
# async def get_response():
|
||||
# global generator
|
||||
# async for elem in generator:
|
||||
# print(elem)
|
||||
|
||||
# # asyncio.run(get_response())
|
||||
# asyncio.run(get_response())
|
||||
|
||||
|
||||
|
||||
##### latest implementation of making raw http post requests to local ollama server
|
||||
|
||||
# import requests
|
||||
# import json
|
||||
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# url = f"{api_base}/api/generate"
|
||||
# data = {
|
||||
# "model": model,
|
||||
# "prompt": prompt,
|
||||
# }
|
||||
# session = requests.Session()
|
||||
|
||||
# with session.post(url, json=data, stream=True) as resp:
|
||||
# for line in resp.iter_lines():
|
||||
# if line:
|
||||
# try:
|
||||
# json_chunk = line.decode("utf-8")
|
||||
# chunks = json_chunk.split("\n")
|
||||
# for chunk in chunks:
|
||||
# if chunk.strip() != "":
|
||||
# j = json.loads(chunk)
|
||||
# if "response" in j:
|
||||
# completion_obj = {
|
||||
# "role": "assistant",
|
||||
# "content": "",
|
||||
# }
|
||||
# completion_obj["content"] = j["response"]
|
||||
# yield {"choices": [{"delta": completion_obj}]}
|
||||
# except Exception as e:
|
||||
# print(f"Error decoding JSON: {e}")
|
||||
# session.close()
|
||||
|
||||
# response = get_ollama_response_stream()
|
||||
|
||||
# for chunk in response:
|
||||
# print(chunk['choices'][0]['delta'])
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
###### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||
# ##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||
# # https://ollama.ai/
|
||||
|
||||
# import sys, os
|
||||
# import traceback
|
||||
|
@ -15,32 +16,36 @@
|
|||
# user_message = "respond in 20 words. who are you?"
|
||||
# messages = [{ "content": user_message,"role": "user"}]
|
||||
|
||||
# async def get_response(generator):
|
||||
# response = ""
|
||||
# async for elem in generator:
|
||||
# print(elem)
|
||||
# response += elem["content"]
|
||||
# return response
|
||||
|
||||
# def test_completion_ollama():
|
||||
# try:
|
||||
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama")
|
||||
# response = completion(
|
||||
# model="llama2",
|
||||
# messages=messages,
|
||||
# api_base="http://localhost:11434",
|
||||
# custom_llm_provider="ollama"
|
||||
# )
|
||||
# print(response)
|
||||
# string_response = asyncio.run(get_response(response))
|
||||
# print(string_response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# # test_completion_ollama()
|
||||
# test_completion_ollama()
|
||||
|
||||
# def test_completion_ollama_stream():
|
||||
# user_message = "what is litellm?"
|
||||
# messages = [{ "content": user_message,"role": "user"}]
|
||||
# try:
|
||||
# response = completion(model="llama2", messages=messages, api_base="http://localhost:11434", custom_llm_provider="ollama", stream=True)
|
||||
# response = completion(
|
||||
# model="llama2",
|
||||
# messages=messages,
|
||||
# api_base="http://localhost:11434",
|
||||
# custom_llm_provider="ollama",
|
||||
# stream=True
|
||||
# )
|
||||
# print(response)
|
||||
# string_response = asyncio.run(get_response(response))
|
||||
# print(string_response)
|
||||
# for chunk in response:
|
||||
# print(chunk['choices'][0]['delta'])
|
||||
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_ollama_stream()
|
||||
# # test_completion_ollama_stream()
|
||||
|
|
|
@ -2217,51 +2217,6 @@ def read_config_args(config_path):
|
|||
print("An error occurred while reading config:", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
########## ollama implementation ############################
|
||||
|
||||
|
||||
async def get_ollama_response_stream(
|
||||
api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"
|
||||
):
|
||||
session = aiohttp.ClientSession()
|
||||
url = f"{api_base}/api/generate"
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
}
|
||||
try:
|
||||
async with session.post(url, json=data) as resp:
|
||||
async for line in resp.content.iter_any():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
completion_obj["content"] = j["response"]
|
||||
yield {"choices": [{"delta": completion_obj}]}
|
||||
# self.responses.append(j["response"])
|
||||
# yield "blank"
|
||||
except Exception as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def stream_to_string(generator):
|
||||
response = ""
|
||||
async for chunk in generator:
|
||||
response += chunk["content"]
|
||||
return response
|
||||
|
||||
|
||||
########## experimental completion variants ############################
|
||||
|
||||
def get_model_split_test(models, completion_call_id):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue