(feat) debug ollama POST request

This commit is contained in:
ishaan-jaff 2023-11-14 17:53:48 -08:00
parent 7c317b78eb
commit e82b8ed7e2
2 changed files with 18 additions and 7 deletions

View file

@ -113,7 +113,8 @@ def get_ollama_response_stream(
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None optional_params=None,
logging_obj=None,
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
@ -131,6 +132,12 @@ def get_ollama_response_stream(
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data},
)
session = requests.Session() session = requests.Session()
with session.post(url, json=data, stream=True) as resp: with session.post(url, json=data, stream=True) as resp:
@ -169,7 +176,8 @@ if async_generator_imported:
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None optional_params=None,
logging_obj=None,
): ):
url = f"{api_base}/api/generate" url = f"{api_base}/api/generate"
@ -184,6 +192,12 @@ if async_generator_imported:
"prompt": prompt, "prompt": prompt,
**optional_params **optional_params
} }
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data},
)
session = requests.Session() session = requests.Session()
with session.post(url, json=data, stream=True) as resp: with session.post(url, json=data, stream=True) as resp:

View file

@ -1235,16 +1235,13 @@ def completion(
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
## LOGGING ## LOGGING
logging.pre_call(
input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": custom_prompt_dict}
)
if kwargs.get('acompletion', False) == True: if kwargs.get('acompletion', False) == True:
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params) async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
return async_generator return async_generator
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params) generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed
response = CustomStreamWrapper( response = CustomStreamWrapper(