From 9c35a3b55404c54112364a928908f8ba0a9ba0a4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 30 Nov 2024 01:10:37 -0800 Subject: [PATCH] test: fix nvidia nim test --- tests/llm_translation/test_nvidia_nim.py | 134 ++++++++++------------- 1 file changed, 59 insertions(+), 75 deletions(-) diff --git a/tests/llm_translation/test_nvidia_nim.py b/tests/llm_translation/test_nvidia_nim.py index 76cb5764c..52ef1043f 100644 --- a/tests/llm_translation/test_nvidia_nim.py +++ b/tests/llm_translation/test_nvidia_nim.py @@ -12,6 +12,7 @@ sys.path.insert( import httpx import pytest from respx import MockRouter +from unittest.mock import patch, MagicMock, AsyncMock import litellm from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage @@ -19,88 +20,71 @@ from litellm import completion @pytest.mark.respx -def test_completion_nvidia_nim(respx_mock: MockRouter): +def test_completion_nvidia_nim(): + from openai import OpenAI + litellm.set_verbose = True - mock_response = ModelResponse( - id="cmpl-mock", - choices=[Choices(message=Message(content="Mocked response", role="assistant"))], - created=int(datetime.now().timestamp()), - model="databricks/dbrx-instruct", - ) model_name = "nvidia_nim/databricks/dbrx-instruct" + client = OpenAI( + api_key="fake-api-key", + ) - mock_request = respx_mock.post( - "https://integrate.api.nvidia.com/v1/chat/completions" - ).mock(return_value=httpx.Response(200, json=mock_response.dict())) - try: - response = completion( - model=model_name, - messages=[ - { - "role": "user", - "content": "What's the weather like in Boston today in Fahrenheit?", - } - ], - presence_penalty=0.5, - frequency_penalty=0.1, - ) + with patch.object( + client.chat.completions.with_raw_response, "create" + ) as mock_client: + try: + completion( + model=model_name, + messages=[ + { + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + } + ], + presence_penalty=0.5, + frequency_penalty=0.1, + client=client, + ) + except Exception as e: + print(e) # Add any assertions here to check the response - print(response) - assert response.choices[0].message.content is not None - assert len(response.choices[0].message.content) > 0 - assert mock_request.called - request_body = json.loads(mock_request.calls[0].request.content) + mock_client.assert_called_once() + request_body = mock_client.call_args.kwargs print("request_body: ", request_body) - assert request_body == { - "messages": [ - { - "role": "user", - "content": "What's the weather like in Boston today in Fahrenheit?", - } - ], - "model": "databricks/dbrx-instruct", - "frequency_penalty": 0.1, - "presence_penalty": 0.5, - } - except litellm.exceptions.Timeout as e: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -def test_embedding_nvidia_nim(respx_mock: MockRouter): - litellm.set_verbose = True - mock_response = EmbeddingResponse( - model="nvidia_nim/databricks/dbrx-instruct", - data=[ + assert request_body["messages"] == [ { - "embedding": [0.1, 0.2, 0.3], - "index": 0, - } - ], - usage=Usage( - prompt_tokens=10, - completion_tokens=0, - total_tokens=10, - ), + "role": "user", + "content": "What's the weather like in Boston today in Fahrenheit?", + }, + ] + assert request_body["model"] == "databricks/dbrx-instruct" + assert request_body["frequency_penalty"] == 0.1 + assert request_body["presence_penalty"] == 0.5 + + +def test_embedding_nvidia_nim(): + litellm.set_verbose = True + from openai import OpenAI + + client = OpenAI( + api_key="fake-api-key", ) - mock_request = respx_mock.post( - "https://integrate.api.nvidia.com/v1/embeddings" - ).mock(return_value=httpx.Response(200, json=mock_response.dict())) - response = litellm.embedding( - model="nvidia_nim/nvidia/nv-embedqa-e5-v5", - input="What is the meaning of life?", - input_type="passage", - ) - assert mock_request.called - request_body = json.loads(mock_request.calls[0].request.content) - print("request_body: ", request_body) - assert request_body == { - "input": "What is the meaning of life?", - "model": "nvidia/nv-embedqa-e5-v5", - "input_type": "passage", - "encoding_format": "base64", - } + with patch.object(client.embeddings.with_raw_response, "create") as mock_client: + try: + litellm.embedding( + model="nvidia_nim/nvidia/nv-embedqa-e5-v5", + input="What is the meaning of life?", + input_type="passage", + client=client, + ) + except Exception as e: + print(e) + mock_client.assert_called_once() + request_body = mock_client.call_args.kwargs + print("request_body: ", request_body) + assert request_body["input"] == "What is the meaning of life?" + assert request_body["model"] == "nvidia/nv-embedqa-e5-v5" + assert request_body["extra_body"]["input_type"] == "passage"