ollama default api_base to http://localhost:11434

This commit is contained in:
ishaan-jaff 2023-10-05 11:03:36 -07:00
parent 1492916a37
commit f0d6d713e0
2 changed files with 34 additions and 18 deletions

View file

@ -1041,10 +1041,11 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
endpoint = ( api_base = (
litellm.api_base litellm.api_base or
or api_base api_base or
or "http://localhost:11434" "http://localhost:11434"
) )
if model in litellm.custom_prompt_dict: if model in litellm.custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
@ -1060,13 +1061,13 @@ def completion(
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=prompt, api_key=None, additional_args={"endpoint": endpoint, "custom_prompt_dict": litellm.custom_prompt_dict} input=prompt, api_key=None, additional_args={"api_base": api_base, "custom_prompt_dict": litellm.custom_prompt_dict}
) )
if kwargs.get('acompletion', False) == True: if kwargs.get('acompletion', False) == True:
async_generator = ollama.async_get_ollama_response_stream(endpoint, model, prompt) async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt)
return async_generator return async_generator
generator = ollama.get_ollama_response_stream(endpoint, model, prompt) generator = ollama.get_ollama_response_stream(api_base, model, prompt)
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed
return generator return generator

View file

@ -16,18 +16,33 @@
# user_message = "respond in 20 words. who are you?" # user_message = "respond in 20 words. who are you?"
# messages = [{ "content": user_message,"role": "user"}] # messages = [{ "content": user_message,"role": "user"}]
# # def test_completion_ollama(): # def test_completion_ollama():
# # try: # try:
# # response = completion( # response = completion(
# # model="ollama/llama2", # model="ollama/llama2",
# # messages=messages, # messages=messages,
# # api_base="http://localhost:11434" # max_tokens=200,
# # ) # request_timeout = 10,
# # print(response)
# # except Exception as e:
# # pytest.fail(f"Error occurred: {e}")
# # test_completion_ollama() # )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama()
# def test_completion_ollama_with_api_base():
# try:
# response = completion(
# model="ollama/llama2",
# messages=messages,
# api_base="http://localhost:11434"
# )
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_ollama_with_api_base()
# # def test_completion_ollama_stream(): # # def test_completion_ollama_stream():
# # user_message = "what is litellm?" # # user_message = "what is litellm?"