mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(nvidia_nim/embed.py): add 'dimensions' support (#8302)
* fix(nvidia_nim/embed.py): add 'dimensions' support Fixes https://github.com/BerriAI/litellm/issues/8238 * fix(proxy_Server.py): initialize router redis cache if setup on proxy Fixes https://github.com/BerriAI/litellm/issues/6602 * test: add unit testing for new helper function
This commit is contained in:
parent
942446d826
commit
024237077b
5 changed files with 36 additions and 2 deletions
|
@ -58,7 +58,7 @@ class NvidiaNimEmbeddingConfig:
|
||||||
def get_supported_openai_params(
|
def get_supported_openai_params(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
return ["encoding_format", "user"]
|
return ["encoding_format", "user", "dimensions"]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
|
@ -73,6 +73,8 @@ class NvidiaNimEmbeddingConfig:
|
||||||
optional_params["extra_body"].update({"input_type": v})
|
optional_params["extra_body"].update({"input_type": v})
|
||||||
elif k == "truncate":
|
elif k == "truncate":
|
||||||
optional_params["extra_body"].update({"truncate": v})
|
optional_params["extra_body"].update({"truncate": v})
|
||||||
|
else:
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
if kwargs is not None:
|
if kwargs is not None:
|
||||||
# pass kwargs in extra_body
|
# pass kwargs in extra_body
|
||||||
|
|
|
@ -1631,7 +1631,7 @@ class ProxyConfig:
|
||||||
self,
|
self,
|
||||||
cache_params: dict,
|
cache_params: dict,
|
||||||
):
|
):
|
||||||
global redis_usage_cache
|
global redis_usage_cache, llm_router
|
||||||
from litellm import Cache
|
from litellm import Cache
|
||||||
|
|
||||||
if "default_in_memory_ttl" in cache_params:
|
if "default_in_memory_ttl" in cache_params:
|
||||||
|
@ -1646,6 +1646,10 @@ class ProxyConfig:
|
||||||
## INIT PROXY REDIS USAGE CLIENT ##
|
## INIT PROXY REDIS USAGE CLIENT ##
|
||||||
redis_usage_cache = litellm.cache.cache
|
redis_usage_cache = litellm.cache.cache
|
||||||
|
|
||||||
|
## INIT ROUTER REDIS CACHE ##
|
||||||
|
if llm_router is not None:
|
||||||
|
llm_router._update_redis_cache(cache=redis_usage_cache)
|
||||||
|
|
||||||
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
async def get_config(self, config_file_path: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Load config file
|
Load config file
|
||||||
|
|
|
@ -573,6 +573,20 @@ class Router:
|
||||||
litellm.amoderation, call_type="moderation"
|
litellm.amoderation, call_type="moderation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_redis_cache(self, cache: RedisCache):
|
||||||
|
"""
|
||||||
|
Update the redis cache for the router, if none set.
|
||||||
|
|
||||||
|
Allows proxy user to just do
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
cache: true
|
||||||
|
```
|
||||||
|
and caching to just work.
|
||||||
|
"""
|
||||||
|
if self.cache.redis_cache is None:
|
||||||
|
self.cache.redis_cache = cache
|
||||||
|
|
||||||
def initialize_assistants_endpoint(self):
|
def initialize_assistants_endpoint(self):
|
||||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||||
|
|
|
@ -77,6 +77,7 @@ def test_embedding_nvidia_nim():
|
||||||
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
model="nvidia_nim/nvidia/nv-embedqa-e5-v5",
|
||||||
input="What is the meaning of life?",
|
input="What is the meaning of life?",
|
||||||
input_type="passage",
|
input_type="passage",
|
||||||
|
dimensions=1024,
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -87,3 +88,4 @@ def test_embedding_nvidia_nim():
|
||||||
assert request_body["input"] == "What is the meaning of life?"
|
assert request_body["input"] == "What is the meaning of life?"
|
||||||
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
|
assert request_body["model"] == "nvidia/nv-embedqa-e5-v5"
|
||||||
assert request_body["extra_body"]["input_type"] == "passage"
|
assert request_body["extra_body"]["input_type"] == "passage"
|
||||||
|
assert request_body["dimensions"] == 1024
|
||||||
|
|
|
@ -384,3 +384,15 @@ def test_router_get_model_access_groups(potential_access_group, expected_result)
|
||||||
model_access_group=potential_access_group
|
model_access_group=potential_access_group
|
||||||
)
|
)
|
||||||
assert access_groups == expected_result
|
assert access_groups == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_redis_cache():
|
||||||
|
router = Router(
|
||||||
|
model_list=[{"model_name": "gemini/*", "litellm_params": {"model": "gemini/*"}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_cache = MagicMock()
|
||||||
|
|
||||||
|
router._update_redis_cache(cache=redis_cache)
|
||||||
|
|
||||||
|
assert router.cache.redis_cache == redis_cache
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue