fix(ollama.py): fix sync ollama streaming

This commit is contained in:
Krrish Dholakia 2023-12-16 21:23:21 -08:00
parent 13d088b72e
commit a3c7a340a5
3 changed files with 20 additions and 42 deletions

View file

@ -148,39 +148,21 @@ def get_ollama_response_stream(
return response
else:
return ollama_completion_stream(url=url, data=data)
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
def ollama_completion_stream(url, data):
session = requests.Session()
with session.post(url, json=data, stream=True) as resp:
if resp.status_code != 200:
raise OllamaError(status_code=resp.status_code, message=resp.text)
for line in resp.iter_lines():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "error" in j:
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
yield completion_obj
if "response" in j:
completion_obj = {
"role": "assistant",
"content": "",
}
completion_obj["content"] = j["response"]
yield completion_obj
except Exception as e:
traceback.print_exc()
session.close()
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url,
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):

View file

@ -1320,14 +1320,8 @@ def completion(
## LOGGING
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True:
if acompletion is True or optional_params.get("stream", False) == True:
return generator
if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed
response = CustomStreamWrapper(
generator, model, custom_llm_provider="ollama", logging_obj=logging
)
return response
else:
response_string = ""
for chunk in generator:

View file

@ -33,16 +33,18 @@
# try:
# response = completion(
# model="ollama/llama2",
# messages=messages,
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
# max_tokens=200,
# request_timeout = 10,
# stream=True
# )
# for chunk in response:
# print(chunk)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama()
# test_completion_ollama()
# def test_completion_ollama_with_api_base():
# try: