Coverage for app/endpoint/knowledge/search_knowledge.py: 39%
36 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-16 01:07 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-16 01:07 +0000
1from app.endpoint.knowledge.router import router
2from app.infrastructure.database.db import get_session
4from fastapi import Depends
5from sqlmodel import Session
8# neo4j登録
9from neo4j import GraphDatabase
10NEO4J_URI = "bolt://search:7687"
11NEO4J_USER = "neo4j"
12NEO4J_PASS = "password"
14def register_graph(graph_data, text, embedding):
15 driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
18from pydantic import BaseModel
19class SearchRequest(BaseModel):
20 query: str
21 top_k: int = 5
24@router.post(
25 "/knowledges:search"
26)
27async def search_knowledge(
28 req: SearchRequest,
29 session: Session = Depends(get_session),
30):
31 query = req.query
32 from .create_knowledge import get_embedding
33 embedding = get_embedding(query)['embeddings'][0]
34 driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
35 with driver.session() as session:
36 results = []
38 # 🔍 1. 全文検索(text CONTAINS)
39 text_results = session.run(
40 """
41 MATCH (n)
42 WHERE toLower(n.text) CONTAINS toLower($query)
43 RETURN n LIMIT $k
44 """,
45 {"query": query, "k": req.top_k}
46 # query=query, k=req.top_k
47 )
48 results.extend([r["n"] for r in text_results])
50 # 📐 2. ベクトル検索(Neo4j 5.11+)
51 # WITH n, vector.similarity.cosine(n.embedding, $embedding) AS score
52 vector_results = session.run(
53 """
54 MATCH (n)
55 WHERE n.embedding IS NOT NULL
56 WITH n, vector.similarity.cosine(n.embedding, $embedding) AS score
57 RETURN n ORDER BY score DESC LIMIT $k
58 """,
59 {"embedding": embedding, "k": req.top_k}
60 # embedding=embedding, k=req.top_k
61 )
62 results.extend([r["n"] for r in vector_results])
64 # 🧠 3. グラフ構造探索(全文一致とリレーション)
65 graph_results = session.run(
66 """
67 MATCH (a)-[r]-(b)
68 WHERE toLower(a.text) CONTAINS toLower($query)
69 OR toLower(b.text) CONTAINS toLower($query)
70 RETURN a, r, b LIMIT $k
71 """,
72 {"query": query, "k": req.top_k}
73 # query=query, k=req.top_k
74 )
75 for r in graph_results:
76 results.extend([r["a"], r["b"]])
78 # ✅ 重複除去して返却(idベース)
79 seen = set()
80 deduped = []
81 for node in results:
82 node_id = node.get("id")
83 if node_id and node_id not in seen:
84 seen.add(node_id)
85 deduped.append({
86 "id": node.get("id"),
87 "label": list(node.labels)[0] if node.labels else "Entity",
88 "text": node.get("text"),
89 # "embedding": node.get("embedding"),
90 # "score": score,
91 })
93 return {"results": deduped}