feat: Add Ollama as a provider in the proxy UI

This commit is contained in:
sha-ahammed 2024-06-05 16:48:38 +05:30
parent 499933f943
commit faa4dfe03e
3 changed files with 21 additions and 1 deletions

View file

@ -2,8 +2,9 @@ from itertools import chain
import requests, types, time # type: ignore import requests, types, time # type: ignore
import json, uuid import json, uuid
import traceback import traceback
from typing import Optional from typing import Optional, List
import litellm import litellm
from litellm.types.utils import ProviderField
import httpx, aiohttp, asyncio # type: ignore import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -124,6 +125,18 @@ class OllamaConfig:
) )
and v is not None and v is not None
} }
def get_required_params(self) -> List[ProviderField]:
"""For a given provider, return it's required fields with a description"""
return [
ProviderField(
field_name="base_url",
field_type="string",
field_description="Your Ollama API Base",
field_value="http://10.10.11.249:11434",
)
]
def get_supported_openai_params( def get_supported_openai_params(
self, self,
): ):

View file

@ -7344,6 +7344,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
if custom_llm_provider == "databricks": if custom_llm_provider == "databricks":
return litellm.DatabricksConfig().get_required_params() return litellm.DatabricksConfig().get_required_params()
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_required_params()
else: else:
return [] return []

View file

@ -145,6 +145,7 @@ enum Providers {
OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)", OpenAI_Compatible = "OpenAI-Compatible Endpoints (Groq, Together AI, Mistral AI, etc.)",
Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)", Vertex_AI = "Vertex AI (Anthropic, Gemini, etc.)",
Databricks = "Databricks", Databricks = "Databricks",
Ollama = "Ollama",
} }
const provider_map: Record<string, string> = { const provider_map: Record<string, string> = {
@ -156,6 +157,7 @@ const provider_map: Record<string, string> = {
OpenAI_Compatible: "openai", OpenAI_Compatible: "openai",
Vertex_AI: "vertex_ai", Vertex_AI: "vertex_ai",
Databricks: "databricks", Databricks: "databricks",
Ollama: "ollama",
}; };
const retry_policy_map: Record<string, string> = { const retry_policy_map: Record<string, string> = {
@ -1747,6 +1749,7 @@ const ModelDashboard: React.FC<ModelDashboardProps> = ({
)} )}
{selectedProvider != Providers.Bedrock && {selectedProvider != Providers.Bedrock &&
selectedProvider != Providers.Vertex_AI && selectedProvider != Providers.Vertex_AI &&
selectedProvider != Providers.Ollama &&
(dynamicProviderForm === undefined || (dynamicProviderForm === undefined ||
dynamicProviderForm.fields.length == 0) && ( dynamicProviderForm.fields.length == 0) && (
<Form.Item <Form.Item