From 3c6b6355c7ffaad28fe8aab3e39f8e380fd5266b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Apr 2024 13:01:52 -0700 Subject: [PATCH] fix(ollama_chat.py): accept api key as a param for ollama calls allows user to call hosted ollama endpoint using bearer token for auth --- litellm/__init__.py | 1 + litellm/llms/ollama_chat.py | 64 ++++++++++++++++++++------- litellm/main.py | 7 +++ litellm/proxy/_new_secret_config.yaml | 12 ++--- 4 files changed, 63 insertions(+), 21 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 5ef78dce4..21f98e8b3 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -51,6 +51,7 @@ replicate_key: Optional[str] = None cohere_key: Optional[str] = None maritalk_key: Optional[str] = None ai21_key: Optional[str] = None +ollama_key: Optional[str] = None openrouter_key: Optional[str] = None huggingface_key: Optional[str] = None vertex_project: Optional[str] = None diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index d442ba5aa..aea00a303 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -184,6 +184,7 @@ class OllamaChatConfig: # ollama implementation def get_ollama_response( api_base="http://localhost:11434", + api_key: Optional[str] = None, model="llama2", messages=None, optional_params=None, @@ -236,6 +237,7 @@ def get_ollama_response( if stream == True: response = ollama_async_streaming( url=url, + api_key=api_key, data=data, model_response=model_response, encoding=encoding, @@ -244,6 +246,7 @@ def get_ollama_response( else: response = ollama_acompletion( url=url, + api_key=api_key, data=data, model_response=model_response, encoding=encoding, @@ -252,12 +255,17 @@ def get_ollama_response( ) return response elif stream == True: - return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) + return ollama_completion_stream( + url=url, api_key=api_key, data=data, logging_obj=logging_obj + ) - response = requests.post( - url=f"{url}", - json=data, - ) + _request = { + "url": f"{url}", + "json": data, + } + if api_key is not None: + _request["headers"] = "Bearer {}".format(api_key) + response = requests.post(**_request) # type: ignore if response.status_code != 200: raise OllamaError(status_code=response.status_code, message=response.text) @@ -307,10 +315,16 @@ def get_ollama_response( return model_response -def ollama_completion_stream(url, data, logging_obj): - with httpx.stream( - url=url, json=data, method="POST", timeout=litellm.request_timeout - ) as response: +def ollama_completion_stream(url, api_key, data, logging_obj): + _request = { + "url": f"{url}", + "json": data, + "method": "POST", + "timeout": litellm.request_timeout, + } + if api_key is not None: + _request["headers"] = "Bearer {}".format(api_key) + with httpx.stream(**_request) as response: try: if response.status_code != 200: raise OllamaError( @@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj): raise e -async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): +async def ollama_async_streaming( + url, api_key, data, model_response, encoding, logging_obj +): try: client = httpx.AsyncClient() - async with client.stream( - url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout - ) as response: + _request = { + "url": f"{url}", + "json": data, + "method": "POST", + "timeout": litellm.request_timeout, + } + if api_key is not None: + _request["headers"] = "Bearer {}".format(api_key) + async with client.stream(**_request) as response: if response.status_code != 200: raise OllamaError( status_code=response.status_code, message=response.text @@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob async def ollama_acompletion( - url, data, model_response, encoding, logging_obj, function_name + url, + api_key: Optional[str], + data, + model_response, + encoding, + logging_obj, + function_name, ): data["stream"] = False try: timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes async with aiohttp.ClientSession(timeout=timeout) as session: - resp = await session.post(url, json=data) + _request = { + "url": f"{url}", + "json": data, + } + if api_key is not None: + _request["headers"] = "Bearer {}".format(api_key) + resp = await session.post(**_request) if resp.status != 200: text = await resp.text() diff --git a/litellm/main.py b/litellm/main.py index b1e75f744..65696b3c0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1941,9 +1941,16 @@ def completion( or "http://localhost:11434" ) + api_key = ( + api_key + or litellm.ollama_key + or os.environ.get("OLLAMA_API_KEY") + or litellm.api_key + ) ## LOGGING generator = ollama_chat.get_ollama_response( api_base, + api_key, model, messages, optional_params, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index ca8b4c539..0f7c24576 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -31,12 +31,12 @@ litellm_settings: upperbound_key_generate_params: max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET -# router_settings: -# routing_strategy: usage-based-routing-v2 -# redis_host: os.environ/REDIS_HOST -# redis_password: os.environ/REDIS_PASSWORD -# redis_port: os.environ/REDIS_PORT -# enable_pre_call_checks: True +router_settings: + routing_strategy: usage-based-routing-v2 + redis_host: os.environ/REDIS_HOST + redis_password: os.environ/REDIS_PASSWORD + redis_port: os.environ/REDIS_PORT + enable_pre_call_checks: True general_settings: master_key: sk-1234