forked from phoenix/litellm-mirror
fix(ollama_chat.py): accept api key as a param for ollama calls
allows user to call hosted ollama endpoint using bearer token for auth
This commit is contained in:
parent
b2bdc99474
commit
3c6b6355c7
4 changed files with 63 additions and 21 deletions
|
@ -51,6 +51,7 @@ replicate_key: Optional[str] = None
|
|||
cohere_key: Optional[str] = None
|
||||
maritalk_key: Optional[str] = None
|
||||
ai21_key: Optional[str] = None
|
||||
ollama_key: Optional[str] = None
|
||||
openrouter_key: Optional[str] = None
|
||||
huggingface_key: Optional[str] = None
|
||||
vertex_project: Optional[str] = None
|
||||
|
|
|
@ -184,6 +184,7 @@ class OllamaChatConfig:
|
|||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
api_key: Optional[str] = None,
|
||||
model="llama2",
|
||||
messages=None,
|
||||
optional_params=None,
|
||||
|
@ -236,6 +237,7 @@ def get_ollama_response(
|
|||
if stream == True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -244,6 +246,7 @@ def get_ollama_response(
|
|||
else:
|
||||
response = ollama_acompletion(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
|
@ -252,12 +255,17 @@ def get_ollama_response(
|
|||
)
|
||||
return response
|
||||
elif stream == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
return ollama_completion_stream(
|
||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
response = requests.post(**_request) # type: ignore
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
|
@ -307,10 +315,16 @@ def get_ollama_response(
|
|||
return model_response
|
||||
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
def ollama_completion_stream(url, api_key, data, logging_obj):
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
with httpx.stream(**_request) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
|
@ -329,12 +343,20 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
raise e
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
async def ollama_async_streaming(
|
||||
url, api_key, data, model_response, encoding, logging_obj
|
||||
):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
"method": "POST",
|
||||
"timeout": litellm.request_timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
async with client.stream(**_request) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
|
@ -353,13 +375,25 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
|
||||
|
||||
async def ollama_acompletion(
|
||||
url, data, model_response, encoding, logging_obj, function_name
|
||||
url,
|
||||
api_key: Optional[str],
|
||||
data,
|
||||
model_response,
|
||||
encoding,
|
||||
logging_obj,
|
||||
function_name,
|
||||
):
|
||||
data["stream"] = False
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
resp = await session.post(url, json=data)
|
||||
_request = {
|
||||
"url": f"{url}",
|
||||
"json": data,
|
||||
}
|
||||
if api_key is not None:
|
||||
_request["headers"] = "Bearer {}".format(api_key)
|
||||
resp = await session.post(**_request)
|
||||
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
|
|
|
@ -1941,9 +1941,16 @@ def completion(
|
|||
or "http://localhost:11434"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.ollama_key
|
||||
or os.environ.get("OLLAMA_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
## LOGGING
|
||||
generator = ollama_chat.get_ollama_response(
|
||||
api_base,
|
||||
api_key,
|
||||
model,
|
||||
messages,
|
||||
optional_params,
|
||||
|
|
|
@ -31,12 +31,12 @@ litellm_settings:
|
|||
upperbound_key_generate_params:
|
||||
max_budget: os.environ/LITELLM_UPPERBOUND_KEYS_MAX_BUDGET
|
||||
|
||||
# router_settings:
|
||||
# routing_strategy: usage-based-routing-v2
|
||||
# redis_host: os.environ/REDIS_HOST
|
||||
# redis_password: os.environ/REDIS_PASSWORD
|
||||
# redis_port: os.environ/REDIS_PORT
|
||||
# enable_pre_call_checks: True
|
||||
router_settings:
|
||||
routing_strategy: usage-based-routing-v2
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
enable_pre_call_checks: True
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue