fix(ollama_chat.py): accept api key as a param for ollama calls

allows user to call hosted ollama endpoint using bearer token for auth
This commit is contained in:
Krrish Dholakia 2024-04-19 13:01:52 -07:00
parent b2bdc99474
commit 3c6b6355c7
4 changed files with 63 additions and 21 deletions

View file

@ -51,6 +51,7 @@ replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
maritalk_key: Optional[str] = None
ai21_key: Optional[str] = None
ollama_key: Optional[str] = None
openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None

View file

@ -184,6 +184,7 @@ class OllamaChatConfig:
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
messages=None,
optional_params=None,
@ -236,6 +237,7 @@ def get_ollama_response(
if stream == True:
response = ollama_async_streaming(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -244,6 +246,7 @@ def get_ollama_response(
else:
response = ollama_acompletion(
url=url,
api_key=api_key,
data=data,
model_response=model_response,
encoding=encoding,
@ -252,12 +255,17 @@ def get_ollama_response(
)
return response
elif stream == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj
)
response = requests.post(
url=f"{url}",
json=data,
)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
response = requests.post(**_request) # type: ignore
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
@ -307,10 +315,16 @@ def get_ollama_response(
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
def ollama_completion_stream(url, api_key, data, logging_obj):
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
with httpx.stream(**_request) as response:
try:
if response.status_code != 200:
raise OllamaError(
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
async def ollama_async_streaming(
url, api_key, data, model_response, encoding, logging_obj
):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
_request = {
"url": f"{url}",
"json": data,
"method": "POST",
"timeout": litellm.request_timeout,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
async with client.stream(**_request) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
async def ollama_acompletion(
url, data, model_response, encoding, logging_obj, function_name
url,
api_key: Optional[str],
data,
model_response,
encoding,
logging_obj,
function_name,
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
_request = {
"url": f"{url}",
"json": data,
}
if api_key is not None:
_request["headers"] = "Bearer {}".format(api_key)
resp = await session.post(**_request)
if resp.status != 200:
text = await resp.text()

View file

@ -1941,9 +1941,16 @@ def completion(
or "http://localhost:11434"
)
api_key = (
api_key
or litellm.ollama_key
or os.environ.get("OLLAMA_API_KEY")
or litellm.api_key
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,

View file

@ -31,12 +31,12 @@ litellm_settings:
upperbound_key_generate_params:
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
# router_settings:
# routing_strategy: usage-based-routing-v2
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
# enable_pre_call_checks: True
router_settings:
routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True
general_settings:
master_key: sk-1234