Skip to content

Using Vector Database Adapters

Using Elasticsearch

To use Elasticsearch, you need to install the elasticsearch package.

pip install elasticsearch

import embed_anything
import os

from typing import Dict, List
from embed_anything import EmbedData
from embed_anything.vectordb import Adapter
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel

from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk


class ElasticsearchAdapter(Adapter):

    def __init__(self, api_key: str, cloud_id: str, index_name: str = "anything"):
        self.es = Elasticsearch(cloud_id=cloud_id, api_key=api_key)
        self.index_name = index_name

    def create_index(
        self, dimension: int, metric: str, mappings={}, settings={}, **kwargs
    ):

        if "index_name" in kwargs:
            self.index_name = kwargs["index_name"]

        self.es.indices.create(
            index=self.index_name, mappings=mappings, settings=settings
        )

    def convert(self, embeddings: List[List[EmbedData]]) -> List[Dict]:
        data = []
        for embedding in embeddings:
            data.append(
                {
                    "text": embedding.text,
                    "embeddings": embedding.embedding,
                    "metadata": {
                        "file_name": embedding.metadata["file_name"],
                        "modified": embedding.metadata["modified"],
                        "created": embedding.metadata["created"],
                    },
                }
            )
        return data

    def delete_index(self, index_name: str):
        self.es.indices.delete(index=index_name)

    def gendata(self, data):
        for doc in data:
            yield doc

    def upsert(self, data: List[Dict]):
        data = self.convert(data)
        bulk(client=self.es, index="anything", actions=self.gendata(data))


index_name = "anything"
elastic_api_key = os.environ.get("ELASTIC_API_KEY")
elastic_cloud_id = os.environ.get("ELASTIC_CLOUD_ID")

# Initialize the ElasticsearchAdapter Class
elasticsearch_adapter = ElasticsearchAdapter(
    api_key=elastic_api_key,
    cloud_id=elastic_cloud_id,
    index_name=index_name,
)

# Prase PDF and insert documents into Elasticsearch.
model = EmbeddingModel.from_pretrained_hf(
    WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)


data = embed_anything.embed_file(
    "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
    embedder=model,
    adapter=elasticsearch_adapter
)

# Create an Index with explicit mappings.
mappings = {
    "properties": {
        "embeddings": {"type": "dense_vector", "dims": 384},
        "text": {"type": "text"},
    }
}
settings = {}

elasticsearch_adapter.create_index(
    dimension=384,
    metric="cosine",
    mappings=mappings,
    settings=settings,
)

# Delete an Index
elasticsearch_adapter.delete_index(index_name=index_name)

Using Weaviate

To use Weaviate, you need to install the weaviate-client package.

pip install weaviate-client
import weaviate, os
import weaviate.classes as wvc
from tqdm.auto import tqdm
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, TextEmbedConfig, WhichModel
from embed_anything.vectordb import Adapter
import textwrap

## Weaviate Adapter

from typing import List


class WeaviateAdapter(Adapter):
    def __init__(self, api_key, url):
        super().__init__(api_key)
        self.client = weaviate.connect_to_weaviate_cloud(
            cluster_url=url, auth_credentials=wvc.init.Auth.api_key(api_key)
        )
        if self.client.is_ready():
            print("Weaviate is ready")

    def create_index(self, index_name: str):
        self.index_name = index_name
        self.collection = self.client.collections.create(
            index_name, vectorizer_config=wvc.config.Configure.Vectorizer.none()
        )
        return self.collection

    def convert(self, embeddings: List[EmbedData]):
        data = []
        for embedding in embeddings:
            property = embedding.metadata
            property["text"] = embedding.text
            data.append(
                wvc.data.DataObject(properties=property, vector=embedding.embedding)
            )
        return data

    def upsert(self, data_):
        data_ = self.convert(data_)
        self.client.collections.get(self.index_name).data.insert_many(data_)

    def delete_index(self, index_name: str):
        self.client.collections.delete(index_name)


URL = "URL"
API_KEY = "API_KEY"
weaviate_adapter = WeaviateAdapter(API_KEY, URL)


# create index
index_name = "Test_index"
if index_name in weaviate_adapter.client.collections.list_all():
    weaviate_adapter.delete_index(index_name)
weaviate_adapter.create_index("Test_index")


model = EmbeddingModel.from_pretrained_hf(
    WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)


data = embed_anything.embed_file(
    "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
    embedder=model,
    adapter=weaviate_adapter,
)

query_vector = embed_anything.embed_query(["What is attention"], embedder=model)[
    0
].embedding


response = weaviate_adapter.collection.query.near_vector(
    near_vector=query_vector,
    limit=2,
    return_metadata=wvc.query.MetadataQuery(certainty=True),
)

for i in range(len(response.objects)):
    print(response.objects[i].properties["text"])


for res in response.objects:
    print(textwrap.fill(res.properties["text"], width=120), end="\n\n")

Using Pinecone

To use Pinecone, you need to install the pinecone package.

pip install pinecone
import re
from typing import Dict, List
import uuid
import embed_anything
import os

from embed_anything.vectordb import Adapter
from pinecone import Pinecone, ServerlessSpec

from embed_anything import EmbedData, EmbeddingModel, WhichModel, TextEmbedConfig


class PineconeAdapter(Adapter):
    """
    Adapter class for interacting with Pinecone, a vector database service.
    """

    def __init__(self, api_key: str):
        """
        Initializes a new instance of the PineconeAdapter class.

        Args:
            api_key (str): The API key for accessing the Pinecone service.
        """
        super().__init__(api_key)
        self.pc = Pinecone(api_key=self.api_key)
        self.index_name = None

    def create_index(
        self,
        dimension: int,
        metric: str = "cosine",
        index_name: str = "anything",
        spec=ServerlessSpec(cloud="aws", region="us-east-1"),
    ):
        """
        Creates a new index in Pinecone.

        Args:
            dimension (int): The dimensionality of the embeddings.
            metric (str, optional): The distance metric to use for similarity search. Defaults to "cosine".
            index_name (str, optional): The name of the index. Defaults to "anything".
            spec (ServerlessSpec, optional): The serverless specification for the index. Defaults to AWS in us-east-1 region.
        """
        self.index_name = index_name
        self.pc.create_index(
            name=index_name, dimension=dimension, metric=metric, spec=spec
        )

    def delete_index(self, index_name: str):
        """
        Deletes an existing index from Pinecone.

        Args:
            index_name (str): The name of the index to delete.
        """
        self.pc.delete_index(name=index_name)

    def convert(self, embeddings: List[EmbedData]) -> List[Dict]:
        """
        Converts a list of embeddings into the required format for upserting into Pinecone.

        Args:
            embeddings (List[EmbedData]): The list of embeddings to convert.

        Returns:
            List[Dict]: The converted data in the required format for upserting into Pinecone.
        """
        data_emb = []

        for embedding in embeddings:
            data_emb.append(
                {
                    "id": str(uuid.uuid4()),
                    "values": embedding.embedding,
                    "metadata": {
                        "text": embedding.text,
                        "file": re.split(
                            r"/|\\", embedding.metadata.get("file_name", "")
                        )[-1],
                    },
                }
            )
        return data_emb

    def upsert(self, data: List[Dict]):
        """
        Upserts data into the specified index in Pinecone.

        Args:
            data (List[Dict]): The data to upsert into Pinecone.

        Raises:
            ValueError: If the index has not been created before upserting data.
        """
        data = self.convert(data)
        if not self.index_name:
            raise ValueError("Index must be created before upserting data")
        self.pc.Index(name=self.index_name).upsert(data)


# Initialize the PineconeEmbedder class
api_key = os.environ.get("PINECONE_API_KEY")
index_name = "anything"
pinecone_adapter = PineconeAdapter(api_key)

