forked from phoenix/litellm-mirror
support acompletion + stream for ollama
This commit is contained in:
parent
4fa9b19af7
commit
35bb6f5a50
2 changed files with 42 additions and 2 deletions
|
@ -1,5 +1,6 @@
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
|
from async_generator import async_generator, yield_
|
||||||
|
|
||||||
# ollama implementation
|
# ollama implementation
|
||||||
def get_ollama_response_stream(
|
def get_ollama_response_stream(
|
||||||
|
@ -32,4 +33,38 @@ def get_ollama_response_stream(
|
||||||
yield {"choices": [{"delta": completion_obj}]}
|
yield {"choices": [{"delta": completion_obj}]}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error decoding JSON: {e}")
|
print(f"Error decoding JSON: {e}")
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
# ollama implementation
|
||||||
|
@async_generator
|
||||||
|
async def async_get_ollama_response_stream(
|
||||||
|
api_base="http://localhost:11434",
|
||||||
|
model="llama2",
|
||||||
|
prompt="Why is the sky blue?"
|
||||||
|
):
|
||||||
|
url = f"{api_base}/api/generate"
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
session = requests.Session()
|
||||||
|
|
||||||
|
with session.post(url, json=data, stream=True) as resp:
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if line:
|
||||||
|
try:
|
||||||
|
json_chunk = line.decode("utf-8")
|
||||||
|
chunks = json_chunk.split("\n")
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.strip() != "":
|
||||||
|
j = json.loads(chunk)
|
||||||
|
if "response" in j:
|
||||||
|
completion_obj = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
}
|
||||||
|
completion_obj["content"] = j["response"]
|
||||||
|
await yield_({"choices": [{"delta": completion_obj}]})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding JSON: {e}")
|
||||||
session.close()
|
session.close()
|
|
@ -75,7 +75,7 @@ async def acompletion(*args, **kwargs):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, *args, **kwargs)
|
func = partial(completion, *args, **kwargs, acompletion=True)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
|
@ -180,6 +180,7 @@ def completion(
|
||||||
fallbacks=[],
|
fallbacks=[],
|
||||||
caching = False,
|
caching = False,
|
||||||
cache_params = {}, # optional to specify metadata for caching
|
cache_params = {}, # optional to specify metadata for caching
|
||||||
|
acompletion=False,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
"""
|
"""
|
||||||
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||||
|
@ -215,7 +216,7 @@ def completion(
|
||||||
"""
|
"""
|
||||||
if mock_response:
|
if mock_response:
|
||||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||||
|
|
||||||
args = locals()
|
args = locals()
|
||||||
try:
|
try:
|
||||||
logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
|
@ -928,6 +929,10 @@ def completion(
|
||||||
logging.pre_call(
|
logging.pre_call(
|
||||||
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
input=prompt, api_key=None, additional_args={"endpoint": endpoint}
|
||||||
)
|
)
|
||||||
|
if acompletion == True:
|
||||||
|
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)
|
||||||
|
return async_generator
|
||||||
|
|
||||||
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
|
generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
|
||||||
if optional_params.get("stream", False) == True:
|
if optional_params.get("stream", False) == True:
|
||||||
# assume all ollama responses are streamed
|
# assume all ollama responses are streamed
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue