fix(ollama_chat.py): accept api key as a param for ollama calls

allows user to call hosted ollama endpoint using bearer token for auth
This commit is contained in:
Krrish Dholakia 2024-04-19 13:01:52 -07:00
parent b2bdc99474
commit 3c6b6355c7
4 changed files with 63 additions and 21 deletions

View file

@ -51,6 +51,7 @@ replicate_key: Optional[str] = None
cohere_key: Optional[str] = None cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None ai21_key: Optional[str] = None
ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None vertex_project: Optional[str] = None

View file

@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2", model="llama2",
messages=None, messages=None,
optional_params=None, optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True: if stream == True:
response = ollama_async_streaming( response = ollama_async_streaming(
url=url, url=url,
api_key=api_key,
data=data, data=data,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else: else:
response = ollama_acompletion( response = ollama_acompletion(
url=url, url=url,
api_key=api_key,
data=data, data=data,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
) )
return response return response
elif stream == True: elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
response = requests.post(
url=f"{url}",
json=data,
) )
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200: if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text) raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response return model_response
def ollama_completion_stream(url, data, logging_obj): def ollama_completion_stream(url, api_key, data, logging_obj):
with httpx.stream( _request = {
url=url, json=data, method="POST", timeout=litellm.request_timeout "url": f"{url}",
) as response: "json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try: try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError( raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj): async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try: try:
client = httpx.AsyncClient() client = httpx.AsyncClient()
async with client.stream( _request = {
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout "url": f"{url}",
) as response: "json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError( raise OllamaError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion( async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
): ):
data["stream"] = False data["stream"] = False
try: try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data) _request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()

View file

@ -1941,9 +1941,16 @@ def completion(
or "http://localhost:11434" or "http://localhost:11434"
) )
api_key = (
api_key
or litellm.ollama_key
or os.environ.get("OLLAMA_API_KEY")
or litellm.api_key
)
## LOGGING ## LOGGING
generator = ollama_chat.get_ollama_response( generator = ollama_chat.get_ollama_response(
api_base, api_base,
api_key,
model, model,
messages, messages,
optional_params, optional_params,

View file

@ -31,12 +31,12 @@ litellm_settings:
upperbound_key_generate_params: upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
# router_settings: router_settings:
# routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: True enable_pre_call_checks: True
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234