mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
fix(main.py): fix lm_studio/ embedding routing (#7658)
* fix(main.py): fix lm_studio/ embedding routing adds the mapping + updates docs with example * docs(self_serve.md): update doc to show how to auto-add sso users to teams * fix(streaming_handler.py): simplify async iterator check, to just check if streaming response is an async iterable
This commit is contained in:
parent
14d4b695df
commit
afdcbe3d64
6 changed files with 109 additions and 38 deletions
|
@ -11,6 +11,14 @@ https://lmstudio.ai/docs/basics/server
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
| Property | Details |
|
||||||
|
|-------|-------|
|
||||||
|
| Description | Discover, download, and run local LLMs. |
|
||||||
|
| Provider Route on LiteLLM | `lm_studio/` |
|
||||||
|
| Provider Doc | [LM Studio ↗](https://lmstudio.ai/docs/api/openai-api) |
|
||||||
|
| Supported OpenAI Endpoints | `/chat/completions`, `/embeddings`, `/completions` |
|
||||||
|
|
||||||
## API Key
|
## API Key
|
||||||
```python
|
```python
|
||||||
# env variable
|
# env variable
|
||||||
|
@ -42,7 +50,7 @@ print(response)
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ['XAI_API_KEY'] = ""
|
os.environ['LM_STUDIO_API_KEY'] = ""
|
||||||
response = completion(
|
response = completion(
|
||||||
model="lm_studio/llama-3-8b-instruct",
|
model="lm_studio/llama-3-8b-instruct",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -131,3 +139,17 @@ Here's how to call a XAI model with the LiteLLM Proxy Server
|
||||||
## Supported Parameters
|
## Supported Parameters
|
||||||
|
|
||||||
See [Supported Parameters](../completion/input.md#translated-openai-params) for supported parameters.
|
See [Supported Parameters](../completion/input.md#translated-openai-params) for supported parameters.
|
||||||
|
|
||||||
|
## Embedding
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['LM_STUDIO_API_BASE'] = "http://localhost:8000"
|
||||||
|
response = embedding(
|
||||||
|
model="lm_studio/jina-embeddings-v3",
|
||||||
|
input=["Hello world"],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
|
@ -196,6 +196,41 @@ This budget does not apply to keys created under non-default teams.
|
||||||
|
|
||||||
[**Go Here**](./team_budgets.md)
|
[**Go Here**](./team_budgets.md)
|
||||||
|
|
||||||
|
### Auto-add SSO users to teams
|
||||||
|
|
||||||
|
1. Specify the JWT field that contains the team ids, that the user belongs to.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
litellm_jwtauth:
|
||||||
|
team_ids_jwt_field: "groups" # 👈 CAN BE ANY FIELD
|
||||||
|
```
|
||||||
|
|
||||||
|
This is assuming your SSO token looks like this:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
...,
|
||||||
|
"groups": ["team_id_1", "team_id_2"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create the teams on LiteLLM
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST '<PROXY_BASE_URL>/team/new' \
|
||||||
|
-H 'Authorization: Bearer <PROXY_MASTER_KEY>' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-D '{
|
||||||
|
"team_alias": "team_1",
|
||||||
|
"team_id": "team_id_1" # 👈 MUST BE THE SAME AS THE SSO GROUP ID
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test the SSO flow
|
||||||
|
|
||||||
|
Here's a walkthrough of [how it works](https://www.loom.com/share/8959be458edf41fd85937452c29a33f3?sid=7ebd6d37-569a-4023-866e-e0cde67cb23e)
|
||||||
|
|
||||||
## **All Settings for Self Serve / SSO Flow**
|
## **All Settings for Self Serve / SSO Flow**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import collections.abc
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
@ -34,6 +35,19 @@ MAX_THREADS = 100
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
executor = ThreadPoolExecutor(max_workers=MAX_THREADS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_async_iterable(obj: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an object is an async iterable (can be used with 'async for').
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Any Python object to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the object is async iterable, False otherwise
|
||||||
|
"""
|
||||||
|
return isinstance(obj, collections.abc.AsyncIterable)
|
||||||
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
|
@ -1530,36 +1544,7 @@ class CustomStreamWrapper:
|
||||||
if self.completion_stream is None:
|
if self.completion_stream is None:
|
||||||
await self.fetch_stream()
|
await self.fetch_stream()
|
||||||
|
|
||||||
if (
|
if is_async_iterable(self.completion_stream):
|
||||||
self.custom_llm_provider == "openai"
|
|
||||||
or self.custom_llm_provider == "azure"
|
|
||||||
or self.custom_llm_provider == "custom_openai"
|
|
||||||
or self.custom_llm_provider == "text-completion-openai"
|
|
||||||
or self.custom_llm_provider == "text-completion-codestral"
|
|
||||||
or self.custom_llm_provider == "azure_text"
|
|
||||||
or self.custom_llm_provider == "cohere_chat"
|
|
||||||
or self.custom_llm_provider == "cohere"
|
|
||||||
or self.custom_llm_provider == "anthropic"
|
|
||||||
or self.custom_llm_provider == "anthropic_text"
|
|
||||||
or self.custom_llm_provider == "huggingface"
|
|
||||||
or self.custom_llm_provider == "ollama"
|
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
|
||||||
or self.custom_llm_provider == "vertex_ai_beta"
|
|
||||||
or self.custom_llm_provider == "sagemaker"
|
|
||||||
or self.custom_llm_provider == "sagemaker_chat"
|
|
||||||
or self.custom_llm_provider == "gemini"
|
|
||||||
or self.custom_llm_provider == "replicate"
|
|
||||||
or self.custom_llm_provider == "cached_response"
|
|
||||||
or self.custom_llm_provider == "predibase"
|
|
||||||
or self.custom_llm_provider == "databricks"
|
|
||||||
or self.custom_llm_provider == "bedrock"
|
|
||||||
or self.custom_llm_provider == "triton"
|
|
||||||
or self.custom_llm_provider == "watsonx"
|
|
||||||
or self.custom_llm_provider == "cloudflare"
|
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_providers
|
|
||||||
or self.custom_llm_provider in litellm._custom_providers
|
|
||||||
):
|
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
if chunk == "None" or chunk is None:
|
if chunk == "None" or chunk is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
|
@ -3218,6 +3218,7 @@ def embedding( # noqa: PLR0915
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dynamic_api_key is not None:
|
if dynamic_api_key is not None:
|
||||||
api_key = dynamic_api_key
|
api_key = dynamic_api_key
|
||||||
|
|
||||||
|
@ -3395,18 +3396,19 @@ def embedding( # noqa: PLR0915
|
||||||
custom_llm_provider == "openai_like"
|
custom_llm_provider == "openai_like"
|
||||||
or custom_llm_provider == "jina_ai"
|
or custom_llm_provider == "jina_ai"
|
||||||
or custom_llm_provider == "hosted_vllm"
|
or custom_llm_provider == "hosted_vllm"
|
||||||
|
or custom_llm_provider == "lm_studio"
|
||||||
):
|
):
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
|
||||||
)
|
)
|
||||||
|
|
||||||
# set API KEY
|
# set API KEY
|
||||||
api_key = (
|
if api_key is None:
|
||||||
api_key
|
api_key = (
|
||||||
or litellm.api_key
|
litellm.api_key
|
||||||
or litellm.openai_like_key
|
or litellm.openai_like_key
|
||||||
or get_secret_str("OPENAI_LIKE_API_KEY")
|
or get_secret_str("OPENAI_LIKE_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = openai_like_embedding.embedding(
|
response = openai_like_embedding.embedding(
|
||||||
|
|
|
@ -25,10 +25,15 @@ model_list:
|
||||||
identifier: deepseek-ai/DeepSeek-V3-Base
|
identifier: deepseek-ai/DeepSeek-V3-Base
|
||||||
revision: main
|
revision: main
|
||||||
auth_token: os.environ/HUGGINGFACE_API_KEY
|
auth_token: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
- model_name: watsonx/ibm/granite-13b-chat-v2 # tried to keep original name for backwards compatibility but I've also tried watsonx_text
|
||||||
|
litellm_params:
|
||||||
|
model: watsonx_text/ibm/granite-13b-chat-v2
|
||||||
|
model_info:
|
||||||
|
input_cost_per_token: 0.0000006
|
||||||
|
output_cost_per_token: 0.0000006
|
||||||
|
|
||||||
# litellm_settings:
|
# litellm_settings:
|
||||||
# key_generation_settings:
|
# key_generation_settings:
|
||||||
# personal_key_generation: # maps to 'Default Team' on UI
|
# personal_key_generation: # maps to 'Default Team' on UI
|
||||||
# allowed_user_roles: ["proxy_admin"]
|
# allowed_user_roles: ["proxy_admin"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1019,6 +1019,28 @@ def test_hosted_vllm_embedding(monkeypatch):
|
||||||
assert json_data["model"] == "jina-embeddings-v3"
|
assert json_data["model"] == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lm_studio_embedding(monkeypatch):
|
||||||
|
monkeypatch.setenv("LM_STUDIO_API_BASE", "http://localhost:8000")
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
embedding(
|
||||||
|
model="lm_studio/jina-embeddings-v3",
|
||||||
|
input=["Hello world"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert json_data["input"] == ["Hello world"]
|
||||||
|
assert json_data["model"] == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue