forked from phoenix/litellm-mirror
fix(ollama_chat.py): accept api key as a param for ollama calls
allows user to call hosted ollama endpoint using bearer token for auth
This commit is contained in:
parent
b2bdc99474
commit
3c6b6355c7
4 changed files with 63 additions and 21 deletions
|
@ -51,6 +51,7 @@ replicate_key: Optional[str] = None
|
||||||
cohere_key: Optional[str] = None
|
cohere_key: Optional[str] = None
|
||||||
maritalk_key: Optional[str] = None
|
maritalk_key: Optional[str] = None
|
||||||
ai21_key: Optional[str] = None
|
ai21_key: Optional[str] = None
|
||||||
|
ollama_key: Optional[str] = None
|
||||||
openrouter_key: Optional[str] = None
|
openrouter_key: Optional[str] = None
|
||||||
huggingface_key: Optional[str] = None
|
huggingface_key: Optional[str] = None
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
|
|
|
@ -184,6 +184,7 @@ class OllamaChatConfig:
|
||||||
# ollama implementation
|
# ollama implementation
|
||||||
def get_ollama_response(
|
def get_ollama_response(
|
||||||
api_base="http://localhost:11434",
|
api_base="http://localhost:11434",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
model="llama2",
|
model="llama2",
|
||||||
messages=None,
|
messages=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
|
@ -236,6 +237,7 @@ def get_ollama_response(
|
||||||
if stream == True:
|
if stream == True:
|
||||||
response = ollama_async_streaming(
|
response = ollama_async_streaming(
|
||||||
url=url,
|
url=url,
|
||||||
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -244,6 +246,7 @@ def get_ollama_response(
|
||||||
else:
|
else:
|
||||||
response = ollama_acompletion(
|
response = ollama_acompletion(
|
||||||
url=url,
|
url=url,
|
||||||
|
api_key=api_key,
|
||||||
data=data,
|
data=data,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -252,12 +255,17 @@ def get_ollama_response(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
elif stream == True:
|
elif stream == True:
|
||||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
return ollama_completion_stream(
|
||||||
|
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||||
response = requests.post(
|
|
||||||
url=f"{url}",
|
|
||||||
json=data,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_request = {
|
||||||
|
"url": f"{url}",
|
||||||
|
"json": data,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
_request["headers"] = "Bearer {}".format(api_key)
|
||||||
|
response = requests.post(**_request) # type: ignore
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
@ -307,10 +315,16 @@ def get_ollama_response(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
def ollama_completion_stream(url, data, logging_obj):
|
def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||||
with httpx.stream(
|
_request = {
|
||||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
"url": f"{url}",
|
||||||
) as response:
|
"json": data,
|
||||||
|
"method": "POST",
|
||||||
|
"timeout": litellm.request_timeout,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
_request["headers"] = "Bearer {}".format(api_key)
|
||||||
|
with httpx.stream(**_request) as response:
|
||||||
try:
|
try:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OllamaError(
|
raise OllamaError(
|
||||||
|
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
async def ollama_async_streaming(
|
||||||
|
url, api_key, data, model_response, encoding, logging_obj
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
client = httpx.AsyncClient()
|
client = httpx.AsyncClient()
|
||||||
async with client.stream(
|
_request = {
|
||||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
"url": f"{url}",
|
||||||
) as response:
|
"json": data,
|
||||||
|
"method": "POST",
|
||||||
|
"timeout": litellm.request_timeout,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
_request["headers"] = "Bearer {}".format(api_key)
|
||||||
|
async with client.stream(**_request) as response:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OllamaError(
|
raise OllamaError(
|
||||||
status_code=response.status_code, message=response.text
|
status_code=response.status_code, message=response.text
|
||||||
|
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
||||||
|
|
||||||
|
|
||||||
async def ollama_acompletion(
|
async def ollama_acompletion(
|
||||||
url, data, model_response, encoding, logging_obj, function_name
|
url,
|
||||||
|
api_key: Optional[str],
|
||||||
|
data,
|
||||||
|
model_response,
|
||||||
|
encoding,
|
||||||
|
logging_obj,
|
||||||
|
function_name,
|
||||||
):
|
):
|
||||||
data["stream"] = False
|
data["stream"] = False
|
||||||
try:
|
try:
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
resp = await session.post(url, json=data)
|
_request = {
|
||||||
|
"url": f"{url}",
|
||||||
|
"json": data,
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
_request["headers"] = "Bearer {}".format(api_key)
|
||||||
|
resp = await session.post(**_request)
|
||||||
|
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
|
|
|
@ -1941,9 +1941,16 @@ def completion(
|
||||||
or "http://localhost:11434"
|
or "http://localhost:11434"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or litellm.ollama_key
|
||||||
|
or os.environ.get("OLLAMA_API_KEY")
|
||||||
|
or litellm.api_key
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
generator = ollama_chat.get_ollama_response(
|
generator = ollama_chat.get_ollama_response(
|
||||||
api_base,
|
api_base,
|
||||||
|
api_key,
|
||||||
model,
|
model,
|
||||||
messages,
|
messages,
|
||||||
optional_params,
|
optional_params,
|
||||||
|
|
|
@ -31,12 +31,12 @@ litellm_settings:
|
||||||
upperbound_key_generate_params:
|
upperbound_key_generate_params:
|
||||||
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||||
|
|
||||||
# router_settings:
|
router_settings:
|
||||||
# routing_strategy: usage-based-routing-v2
|
routing_strategy: usage-based-routing-v2
|
||||||
# redis_host: os.environ/REDIS_HOST
|
redis_host: os.environ/REDIS_HOST
|
||||||
# redis_password: os.environ/REDIS_PASSWORD
|
redis_password: os.environ/REDIS_PASSWORD
|
||||||
# redis_port: os.environ/REDIS_PORT
|
redis_port: os.environ/REDIS_PORT
|
||||||
# enable_pre_call_checks: True
|
enable_pre_call_checks: True
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue