diff --git a/litellm/tests/test_router_init.py b/litellm/tests/test_router_init.py index bd70a33d1..3208b70b0 100644 --- a/litellm/tests/test_router_init.py +++ b/litellm/tests/test_router_init.py @@ -200,3 +200,43 @@ def test_stream_timeouts_router(): # test_stream_timeouts_router() + + +def test_xinference_embedding(): + # [Test Init Xinference] this tests if we init xinference on the router correctly + # [Test Exception Mapping] tests that xinference is an openai comptiable provider + print("Testing init xinference") + print( + "this tests if we create an OpenAI client for Xinference, with the correct API BASE" + ) + + model_list = [ + { + "model_name": "xinference", + "litellm_params": { + "model": "xinference/bge-base-en", + "api_base": "os.environ/XINFERENCE_API_BASE", + }, + } + ] + + router = Router(model_list=model_list) + + print(router.model_list) + print(router.model_list[0]) + + assert ( + router.model_list[0]["litellm_params"]["api_base"] == "http://0.0.0.0:9997" + ) # set in env + + openai_client = router._get_client( + deployment=router.model_list[0], + kwargs={"input": ["hello"], "model": "xinference"}, + ) + + assert openai_client._base_url == "http://0.0.0.0:9997" + assert "xinference" in litellm.openai_compatible_providers + print("passed") + + +# test_xinference_embedding()