diff --git a/.env b/.env new file mode 100644 index 0000000..335f306 --- /dev/null +++ b/.env @@ -0,0 +1,2 @@ +# Postgres database address for cocoindex +COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex diff --git a/README.md b/README.md index 36cf6ae..02fa1b0 100644 --- a/README.md +++ b/README.md @@ -28,13 +28,13 @@ export COCOINDEX_DATABASE_URL="postgresql://cocoindex:cocoindex@localhost:5432/c Setup index: ```bash -python quickstart.py cocoindex setup +cocoindex setup quickstart.py ``` Update index: ```bash -python quickstart.py cocoindex update +cocoindex update quickstart.py ``` Run query: diff --git a/quickstart.py b/quickstart.py index 911f101..579082c 100644 --- a/quickstart.py +++ b/quickstart.py @@ -1,4 +1,22 @@ import cocoindex +from dotenv import load_dotenv +from psycopg_pool import ConnectionPool +import os + +@cocoindex.transform_flow() +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: + """ + Embed the text using a SentenceTransformer model. + This is a shared logic between indexing and querying, so extract it as a function. + """ + return text.transform( + cocoindex.functions.SentenceTransformerEmbed( + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + @cocoindex.flow_def(name="TextEmbedding") def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): # Add a data source to read files from a directory @@ -18,9 +36,7 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind # Transform data of each chunk with doc["chunks"].row() as chunk: # Embed the chunk, put into `embedding` field - chunk["embedding"] = chunk["text"].transform( - cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + chunk["embedding"] = text_to_embedding(chunk["text"]) # Collect the chunk into the collector. doc_embeddings.collect(filename=doc["filename"], location=chunk["location"], @@ -31,35 +47,54 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind "doc_embeddings", cocoindex.storages.Postgres(), primary_key_fields=["filename", "location"], - vector_index=[("embedding", cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + vector_indexes=[ + cocoindex.VectorIndexDef( + field_name="embedding", + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) -query_handler = cocoindex.query.SimpleSemanticsQueryHandler( - name="SemanticsSearch", - flow=text_embedding_flow, - target_name="doc_embeddings", - query_transform_flow=lambda text: text.transform( - cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")), - default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY) +def search(pool: ConnectionPool, query: str, top_k: int = 5): + # Get the table name, for the export target in the text_embedding_flow above. + table_name = cocoindex.utils.get_target_storage_default_name( + text_embedding_flow, "doc_embeddings" + ) + # Evaluate the transform flow defined above with the input query, to get the embedding. + query_vector = text_to_embedding.eval(query) + # Run the query and get the results. + with pool.connection() as conn: + with conn.cursor() as cur: + cur.execute( + f""" + SELECT filename, text, embedding <=> %s::vector AS distance + FROM {table_name} ORDER BY distance LIMIT %s + """, + (query_vector, top_k), + ) + return [ + {"filename": row[0], "text": row[1], "score": 1.0 - row[2]} + for row in cur.fetchall() + ] -@cocoindex.main_fn() def _main(): - # Run queries to demonstrate the query capabilities. + # Initialize the database connection pool. + pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL")) + # Run queries in a loop to demonstrate the query capabilities. while True: - try: - query = input("Enter search query (or Enter to quit): ") - if query == '': - break - results, _ = query_handler.search(query, 10) - print("\nSearch results:") - for result in results: - print(f"[{result.score:.3f}] {result.data['filename']}") - print(f" {result.data['text']}") - print("---") - print() - except KeyboardInterrupt: + query = input("Enter search query (or Enter to quit): ") + if query == "": break - + # Run the query function with the database connection pool and the query. + results = search(pool, query) + print("\nSearch results:") + for result in results: + print(f"[{result['score']:.3f}] {result['filename']}") + print(f" {result['text']}") + print("---") + print() if __name__ == "__main__": + load_dotenv() + cocoindex.init() _main() \ No newline at end of file