fix(main.py): fix lm_studio/ embedding routing (#7658)

* fix(main.py): fix lm_studio/ embedding routing

adds the mapping + updates docs with example

* docs(self_serve.md): update doc to show how to auto-add sso users to teams

* fix(streaming_handler.py): simplify async iterator check, to just check if streaming response is an async iterable
This commit is contained in:
Krish Dholakia 2025-01-09 23:03:24 -08:00 committed by GitHub
parent 14d4b695df
commit afdcbe3d64
6 changed files with 109 additions and 38 deletions

View file

@ -11,6 +11,14 @@ https://lmstudio.ai/docs/basics/server
::: :::
| Property | Details |
|-------|-------|
| Description | Discover, download, and run local LLMs. |
| Provider Route on LiteLLM | `lm_studio/` |
| Provider Doc | [LM Studio ↗](https://lmstudio.ai/docs/api/openai-api) |
| Supported OpenAI Endpoints | `/chat/completions`, `/embeddings`, `/completions` |
## API Key ## API Key
```python ```python
# env variable # env variable
@ -42,7 +50,7 @@ print(response)
from litellm import completion from litellm import completion
import os import os
os.environ['XAI_API_KEY'] = "" os.environ['LM_STUDIO_API_KEY'] = ""
response = completion( response = completion(
model="lm_studio/llama-3-8b-instruct", model="lm_studio/llama-3-8b-instruct",
messages=[ messages=[
@ -131,3 +139,17 @@ Here's how to call a XAI model with the LiteLLM Proxy Server
## Supported Parameters ## Supported Parameters
See [Supported Parameters](../completion/input.md#translated-openai-params) for supported parameters. See [Supported Parameters](../completion/input.md#translated-openai-params) for supported parameters.
## Embedding
```python
from litellm import embedding
import os
os.environ['LM_STUDIO_API_BASE'] = "http://localhost:8000"
response = embedding(
model="lm_studio/jina-embeddings-v3",
input=["Hello world"],
)
print(response)
```

View file

@ -196,6 +196,41 @@ This budget does not apply to keys created under non-default teams.
[**Go Here**](./team_budgets.md) [**Go Here**](./team_budgets.md)
### Auto-add SSO users to teams
1. Specify the JWT field that contains the team ids, that the user belongs to.
```yaml
general_settings:
master_key: sk-1234
litellm_jwtauth:
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
```
This is assuming your SSO token looks like this:
```
{
...,
"groups": ["team_id_1", "team_id_2"]
}
```
2. Create the teams on LiteLLM
```bash
curl -X POST '<PROXY_BASE_URL>/team/new' \
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
-H 'Content-Type: application/json' \
-D '{
"team_alias": "team_1",
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
}'
```
3. Test the SSO flow
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
## **All Settings for Self Serve / SSO Flow** ## **All Settings for Self Serve / SSO Flow**
```yaml ```yaml

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import collections.abc
import json import json
import threading import threading
import time import time
@ -34,6 +35,19 @@ MAX_THREADS = 100
executor = ThreadPoolExecutor(max_workers=MAX_THREADS) executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
def is_async_iterable(obj: Any) -> bool:
"""
Check if an object is an async iterable (can be used with 'async for').
Args:
obj: Any Python object to check
Returns:
bool: True if the object is async iterable, False otherwise
"""
return isinstance(obj, collections.abc.AsyncIterable)
def print_verbose(print_statement): def print_verbose(print_statement):
try: try:
if litellm.set_verbose: if litellm.set_verbose:
@ -1530,36 +1544,7 @@ class CustomStreamWrapper:
if self.completion_stream is None: if self.completion_stream is None:
await self.fetch_stream() await self.fetch_stream()
if ( if is_async_iterable(self.completion_stream):
self.custom_llm_provider == "openai"
or self.custom_llm_provider == "azure"
or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "text-completion-codestral"
or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "cohere_chat"
or self.custom_llm_provider == "cohere"
or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "anthropic_text"
or self.custom_llm_provider == "huggingface"
or self.custom_llm_provider == "ollama"
or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "vertex_ai_beta"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "sagemaker_chat"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider == "replicate"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase"
or self.custom_llm_provider == "databricks"
or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider == "triton"
or self.custom_llm_provider == "watsonx"
or self.custom_llm_provider == "cloudflare"
or self.custom_llm_provider in litellm.openai_compatible_providers
or self.custom_llm_provider in litellm._custom_providers
):
async for chunk in self.completion_stream: async for chunk in self.completion_stream:
if chunk == "None" or chunk is None: if chunk == "None" or chunk is None:
raise Exception raise Exception

View file

@ -3218,6 +3218,7 @@ def embedding( # noqa: PLR0915
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
) )
if dynamic_api_key is not None: if dynamic_api_key is not None:
api_key = dynamic_api_key api_key = dynamic_api_key
@ -3395,18 +3396,19 @@ def embedding( # noqa: PLR0915
custom_llm_provider == "openai_like" custom_llm_provider == "openai_like"
or custom_llm_provider == "jina_ai" or custom_llm_provider == "jina_ai"
or custom_llm_provider == "hosted_vllm" or custom_llm_provider == "hosted_vllm"
or custom_llm_provider == "lm_studio"
): ):
api_base = ( api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE") api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
) )
# set API KEY # set API KEY
api_key = ( if api_key is None:
api_key api_key = (
or litellm.api_key litellm.api_key
or litellm.openai_like_key or litellm.openai_like_key
or get_secret_str("OPENAI_LIKE_API_KEY") or get_secret_str("OPENAI_LIKE_API_KEY")
) )
## EMBEDDING CALL ## EMBEDDING CALL
response = openai_like_embedding.embedding( response = openai_like_embedding.embedding(

View file

@ -25,10 +25,15 @@ model_list:
identifier: deepseek-ai/DeepSeek-V3-Base identifier: deepseek-ai/DeepSeek-V3-Base
revision: main revision: main
auth_token: os.environ/HUGGINGFACE_API_KEY auth_token: os.environ/HUGGINGFACE_API_KEY
- model_name: watsonx/ibm/granite-13b-chat-v2 # tried to keep original name for backwards compatibility but I've also tried watsonx_text
litellm_params:
model: watsonx_text/ibm/granite-13b-chat-v2
model_info:
input_cost_per_token: 0.0000006
output_cost_per_token: 0.0000006
# litellm_settings: # litellm_settings:
# key_generation_settings: # key_generation_settings:
# personal_key_generation: # maps to 'Default Team' on UI # personal_key_generation: # maps to 'Default Team' on UI
# allowed_user_roles: ["proxy_admin"] # allowed_user_roles: ["proxy_admin"]

View file

@ -1019,6 +1019,28 @@ def test_hosted_vllm_embedding(monkeypatch):
assert json_data["model"] == "jina-embeddings-v3" assert json_data["model"] == "jina-embeddings-v3"
def test_lm_studio_embedding(monkeypatch):
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
embedding(
model="lm_studio/jina-embeddings-v3",
input=["Hello world"],
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["input"] == ["Hello world"]
assert json_data["model"] == "jina-embeddings-v3"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [