mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(main.py): fix lm_studio/ embedding routing (#7658)
* fix(main.py): fix lm_studio/ embedding routing adds the mapping + updates docs with example * docs(self_serve.md): update doc to show how to auto-add sso users to teams * fix(streaming_handler.py): simplify async iterator check, to just check if streaming response is an async iterable
This commit is contained in:
parent
14d4b695df
commit
afdcbe3d64
6 changed files with 109 additions and 38 deletions
|
@ -11,6 +11,14 @@ https://lmstudio.ai/docs/basics/server
|
|||
|
||||
:::
|
||||
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Discover, download, and run local LLMs. |
|
||||
| Provider Route on LiteLLM | `lm_studio/` |
|
||||
| Provider Doc | [LM Studio ↗](https://lmstudio.ai/docs/api/openai-api) |
|
||||
| Supported OpenAI Endpoints | `/chat/completions`, `/embeddings`, `/completions` |
|
||||
|
||||
## API Key
|
||||
```python
|
||||
# env variable
|
||||
|
@ -42,7 +50,7 @@ print(response)
|
|||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['XAI_API_KEY'] = ""
|
||||
os.environ['LM_STUDIO_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="lm_studio/llama-3-8b-instruct",
|
||||
messages=[
|
||||
|
@ -131,3 +139,17 @@ Here's how to call a XAI model with the LiteLLM Proxy Server
|
|||
## Supported Parameters
|
||||
|
||||
See [Supported Parameters](../completion/input.md#translated-openai-params) for supported parameters.
|
||||
|
||||
## Embedding
|
||||
|
||||
```python
|
||||
from litellm import embedding
|
||||
import os
|
||||
|
||||
os.environ['LM_STUDIO_API_BASE'] = "http://localhost:8000"
|
||||
response = embedding(
|
||||
model="lm_studio/jina-embeddings-v3",
|
||||
input=["Hello world"],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
|
|
@ -196,6 +196,41 @@ This budget does not apply to keys created under non-default teams.
|
|||
|
||||
[**Go Here**](./team_budgets.md)
|
||||
|
||||
### Auto-add SSO users to teams
|
||||
|
||||
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||
|
||||
```yaml
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
litellm_jwtauth:
|
||||
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||
```
|
||||
|
||||
This is assuming your SSO token looks like this:
|
||||
```
|
||||
{
|
||||
...,
|
||||
"groups": ["team_id_1", "team_id_2"]
|
||||
}
|
||||
```
|
||||
|
||||
2. Create the teams on LiteLLM
|
||||
|
||||
```bash
|
||||
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-D '{
|
||||
"team_alias": "team_1",
|
||||
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||
}'
|
||||
```
|
||||
|
||||
3. Test the SSO flow
|
||||
|
||||
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||
|
||||
## **All Settings for Self Serve / SSO Flow**
|
||||
|
||||
```yaml
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import collections.abc
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
@ -34,6 +35,19 @@ MAX_THREADS = 100
|
|||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||
|
||||
|
||||
def is_async_iterable(obj: Any) -> bool:
|
||||
"""
|
||||
Check if an object is an async iterable (can be used with 'async for').
|
||||
|
||||
Args:
|
||||
obj: Any Python object to check
|
||||
|
||||
Returns:
|
||||
bool: True if the object is async iterable, False otherwise
|
||||
"""
|
||||
return isinstance(obj, collections.abc.AsyncIterable)
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
|
@ -1530,36 +1544,7 @@ class CustomStreamWrapper:
|
|||
if self.completion_stream is None:
|
||||
await self.fetch_stream()
|
||||
|
||||
if (
|
||||
self.custom_llm_provider == "openai"
|
||||
or self.custom_llm_provider == "azure"
|
||||
or self.custom_llm_provider == "custom_openai"
|
||||
or self.custom_llm_provider == "text-completion-openai"
|
||||
or self.custom_llm_provider == "text-completion-codestral"
|
||||
or self.custom_llm_provider == "azure_text"
|
||||
or self.custom_llm_provider == "cohere_chat"
|
||||
or self.custom_llm_provider == "cohere"
|
||||
or self.custom_llm_provider == "anthropic"
|
||||
or self.custom_llm_provider == "anthropic_text"
|
||||
or self.custom_llm_provider == "huggingface"
|
||||
or self.custom_llm_provider == "ollama"
|
||||
or self.custom_llm_provider == "ollama_chat"
|
||||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "vertex_ai_beta"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "sagemaker_chat"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "replicate"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "predibase"
|
||||
or self.custom_llm_provider == "databricks"
|
||||
or self.custom_llm_provider == "bedrock"
|
||||
or self.custom_llm_provider == "triton"
|
||||
or self.custom_llm_provider == "watsonx"
|
||||
or self.custom_llm_provider == "cloudflare"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_providers
|
||||
or self.custom_llm_provider in litellm._custom_providers
|
||||
):
|
||||
if is_async_iterable(self.completion_stream):
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
|
|
@ -3218,6 +3218,7 @@ def embedding( # noqa: PLR0915
|
|||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if dynamic_api_key is not None:
|
||||
api_key = dynamic_api_key
|
||||
|
||||
|
@ -3395,15 +3396,16 @@ def embedding( # noqa: PLR0915
|
|||
custom_llm_provider == "openai_like"
|
||||
or custom_llm_provider == "jina_ai"
|
||||
or custom_llm_provider == "hosted_vllm"
|
||||
or custom_llm_provider == "lm_studio"
|
||||
):
|
||||
api_base = (
|
||||
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||
)
|
||||
|
||||
# set API KEY
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
litellm.api_key
|
||||
or litellm.openai_like_key
|
||||
or get_secret_str("OPENAI_LIKE_API_KEY")
|
||||
)
|
||||
|
|
|
@ -25,10 +25,15 @@ model_list:
|
|||
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||
revision: main
|
||||
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||
- model_name: watsonx/ibm/granite-13b-chat-v2 # tried to keep original name for backwards compatibility but I've also tried watsonx_text
|
||||
litellm_params:
|
||||
model: watsonx_text/ibm/granite-13b-chat-v2
|
||||
model_info:
|
||||
input_cost_per_token: 0.0000006
|
||||
output_cost_per_token: 0.0000006
|
||||
|
||||
# litellm_settings:
|
||||
# key_generation_settings:
|
||||
# personal_key_generation: # maps to 'Default Team' on UI
|
||||
# allowed_user_roles: ["proxy_admin"]
|
||||
|
||||
|
||||
|
|
|
@ -1019,6 +1019,28 @@ def test_hosted_vllm_embedding(monkeypatch):
|
|||
assert json_data["model"] == "jina-embeddings-v3"
|
||||
|
||||
|
||||
def test_lm_studio_embedding(monkeypatch):
|
||||
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
embedding(
|
||||
model="lm_studio/jina-embeddings-v3",
|
||||
input=["Hello world"],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert json_data["input"] == ["Hello world"]
|
||||
assert json_data["model"] == "jina-embeddings-v3"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue