Using Vector Database Adapters
Using Elasticsearch
To use Elasticsearch, you need to install the elasticsearch
package.
import embed_anything
import os
from typing import Dict, List
from embed_anything import EmbedData
from embed_anything.vectordb import Adapter
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
class ElasticsearchAdapter(Adapter):
def __init__(self, api_key: str, cloud_id: str, index_name: str = "anything"):
self.es = Elasticsearch(cloud_id=cloud_id, api_key=api_key)
self.index_name = index_name
def create_index(
self, dimension: int, metric: str, mappings={}, settings={}, **kwargs
):
if "index_name" in kwargs:
self.index_name = kwargs["index_name"]
self.es.indices.create(
index=self.index_name, mappings=mappings, settings=settings
)
def convert(self, embeddings: List[List[EmbedData]]) -> List[Dict]:
data = []
for embedding in embeddings:
data.append(
{
"text": embedding.text,
"embeddings": embedding.embedding,
"metadata": {
"file_name": embedding.metadata["file_name"],
"modified": embedding.metadata["modified"],
"created": embedding.metadata["created"],
},
}
)
return data
def delete_index(self, index_name: str):
self.es.indices.delete(index=index_name)
def gendata(self, data):
for doc in data:
yield doc
def upsert(self, data: List[Dict]):
data = self.convert(data)
bulk(client=self.es, index="anything", actions=self.gendata(data))
index_name = "anything"
elastic_api_key = os.environ.get("ELASTIC_API_KEY")
elastic_cloud_id = os.environ.get("ELASTIC_CLOUD_ID")
# Initialize the ElasticsearchAdapter Class
elasticsearch_adapter = ElasticsearchAdapter(
api_key=elastic_api_key,
cloud_id=elastic_cloud_id,
index_name=index_name,
)
# Prase PDF and insert documents into Elasticsearch.
model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)
data = embed_anything.embed_file(
"/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
embedder=model,
adapter=elasticsearch_adapter
)
# Create an Index with explicit mappings.
mappings = {
"properties": {
"embeddings": {"type": "dense_vector", "dims": 384},
"text": {"type": "text"},
}
}
settings = {}
elasticsearch_adapter.create_index(
dimension=384,
metric="cosine",
mappings=mappings,
settings=settings,
)
# Delete an Index
elasticsearch_adapter.delete_index(index_name=index_name)
Using Weaviate
To use Weaviate, you need to install the weaviate-client
package.
import weaviate, os
import weaviate.classes as wvc
from tqdm.auto import tqdm
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel
from embed_anything.vectordb import Adapter
import textwrap
## Weaviate Adapter
from typing import List
class WeaviateAdapter(Adapter):
def __init__(self, api_key, url):
super().__init__(api_key)
self.client = weaviate.connect_to_weaviate_cloud(
cluster_url=url, auth_credentials=wvc.init.Auth.api_key(api_key)
)
if self.client.is_ready():
print("Weaviate is ready")
def create_index(self, index_name: str):
self.index_name = index_name
self.collection = self.client.collections.create(
index_name, vectorizer_config=wvc.config.Configure.Vectorizer.none()
)
return self.collection
def convert(self, embeddings: List[EmbedData]):
data = []
for embedding in embeddings:
property = embedding.metadata
property["text"] = embedding.text
data.append(
wvc.data.DataObject(properties=property, vector=embedding.embedding)
)
return data
def upsert(self, data_):
data_ = self.convert(data_)
self.client.collections.get(self.index_name).data.insert_many(data_)
def delete_index(self, index_name: str):
self.client.collections.delete(index_name)
URL = "URL"
API_KEY = "API_KEY"
weaviate_adapter = WeaviateAdapter(API_KEY, URL)
# create index
index_name = "Test_index"
if index_name in weaviate_adapter.client.collections.list_all():
weaviate_adapter.delete_index(index_name)
weaviate_adapter.create_index("Test_index")
model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)
data = embed_anything.embed_file(
"/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
embedder=model,
adapter=weaviate_adapter,
)
query_vector = embed_anything.embed_query(["What is attention"], embedder=model)[
0
].embedding
response = weaviate_adapter.collection.query.near_vector(
near_vector=query_vector,
limit=2,
return_metadata=wvc.query.MetadataQuery(certainty=True),
)
for i in range(len(response.objects)):
print(response.objects[i].properties["text"])
for res in response.objects:
print(textwrap.fill(res.properties["text"], width=120), end="\n\n")
Using Pinecone
To use Pinecone, you need to install the pinecone
package.
import re
from typing import Dict, List
import uuid
import embed_anything
import os
from embed_anything.vectordb import Adapter
from pinecone import Pinecone, ServerlessSpec
from embed_anything import EmbedData, EmbeddingModel, WhichModel, TextEmbedConfig
class PineconeAdapter(Adapter):
"""
Adapter class for interacting with Pinecone, a vector database service.
"""
def __init__(self, api_key: str):
"""
Initializes a new instance of the PineconeAdapter class.
Args:
api_key (str): The API key for accessing the Pinecone service.
"""
super().__init__(api_key)
self.pc = Pinecone(api_key=self.api_key)
self.index_name = None
def create_index(
self,
dimension: int,
metric: str = "cosine",
index_name: str = "anything",
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
):
"""
Creates a new index in Pinecone.
Args:
dimension (int): The dimensionality of the embeddings.
metric (str, optional): The distance metric to use for similarity search. Defaults to "cosine".
index_name (str, optional): The name of the index. Defaults to "anything".
spec (ServerlessSpec, optional): The serverless specification for the index. Defaults to AWS in us-east-1 region.
"""
self.index_name = index_name
self.pc.create_index(
name=index_name, dimension=dimension, metric=metric, spec=spec
)
def delete_index(self, index_name: str):
"""
Deletes an existing index from Pinecone.
Args:
index_name (str): The name of the index to delete.
"""
self.pc.delete_index(name=index_name)
def convert(self, embeddings: List[EmbedData]) -> List[Dict]:
"""
Converts a list of embeddings into the required format for upserting into Pinecone.
Args:
embeddings (List[EmbedData]): The list of embeddings to convert.
Returns:
List[Dict]: The converted data in the required format for upserting into Pinecone.
"""
data_emb = []
for embedding in embeddings:
data_emb.append(
{
"id": str(uuid.uuid4()),
"values": embedding.embedding,
"metadata": {
"text": embedding.text,
"file": re.split(
r"/|\\", embedding.metadata.get("file_name", "")
)[-1],
},
}
)
return data_emb
def upsert(self, data: List[Dict]):
"""
Upserts data into the specified index in Pinecone.
Args:
data (List[Dict]): The data to upsert into Pinecone.
Raises:
ValueError: If the index has not been created before upserting data.
"""
data = self.convert(data)
if not self.index_name:
raise ValueError("Index must be created before upserting data")
self.pc.Index(name=self.index_name).upsert(data)
# Initialize the PineconeEmbedder class
api_key = os.environ.get("PINECONE_API_KEY")
index_name = "anything"
pinecone_adapter = PineconeAdapter(api_key)
try:
pinecone_adapter.delete_index("anything")
except:
pass
# Initialize the PineconeEmbedder class
pinecone_adapter.create_index(dimension=512, metric="cosine")
model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)
data = embed_anything.embed_file(
"/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
embedder=model,
adapter=pinecone_adapter,
)
data = embed_anything.embed_image_directory(
"test_files",
embedder=model,
adapter=pinecone_adapter
)
print(data)
Using Qdrant
To use Qdrant, you need to install the qdrant-client
package.
import uuid
from typing import List, Dict
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
)
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, WhichModel
from embed_anything.vectordb import Adapter
class QdrantAdapter(Adapter):
"""
Adapter class for interacting with [Qdrant](https://qdrant.tech/).
"""
def __init__(self, client: QdrantClient):
"""
Initializes a new instance of the QdrantAdapter class.
Args:
client : An instance of qdrant_client.QdrantClient
"""
self.client = client
def create_index(
self,
dimension: int,
metric: Distance = Distance.COSINE,
index_name: str = "embed-anything",
**kwargs,
):
self.collection_name = index_name
if not self.client.collection_exists(index_name):
self.client.create_collection(
collection_name=index_name,
vectors_config=VectorParams(size=dimension, distance=metric),
)
def delete_index(self, index_name: str):
self.client.delete_collection(collection_name=index_name)
def convert(self, embeddings: List[EmbedData]) -> List[PointStruct]:
points = []
for embedding in embeddings:
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding.embedding,
payload={
"text": embedding.text,
"file_name": embedding.metadata["file_name"],
"modified": embedding.metadata["modified"],
"created": embedding.metadata["created"],
},
)
)
return points
def upsert(self, data: List[Dict]):
points = self.convert(data)
self.client.upsert(
collection_name=self.collection_name,
points=points,
)
def main():
adapter = QdrantAdapter(QdrantClient(location=":memory:"))
adapter.create_index(dimension=384)
model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)
embed_anything.embed_file(
"test_files/attention.pdf",
embedder=model,
adapter=adapter,
)
if __name__ == "__main__":
main()
Using Milvus
To use Milvus, you need to install the pymilvus
package.
from pymilvus import MilvusClient, DataType
import os
from typing import Dict, List
import embed_anything
from embed_anything.vectordb import Adapter
from embed_anything import EmbedData, EmbeddingModel, WhichModel
print("Milvus Vector DB - Adapter")
# Default embedding dimension
EMBEDDINGS_DIM = 384
# Maximum VARCHAR field length for text content
TEXT_CONTENT_VARCHARS = 4098
# Type annotation for embeddings
VectorEmbeddings = List[List[EmbedData]]
class MilvusVectorAdapter(Adapter):
def __init__(self, uri: str = './milvus.db', token: str = '', collection_name: str = "embed_anything_collection"):
"""
Initialize the MilvusVectorAdapter.
Args:
uri (str): The URI to connect to, comes in the form of
"https://address:port" for Milvus or Zilliz Cloud service,
or "path/to/local/milvus.db" for the lite local Milvus. Defaults to
"./milvus.db".
token (str): The token for log in. Defaults to "".
collection_name (str): Name of the collection to use. Defaults to
"embed_anything_collection".
"""
self.collection_name = collection_name
self.client = MilvusClient(uri=uri, token=token)
print("Ok - Milvus DB connection established.")
def create_index(self, dimension: int = EMBEDDINGS_DIM):
"""
Create a collection and index for embeddings.
Args:
dimension: Dimension of the embedding vectors.
**kwargs: Additional parameters for index creation.
"""
# Delete collection if it exists
if self.client.has_collection(self.collection_name):
self.delete_index()
# Create collection schema
schema = self.client.create_schema(auto_id=True)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(
field_name="embeddings",
datatype=DataType.FLOAT_VECTOR,
dim=dimension
)
schema.add_field(
field_name="text",
datatype=DataType.VARCHAR,
max_length=TEXT_CONTENT_VARCHARS
)
schema.add_field(
field_name="file_name",
datatype=DataType.VARCHAR,
max_length=255
)
schema.add_field(
field_name="modified",
datatype=DataType.VARCHAR,
max_length=50
)
schema.add_field(
field_name="created",
datatype=DataType.VARCHAR,
max_length=50
)
# Create the collection
self.client.create_collection(
collection_name=self.collection_name,
schema=schema
)
# Create the index
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="embeddings",
index_type="IVF_FLAT",
metric_type="L2",
params={"nlist": 1024}
)
# Apply the index
self.client.create_index(
collection_name=self.collection_name,
index_params=index_params
)
# Load the collection
self.client.load_collection(
collection_name=self.collection_name
)
print(f"Collection '{self.collection_name}' created with index.")
def convert(self, embeddings: List[EmbedData]) -> List[Dict]:
"""
Convert EmbedData objects to a format compatible with Milvus.
Args:
embeddings: List of EmbedData objects.
Returns:
List of dictionaries with data formatted for Milvus.
"""
ret_data = []
for i, embedding in enumerate(embeddings):
data_dict = {
"embeddings": embedding.embedding,
"text": embedding.text,
"file_name": embedding.metadata["file_name"],
"modified": embedding.metadata["modified"],
"created": embedding.metadata["created"],
}
ret_data.append(data_dict)
print(f"Converted {len(ret_data)} embeddings for insertion.")
return ret_data
def delete_index(self):
"""
Delete the collection and its index.
"""
try:
self.client.drop_collection(self.collection_name)
print(f"Collection '{self.collection_name}' dropped.")
except Exception as e:
print(f"Failed to drop collection: {e}")
def upsert(self, data: List[EmbedData]):
"""
Insert or update embeddings in the collection.
Args:
data: List of EmbedData objects to insert.
"""
# Convert data to Milvus format
formatted_data = self.convert(data)
# Insert data
self.client.insert(
collection_name=self.collection_name,
data=formatted_data
)
print(f"Successfully inserted {len(formatted_data)} embeddings.")
if __name__ == "__main__":
# Initialize the MilvusVectorAdapter class
index_name = "embed_anything_milvus_collection"
milvus_adapter = MilvusVectorAdapter(uri='./milvus.db', collection_name=index_name)
# Delete existing index if it exists
try:
milvus_adapter.delete_index(index_name)
except:
pass
# Create a new index
milvus_adapter.create_index()
# Initialize the embedding model
model = EmbeddingModel.from_pretrained_hf(
WhichModel.Bert,
model_id="sentence-transformers/all-MiniLM-L12-v2"
)
# Embed a PDF file
data = embed_anything.embed_file(
"/Users/jinhonglin/Desktop/sample.pdf",
embedder=model,
adapter=milvus_adapter,
)