Source code for langchain_matrixone.vectorstores

from __future__ import annotations

import json
import logging
import uuid
from typing import Any, Dict, Iterable, List, Optional

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from matrixone import Client

logger = logging.getLogger(__name__)


[docs] class MatrixOneVectorStore(VectorStore): """MatrixOne vector store integration compatible with LangChain.""" def __init__( self, embedding: Embeddings, connection_args: Optional[Dict[str, Any]] = None, *, client: Optional[Client] = None, table_name: str = "langchain_vectors", content_column: str = "content", metadata_column: str = "metadata", vector_column: str = "embedding", drop_old: bool = False, distance: str = "l2", ) -> None: if client is None and not connection_args: raise ValueError("Either an existing MatrixOne Client or connection_args must be provided.") self.embedding = embedding self.connection_args = connection_args or {} self.table_name = table_name self.content_column = content_column self.metadata_column = metadata_column self.vector_column = vector_column self.distance = distance self.client = client or Client() self._owns_client = client is None if self._owns_client: self._connect_client() elif not self.client.connected(): raise ValueError("Provided MatrixOne Client is not connected.") self._create_table_if_not_exists(drop_old) def _connect_client(self) -> None: required_keys = {"host", "port", "user", "password", "database"} missing = required_keys - self.connection_args.keys() if missing: raise ValueError(f"connection_args missing required keys: {', '.join(sorted(missing))}") self.client.connect(**self.connection_args) def _create_table_if_not_exists(self, drop_old: bool = False) -> None: dummy_embedding = self.embedding.embed_query("matrixone-init") dim = len(dummy_embedding) if drop_old: self.client.execute(f"DROP TABLE IF EXISTS {self.table_name}") create_sql = f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id VARCHAR(36) PRIMARY KEY, {self.content_column} TEXT, {self.metadata_column} JSON, {self.vector_column} VECF32({dim}) ) """ self.client.execute(create_sql) def _format_metadata(self, metadata: Optional[dict]) -> str: if not metadata: return "{}" return json.dumps(metadata)
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: texts = list(texts) embeddings = self.embedding.embed_documents(texts) if not ids: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] records: List[Dict[str, Any]] = [] for idx, text in enumerate(texts): records.append( { "id": ids[idx], self.content_column: text, self.metadata_column: self._format_metadata(metadatas[idx]), self.vector_column: embeddings[idx], } ) self.client.vector_ops.batch_insert(self.table_name, records) return ids
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any ) -> List[Document]: rows = self.client.vector_ops.similarity_search( self.table_name, vector_column=self.vector_column, query_vector=embedding, limit=k, select_columns=[self.content_column, self.metadata_column], distance_type=self.distance, ) docs: List[Document] = [] for row in rows: content = row.get(self.content_column) metadata_value = row.get(self.metadata_column) metadata: Dict[str, Any] if isinstance(metadata_value, dict): metadata = metadata_value elif isinstance(metadata_value, str) and metadata_value: try: metadata = json.loads(metadata_value) except json.JSONDecodeError: metadata = {"raw_metadata": metadata_value} else: metadata = {} if content is None: continue docs.append(Document(page_content=content, metadata=metadata)) return docs
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: if not ids: return True placeholders = ", ".join(["?"] * len(ids)) sql = f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" self.client.execute(sql, tuple(ids)) return True
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, connection_args: Optional[Dict[str, Any]] = None, client: Optional[Client] = None, **kwargs: Any, ) -> "MatrixOneVectorStore": store = cls( embedding=embedding, connection_args=connection_args, client=client, **kwargs, ) store.add_texts(texts, metadatas=metadatas) return store
def __del__(self) -> None: if getattr(self, "_owns_client", False): try: self.client.disconnect() except Exception: # pragma: no cover - best effort cleanup pass