test: fix nvidia nim test

This commit is contained in:
Krrish Dholakia 2024-11-30 01:10:37 -08:00
parent e90ff0f350
commit 9c35a3b554

View file

@ -12,6 +12,7 @@ sys.path.insert(
import httpx import httpx
import pytest import pytest
from respx import MockRouter from respx import MockRouter
from unittest.mock import patch, MagicMock, AsyncMock
import litellm import litellm
from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage
@ -19,88 +20,71 @@ from litellm import completion
@pytest.mark.respx @pytest.mark.respx
def test_completion_nvidia_nim(respx_mock: MockRouter): def test_completion_nvidia_nim():
from openai import OpenAI
litellm.set_verbose = True litellm.set_verbose = True
mock_response = ModelResponse(
id="cmpl-mock",
choices=[Choices(message=Message(content="Mocked response", role="assistant"))],
created=int(datetime.now().timestamp()),
model="databricks/dbrx-instruct",
)
model_name = "nvidia_nim/databricks/dbrx-instruct" model_name = "nvidia_nim/databricks/dbrx-instruct"
client = OpenAI(
api_key="fake-api-key",
)
mock_request = respx_mock.post( with patch.object(
"https://integrate.api.nvidia.com/v1/chat/completions" client.chat.completions.with_raw_response, "create"
).mock(return_value=httpx.Response(200, json=mock_response.dict())) ) as mock_client:
try: try:
response = completion( completion(
model=model_name, model=model_name,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?", "content": "What's the weather like in Boston today in Fahrenheit?",
} }
], ],
presence_penalty=0.5, presence_penalty=0.5,
frequency_penalty=0.1, frequency_penalty=0.1,
) client=client,
)
except Exception as e:
print(e)
# Add any assertions here to check the response # Add any assertions here to check the response
print(response)
assert response.choices[0].message.content is not None
assert len(response.choices[0].message.content) > 0
assert mock_request.called mock_client.assert_called_once()
request_body = json.loads(mock_request.calls[0].request.content) request_body = mock_client.call_args.kwargs
print("request_body: ", request_body) print("request_body: ", request_body)
assert request_body == { assert request_body["messages"] == [
"messages": [
{
"role": "user",
"content": "What's the weather like in Boston today in Fahrenheit?",
}
],
"model": "databricks/dbrx-instruct",
"frequency_penalty": 0.1,
"presence_penalty": 0.5,
}
except litellm.exceptions.Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_embedding_nvidia_nim(respx_mock: MockRouter):
litellm.set_verbose = True
mock_response = EmbeddingResponse(
model="nvidia_nim/databricks/dbrx-instruct",
data=[
{ {
"embedding": [0.1, 0.2, 0.3], "role": "user",
"index": 0, "content": "What's the weather like in Boston today in Fahrenheit?",
} },
], ]
usage=Usage( assert request_body["model"] == "databricks/dbrx-instruct"
prompt_tokens=10, assert request_body["frequency_penalty"] == 0.1
completion_tokens=0, assert request_body["presence_penalty"] == 0.5
total_tokens=10,
),
def test_embedding_nvidia_nim():
litellm.set_verbose = True
from openai import OpenAI
client = OpenAI(
api_key="fake-api-key",
) )
mock_request = respx_mock.post( with patch.object(client.embeddings.with_raw_response, "create") as mock_client:
"https://integrate.api.nvidia.com/v1/embeddings" try:
).mock(return_value=httpx.Response(200, json=mock_response.dict())) litellm.embedding(
response = litellm.embedding( model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
model="nvidia_nim/nvidia/nv-embedqa-e5-v5", input="What is the meaning of life?",
input="What is the meaning of life?", input_type="passage",
input_type="passage", client=client,
) )
assert mock_request.called except Exception as e:
request_body = json.loads(mock_request.calls[0].request.content) print(e)
print("request_body: ", request_body) mock_client.assert_called_once()
assert request_body == { request_body = mock_client.call_args.kwargs
"input": "What is the meaning of life?", print("request_body: ", request_body)
"model": "nvidia/nv-embedqa-e5-v5", assert request_body["input"] == "What is the meaning of life?"
"input_type": "passage", assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
"encoding_format": "base64", assert request_body["extra_body"]["input_type"] == "passage"
}