Merge pull request #2720 from onukura/ollama-batch-embedding

Batch embedding for Ollama
This commit is contained in:
Krish Dholakia 2024-03-28 14:58:55 -07:00 committed by GitHub
commit 28905c85b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 52 additions and 62 deletions

View file

@ -344,9 +344,9 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
async def ollama_aembeddings( async def ollama_aembeddings(
api_base="http://localhost:11434", api_base: str,
model="llama2", model: str,
prompt="Why is the sky blue?", prompts: list[str],
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
model_response=None, model_response=None,
@ -365,52 +365,52 @@ async def ollama_aembeddings(
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v optional_params[k] = v
data = { total_input_tokens = 0
"model": model, output_data = []
"prompt": prompt,
}
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
)
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
response = await session.post(url, json=data) for idx, prompt in enumerate(prompts):
data = {
if response.status != 200: "model": model,
text = await response.text() "prompt": prompt,
raise OllamaError(status_code=response.status, message=text) }
## LOGGING
## LOGGING logging_obj.pre_call(
logging_obj.post_call( input=None,
input=prompt, api_key=None,
api_key="", additional_args={"api_base": url, "complete_input_dict": data, "headers": {}},
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = await response.json()
embeddings = response_json["embedding"]
embeddings = [embeddings] # Ollama currently does not support batch embedding
## RESPONSE OBJECT
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
) )
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = len(encoding.encode(prompt)) response = await session.post(url, json=data)
if response.status != 200:
text = await response.text()
raise OllamaError(status_code=response.status, message=text)
model_response["usage"] = { ## LOGGING
"prompt_tokens": input_tokens, logging_obj.post_call(
"total_tokens": input_tokens, input=prompt,
} api_key="",
return model_response original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = await response.json()
embeddings: list[float] = response_json["embedding"]
output_data.append(
{"object": "embedding", "index": idx, "embedding": embeddings}
)
input_tokens = len(encoding.encode(prompt))
total_input_tokens += input_tokens
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": total_input_tokens,
"total_tokens": total_input_tokens,
}
return model_response

View file

@ -2796,29 +2796,19 @@ def embedding(
or get_secret("OLLAMA_API_BASE") or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434" or "http://localhost:11434"
) )
ollama_input = None if isinstance(input ,str):
if isinstance(input, list) and len(input) > 1: input = [input]
raise litellm.BadRequestError( if not all(isinstance(item, str) for item in input):
message=f"Ollama Embeddings don't support batch embeddings",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if isinstance(input, list) and len(input) == 1:
ollama_input = "".join(input[0])
elif isinstance(input, str):
ollama_input = input
else:
raise litellm.BadRequestError( raise litellm.BadRequestError(
message=f"Invalid input for ollama embeddings. input={input}", message=f"Invalid input for ollama embeddings. input={input}",
model=model, # type: ignore model=model, # type: ignore
llm_provider="ollama", # type: ignore llm_provider="ollama", # type: ignore
) )
if aembedding:
if aembedding == True:
response = ollama.ollama_aembeddings( response = ollama.ollama_aembeddings(
api_base=api_base, api_base=api_base,
model=model, model=model,
prompt=ollama_input, prompts=input,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,