try:
    pinecone_adapter.delete_index("anything")
except:
    pass

# Initialize the PineconeEmbedder class

pinecone_adapter.create_index(dimension=512, metric="cosine")


model = EmbeddingModel.from_pretrained_hf(
    WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
)


data = embed_anything.embed_file(
    "/home/sonamAI/projects/EmbedAnything/test_files/attention.pdf",
    embedder=model,
    adapter=pinecone_adapter,
)



data = embed_anything.embed_image_directory(
    "test_files",
    embedder=model,
    adapter=pinecone_adapter
)
print(data)

Using Qdrant

To use Qdrant, you need to install the qdrant-client package.

pip install qdrant-client
import uuid
from typing import List, Dict
from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    VectorParams,
    PointStruct,
)
import embed_anything
from embed_anything import EmbedData, EmbeddingModel, WhichModel
from embed_anything.vectordb import Adapter


class QdrantAdapter(Adapter):
    """
    Adapter class for interacting with [Qdrant](https://qdrant.tech/).
    """

    def __init__(self, client: QdrantClient):
        """
        Initializes a new instance of the QdrantAdapter class.

        Args:
            client : An instance of qdrant_client.QdrantClient
        """
        self.client = client

    def create_index(
        self,
        dimension: int,
        metric: Distance = Distance.COSINE,
        index_name: str = "embed-anything",
        **kwargs,
    ):
        self.collection_name = index_name

        if not self.client.collection_exists(index_name):
            self.client.create_collection(
                collection_name=index_name,
                vectors_config=VectorParams(size=dimension, distance=metric),
            )

    def delete_index(self, index_name: str):
        self.client.delete_collection(collection_name=index_name)

    def convert(self, embeddings: List[EmbedData]) -> List[PointStruct]:
        points = []
        for embedding in embeddings:
            points.append(
                PointStruct(
                    id=str(uuid.uuid4()),
                    vector=embedding.embedding,
                    payload={
                        "text": embedding.text,
                        "file_name": embedding.metadata["file_name"],
                        "modified": embedding.metadata["modified"],
                        "created": embedding.metadata["created"],
                    },
                )
            )
        return points

    def upsert(self, data: List[Dict]):
        points = self.convert(data)
        self.client.upsert(
            collection_name=self.collection_name,
            points=points,
        )


def main():
    adapter = QdrantAdapter(QdrantClient(location=":memory:"))
    adapter.create_index(dimension=384)

    model = EmbeddingModel.from_pretrained_hf(
        WhichModel.Bert, model_id="sentence-transformers/all-MiniLM-L12-v2"
    )

    embed_anything.embed_file(
        "test_files/attention.pdf",
        embedder=model,
        adapter=adapter,
    )


if __name__ == "__main__":
    main()

Using Milvus

To use Milvus, you need to install the pymilvus package.

pip install pymilvus
from pymilvus import MilvusClient, DataType
import os
from typing import Dict, List

import embed_anything
from embed_anything.vectordb import Adapter
from embed_anything import EmbedData, EmbeddingModel, WhichModel

print("Milvus Vector DB - Adapter")

# Default embedding dimension
EMBEDDINGS_DIM = 384
# Maximum VARCHAR field length for text content
TEXT_CONTENT_VARCHARS = 4098

# Type annotation for embeddings
VectorEmbeddings = List[List[EmbedData]]

class MilvusVectorAdapter(Adapter):
    def __init__(self, uri: str = './milvus.db', token: str = '', collection_name: str = "embed_anything_collection"):
        """
        Initialize the MilvusVectorAdapter.

        Args:
            uri (str): The URI to connect to, comes in the form of
                "https://address:port" for Milvus or Zilliz Cloud service,
                or "path/to/local/milvus.db" for the lite local Milvus. Defaults to
                "./milvus.db".
            token (str): The token for log in. Defaults to "".
            collection_name (str): Name of the collection to use. Defaults to
                "embed_anything_collection".
        """
        self.collection_name = collection_name
        self.client = MilvusClient(uri=uri, token=token)
        print("Ok - Milvus DB connection established.")

    def create_index(self, dimension: int = EMBEDDINGS_DIM):
        """
        Create a collection and index for embeddings.

        Args:
            dimension: Dimension of the embedding vectors.
            **kwargs: Additional parameters for index creation.
        """
        # Delete collection if it exists
        if self.client.has_collection(self.collection_name):
            self.delete_index()

        # Create collection schema
        schema = self.client.create_schema(auto_id=True)
        schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
        schema.add_field(
            field_name="embeddings",
            datatype=DataType.FLOAT_VECTOR,
            dim=dimension
        )
        schema.add_field(
            field_name="text",
            datatype=DataType.VARCHAR,
            max_length=TEXT_CONTENT_VARCHARS
        )
        schema.add_field(
            field_name="file_name",
            datatype=DataType.VARCHAR,
            max_length=255
        )
        schema.add_field(
            field_name="modified",
            datatype=DataType.VARCHAR,
            max_length=50
        )
        schema.add_field(
            field_name="created",
            datatype=DataType.VARCHAR,
            max_length=50
        )

        # Create the collection
        self.client.create_collection(
            collection_name=self.collection_name,
            schema=schema
        )

        # Create the index
        index_params = self.client.prepare_index_params()
        index_params.add_index(
            field_name="embeddings",
            index_type="IVF_FLAT",
            metric_type="L2",
            params={"nlist": 1024}
        )

        # Apply the index
        self.client.create_index(
            collection_name=self.collection_name,
            index_params=index_params
        )

        # Load the collection
        self.client.load_collection(
            collection_name=self.collection_name
        )

        print(f"Collection '{self.collection_name}' created with index.")

    def convert(self, embeddings: List[EmbedData]) -> List[Dict]:
        """
        Convert EmbedData objects to a format compatible with Milvus.

        Args:
            embeddings: List of EmbedData objects.

        Returns:
            List of dictionaries with data formatted for Milvus.
        """
        ret_data = []
        for i, embedding in enumerate(embeddings):
            data_dict = {
                "embeddings": embedding.embedding,
                "text": embedding.text,
                "file_name": embedding.metadata["file_name"],
                "modified": embedding.metadata["modified"],
                "created": embedding.metadata["created"],
            }
            ret_data.append(data_dict)

        print(f"Converted {len(ret_data)} embeddings for insertion.")
        return ret_data

    def delete_index(self):
        """
        Delete the collection and its index.
        """
        try:
            self.client.drop_collection(self.collection_name)
            print(f"Collection '{self.collection_name}' dropped.")
        except Exception as e:
            print(f"Failed to drop collection: {e}")


    def upsert(self, data: List[EmbedData]):
        """
        Insert or update embeddings in the collection.

        Args:
            data: List of EmbedData objects to insert.
        """
        # Convert data to Milvus format
        formatted_data = self.convert(data)

        # Insert data
        self.client.insert(
            collection_name=self.collection_name,
            data=formatted_data
        )

        print(f"Successfully inserted {len(formatted_data)} embeddings.")





if __name__ == "__main__":
    # Initialize the MilvusVectorAdapter class
    index_name = "embed_anything_milvus_collection"
    milvus_adapter = MilvusVectorAdapter(uri='./milvus.db', collection_name=index_name)

    # Delete existing index if it exists
    try:
        milvus_adapter.delete_index(index_name)
    except:
        pass

    # Create a new index
    milvus_adapter.create_index()

    # Initialize the embedding model
    model = EmbeddingModel.from_pretrained_hf(
        WhichModel.Bert, 
        model_id="sentence-transformers/all-MiniLM-L12-v2"
    )

    # Embed a PDF file
    data = embed_anything.embed_file(
        "/Users/jinhonglin/Desktop/sample.pdf",
        embedder=model,
        adapter=milvus_adapter,
    )