forked from phoenix/litellm-mirror
fix(ollama.py): fix sync ollama streaming
This commit is contained in:
parent
13d088b72e
commit
a3c7a340a5
3 changed files with 20 additions and 42 deletions
|
@ -148,39 +148,21 @@ def get_ollama_response_stream(
|
|||
return response
|
||||
|
||||
else:
|
||||
return ollama_completion_stream(url=url, data=data)
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
|
||||
def ollama_completion_stream(url, data):
|
||||
session = requests.Session()
|
||||
|
||||
with session.post(url, json=data, stream=True) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise OllamaError(status_code=resp.status_code, message=resp.text)
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
chunks = json_chunk.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip() != "":
|
||||
j = json.loads(chunk)
|
||||
if "error" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"error": j
|
||||
}
|
||||
yield completion_obj
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
}
|
||||
completion_obj["content"] = j["response"]
|
||||
yield completion_obj
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
session.close()
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url,
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
|
|
|
@ -1320,14 +1320,8 @@ def completion(
|
|||
|
||||
## LOGGING
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
if acompletion is True:
|
||||
if acompletion is True or optional_params.get("stream", False) == True:
|
||||
return generator
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
response = CustomStreamWrapper(
|
||||
generator, model, custom_llm_provider="ollama", logging_obj=logging
|
||||
)
|
||||
return response
|
||||
else:
|
||||
response_string = ""
|
||||
for chunk in generator:
|
||||
|
|
|
@ -33,16 +33,18 @@
|
|||
# try:
|
||||
# response = completion(
|
||||
# model="ollama/llama2",
|
||||
# messages=messages,
|
||||
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
# max_tokens=200,
|
||||
# request_timeout = 10,
|
||||
|
||||
# stream=True
|
||||
# )
|
||||
# for chunk in response:
|
||||
# print(chunk)
|
||||
# print(response)
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# # test_completion_ollama()
|
||||
# test_completion_ollama()
|
||||
|
||||
# def test_completion_ollama_with_api_base():
|
||||
# try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue