support acompletion + stream for ollama

This commit is contained in:
ishaan-jaff 2023-09-21 10:39:15 -07:00
parent 4fa9b19af7
commit 35bb6f5a50
2 changed files with 42 additions and 2 deletions

View file

@ -1,5 +1,6 @@
import requests import requests
import json import json
from async_generator import async_generator, yield_
# ollama implementation # ollama implementation
def get_ollama_response_stream( def get_ollama_response_stream(
@ -32,4 +33,38 @@ def get_ollama_response_stream(
yield {"choices": [{"delta": completion_obj}]} yield {"choices": [{"delta": completion_obj}]}
except Exception as e: except Exception as e:
print(f"Error decoding JSON: {e}") print(f"Error decoding JSON: {e}")
session.close()
# ollama implementation
@async_generator
async def async_get_ollama_response_stream(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?"
):
url = f"{api_base}/api/generate"
data = {
"model": model,
"prompt": prompt,
}
session = requests.Session()
with session.post(url, json=data, stream=True) as resp:
for line in resp.iter_lines():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
await yield_({"choices": [{"delta": completion_obj}]})
except Exception as e:
print(f"Error decoding JSON: {e}")
session.close() session.close()

View file

@ -75,7 +75,7 @@ async def acompletion(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs) func = partial(completion, *args, **kwargs, acompletion=True)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
@ -180,6 +180,7 @@ def completion(
fallbacks=[], fallbacks=[],
caching = False, caching = False,
cache_params = {}, # optional to specify metadata for caching cache_params = {}, # optional to specify metadata for caching
acompletion=False,
) -> ModelResponse: ) -> ModelResponse:
""" """
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -215,7 +216,7 @@ def completion(
""" """
if mock_response: if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response) return mock_completion(model, messages, stream=stream, mock_response=mock_response)
args = locals() args = locals()
try: try:
logging = litellm_logging_obj logging = litellm_logging_obj
@ -928,6 +929,10 @@ def completion(
logging.pre_call( logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint} input=prompt, api_key=None, additional_args={"endpoint": endpoint}
) )
if acompletion == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt)
return async_generator
generator = ollama.get_ollama_response_stream(endpoint, model, prompt) generator = ollama.get_ollama_response_stream(endpoint, model, prompt)
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed