litellm-mirror/tests/local_testing/test_ollama.py
Krish Dholakia 1e011b66d3
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 13s
Ollama ssl verify = False + Spend Logs reliability fixes (#7931)
* fix(http_handler.py): support passing ssl verify dynamically and using the correct httpx client based on passed ssl verify param

Fixes https://github.com/BerriAI/litellm/issues/6499

* feat(llm_http_handler.py): support passing `ssl_verify=False` dynamically in call args

Closes https://github.com/BerriAI/litellm/issues/6499

* fix(proxy/utils.py): prevent bad logs from breaking all cost tracking + reset list regardless of success/failure

prevents malformed logs from causing all spend tracking to break since they're constantly retried

* test(test_proxy_utils.py): add test to ensure bad log is dropped

* test(test_proxy_utils.py): ensure in-memory spend logs reset after bad log error

* test(test_user_api_key_auth.py): add unit test to ensure end user id as str works

* fix(auth_utils.py): ensure extracted end user id is always a str

prevents db cost tracking errors

* test(test_auth_utils.py): ensure get end user id from request body always returns a string

* test: update tests

* test: skip bedrock test- behaviour now supported

* test: fix testing

* refactor(spend_tracking_utils.py): reduce size of get_logging_payload

* test: fix test

* bump: version 1.59.4 → 1.59.5

* Revert "bump: version 1.59.4 → 1.59.5"

This reverts commit 1182b46b2e.

* fix(utils.py): fix spend logs retry logic

* fix(spend_tracking_utils.py): fix get tags

* fix(spend_tracking_utils.py): fix end user id spend tracking on pass-through endpoints
2025-01-23 23:05:41 -08:00

240 lines
6.6 KiB
Python

import asyncio
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest import mock
import pytest
import litellm
## for ollama we can't test making the completion call
from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params
def test_get_ollama_params():
try:
converted_params = get_optional_params(
custom_llm_provider="ollama",
model="llama2",
max_tokens=20,
temperature=0.5,
stream=True,
)
print("Converted params", converted_params)
assert converted_params == {
"num_predict": 20,
"stream": True,
"temperature": 0.5,
}, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_get_ollama_params()
def test_get_ollama_model():
try:
model, custom_llm_provider, _, _ = get_llm_provider("ollama/code-llama-22")
print("Model", "custom_llm_provider", model, custom_llm_provider)
assert custom_llm_provider == "ollama", f"{custom_llm_provider} != ollama"
assert model == "code-llama-22", f"{model} != code-llama-22"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_get_ollama_model()
def test_ollama_json_mode():
# assert that format: json gets passed as is to ollama
try:
converted_params = get_optional_params(
custom_llm_provider="ollama", model="llama2", format="json", temperature=0.5
)
print("Converted params", converted_params)
assert converted_params == {
"temperature": 0.5,
"format": "json",
"stream": False,
}, f"{converted_params} != {'temperature': 0.5, 'format': 'json', 'stream': False}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_json_mode()
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
@mock.patch(
"litellm.llms.ollama.completion.handler.ollama_embeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_embeddings(mock_embeddings):
# assert that ollama_embeddings is called with the right parameters
try:
embeddings = litellm.embedding(
model="ollama/nomic-embed-text", input=["hello world"]
)
print(embeddings)
mock_embeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_embeddings()
@mock.patch(
"litellm.llms.ollama.completion.handler.ollama_aembeddings",
return_value=mock_ollama_embedding_response,
)
def test_ollama_aembeddings(mock_aembeddings):
# assert that ollama_aembeddings is called with the right parameters
try:
embeddings = asyncio.run(
litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"])
)
print(embeddings)
mock_aembeddings.assert_called_once_with(
api_base="http://localhost:11434",
model="nomic-embed-text",
prompts=["hello world"],
optional_params=mock.ANY,
logging_obj=mock.ANY,
model_response=mock.ANY,
encoding=mock.ANY,
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_aembeddings()
@pytest.mark.skip(reason="local only test")
def test_ollama_chat_function_calling():
import json
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
},
]
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
response = litellm.completion(
model="ollama_chat/llama3.1",
messages=messages,
tools=tools,
)
tool_calls = response.choices[0].message.get("tool_calls", None)
assert tool_calls is not None
print(json.loads(tool_calls[0].function.arguments))
print(response)
def test_ollama_ssl_verify():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
import ssl
import httpx
try:
response = litellm.completion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
)
except Exception as e:
print(e)
client: HTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"httpx_clientssl_verify_False"
)
test_client = httpx.Client(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.asyncio
async def test_async_ollama_ssl_verify(stream):
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
import httpx
try:
response = await litellm.acompletion(
model="ollama/llama3.1",
messages=[
{
"role": "user",
"content": "What's the weather like in San Francisco?",
}
],
ssl_verify=False,
stream=stream,
)
except Exception as e:
print(e)
client: AsyncHTTPHandler = litellm.in_memory_llm_clients_cache.get_cache(
"async_httpx_clientssl_verify_Falseollama"
)
test_client = httpx.AsyncClient(verify=False)
print(client)
assert (
client.client._transport._pool._ssl_context.verify_mode
== test_client._transport._pool._ssl_context.verify_mode
)