Merge branch 'BerriAI:main' into feature/watsonx-integration

This commit is contained in:
Simon S. Viloria 2024-04-20 21:02:54 +02:00 committed by GitHub
commit 7b2bd2e0e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 12 additions and 4 deletions

View file

@ -110,7 +110,7 @@ def _get_redis_client_logic(**env_overrides):
redis_kwargs.pop("password", None) redis_kwargs.pop("password", None)
elif "host" not in redis_kwargs or redis_kwargs["host"] is None: elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
raise ValueError("Either 'host' or 'url' must be specified for redis.") raise ValueError("Either 'host' or 'url' must be specified for redis.")
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}") # litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
return redis_kwargs return redis_kwargs
@ -142,6 +142,7 @@ def get_redis_async_client(**env_overrides):
) )
) )
return async_redis.Redis.from_url(**url_kwargs) return async_redis.Redis.from_url(**url_kwargs)
return async_redis.Redis( return async_redis.Redis(
socket_timeout=5, socket_timeout=5,
**redis_kwargs, **redis_kwargs,
@ -154,4 +155,9 @@ def get_redis_connection_pool(**env_overrides):
return async_redis.BlockingConnectionPool.from_url( return async_redis.BlockingConnectionPool.from_url(
timeout=5, url=redis_kwargs["url"] timeout=5, url=redis_kwargs["url"]
) )
connection_class = async_redis.Connection
if "ssl" in redis_kwargs and redis_kwargs["ssl"] is not None:
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)

View file

@ -228,7 +228,7 @@ def get_ollama_response(
model_response["choices"][0]["message"]["content"] = response_json["response"] model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
@ -330,7 +330,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
] ]
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"]))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", "")))
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,

View file

@ -148,7 +148,7 @@ class OllamaChatConfig:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["repeat_penalty"] = param optional_params["repeat_penalty"] = value
if param == "stop": if param == "stop":
optional_params["stop"] = value optional_params["stop"] = value
if param == "response_format" and value["type"] == "json_object": if param == "response_format" and value["type"] == "json_object":

View file

@ -7927,6 +7927,8 @@ def exception_type(
elif ( elif (
"429 Quota exceeded" in error_str "429 Quota exceeded" in error_str
or "IndexError: list index out of range" in error_str or "IndexError: list index out of range" in error_str
or "429 Unable to submit request because the service is temporarily out of capacity."
in error_str
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise RateLimitError( raise RateLimitError(