fix(nvidia_nim/embed.py): add 'dimensions' support (#8302)

* fix(nvidia_nim/embed.py): add 'dimensions' support

Fixes https://github.com/BerriAI/litellm/issues/8238

* fix(proxy_Server.py): initialize router redis cache if setup on proxy

Fixes https://github.com/BerriAI/litellm/issues/6602

* test: add unit testing for new helper function
This commit is contained in:
Krish Dholakia 2025-02-07 16:19:32 -08:00 committed by GitHub
parent 942446d826
commit 024237077b
5 changed files with 36 additions and 2 deletions

View file

@ -58,7 +58,7 @@ class NvidiaNimEmbeddingConfig:
def get_supported_openai_params( def get_supported_openai_params(
self, self,
): ):
return ["encoding_format", "user"] return ["encoding_format", "user", "dimensions"]
def map_openai_params( def map_openai_params(
self, self,
@ -73,6 +73,8 @@ class NvidiaNimEmbeddingConfig:
optional_params["extra_body"].update({"input_type": v}) optional_params["extra_body"].update({"input_type": v})
elif k == "truncate": elif k == "truncate":
optional_params["extra_body"].update({"truncate": v}) optional_params["extra_body"].update({"truncate": v})
else:
optional_params[k] = v
if kwargs is not None: if kwargs is not None:
# pass kwargs in extra_body # pass kwargs in extra_body

View file

@ -1631,7 +1631,7 @@ class ProxyConfig:
self, self,
cache_params: dict, cache_params: dict,
): ):
global redis_usage_cache global redis_usage_cache, llm_router
from litellm import Cache from litellm import Cache
if "default_in_memory_ttl" in cache_params: if "default_in_memory_ttl" in cache_params:
@ -1646,6 +1646,10 @@ class ProxyConfig:
## INIT PROXY REDIS USAGE CLIENT ## ## INIT PROXY REDIS USAGE CLIENT ##
redis_usage_cache = litellm.cache.cache redis_usage_cache = litellm.cache.cache
## INIT ROUTER REDIS CACHE ##
if llm_router is not None:
llm_router._update_redis_cache(cache=redis_usage_cache)
async def get_config(self, config_file_path: Optional[str] = None) -> dict: async def get_config(self, config_file_path: Optional[str] = None) -> dict:
""" """
Load config file Load config file

View file

@ -573,6 +573,20 @@ class Router:
litellm.amoderation, call_type="moderation" litellm.amoderation, call_type="moderation"
) )
def _update_redis_cache(self, cache: RedisCache):
"""
Update the redis cache for the router, if none set.
Allows proxy user to just do
```yaml
litellm_settings:
cache: true
```
and caching to just work.
"""
if self.cache.redis_cache is None:
self.cache.redis_cache = cache
def initialize_assistants_endpoint(self): def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants) self.acreate_assistants = self.factory_function(litellm.acreate_assistants)

View file

@ -77,6 +77,7 @@ def test_embedding_nvidia_nim():
model="nvidia_nim/nvidia/nv-embedqa-e5-v5", model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
input="What is the meaning of life?", input="What is the meaning of life?",
input_type="passage", input_type="passage",
dimensions=1024,
client=client, client=client,
) )
except Exception as e: except Exception as e:
@ -87,3 +88,4 @@ def test_embedding_nvidia_nim():
assert request_body["input"] == "What is the meaning of life?" assert request_body["input"] == "What is the meaning of life?"
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5" assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
assert request_body["extra_body"]["input_type"] == "passage" assert request_body["extra_body"]["input_type"] == "passage"
assert request_body["dimensions"] == 1024

View file

@ -384,3 +384,15 @@ def test_router_get_model_access_groups(potential_access_group, expected_result)
model_access_group=potential_access_group model_access_group=potential_access_group
) )
assert access_groups == expected_result assert access_groups == expected_result
def test_router_redis_cache():
router = Router(
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}}]
)
redis_cache = MagicMock()
router._update_redis_cache(cache=redis_cache)
assert router.cache.redis_cache == redis_cache