Coverage for app/endpoint/knowledge/search_knowledge.py: 39%

36 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-16 01:07 +0000

1from app.endpoint.knowledge.router import router 

2from app.infrastructure.database.db import get_session 

3 

4from fastapi import Depends 

5from sqlmodel import Session 

6 

7 

8# neo4j登録 

9from neo4j import GraphDatabase 

10NEO4J_URI = "bolt://search:7687" 

11NEO4J_USER = "neo4j" 

12NEO4J_PASS = "password" 

13 

14def register_graph(graph_data, text, embedding): 

15 driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS)) 

16 

17 

18from pydantic import BaseModel 

19class SearchRequest(BaseModel): 

20 query: str 

21 top_k: int = 5 

22 

23 

24@router.post( 

25 "/knowledges:search" 

26) 

27async def search_knowledge( 

28 req: SearchRequest, 

29 session: Session = Depends(get_session), 

30): 

31 query = req.query 

32 from .create_knowledge import get_embedding 

33 embedding = get_embedding(query)['embeddings'][0] 

34 driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS)) 

35 with driver.session() as session: 

36 results = [] 

37 

38 # 🔍 1. 全文検索(text CONTAINS) 

39 text_results = session.run( 

40 """ 

41 MATCH (n) 

42 WHERE toLower(n.text) CONTAINS toLower($query) 

43 RETURN n LIMIT $k 

44 """, 

45 {"query": query, "k": req.top_k} 

46 # query=query, k=req.top_k 

47 ) 

48 results.extend([r["n"] for r in text_results]) 

49 

50 # 📐 2. ベクトル検索(Neo4j 5.11+) 

51 # WITH n, vector.similarity.cosine(n.embedding, $embedding) AS score 

52 vector_results = session.run( 

53 """ 

54 MATCH (n) 

55 WHERE n.embedding IS NOT NULL 

56 WITH n, vector.similarity.cosine(n.embedding, $embedding) AS score 

57 RETURN n ORDER BY score DESC LIMIT $k 

58 """, 

59 {"embedding": embedding, "k": req.top_k} 

60 # embedding=embedding, k=req.top_k 

61 ) 

62 results.extend([r["n"] for r in vector_results]) 

63 

64 # 🧠 3. グラフ構造探索(全文一致とリレーション) 

65 graph_results = session.run( 

66 """ 

67 MATCH (a)-[r]-(b) 

68 WHERE toLower(a.text) CONTAINS toLower($query) 

69 OR toLower(b.text) CONTAINS toLower($query) 

70 RETURN a, r, b LIMIT $k 

71 """, 

72 {"query": query, "k": req.top_k} 

73 # query=query, k=req.top_k 

74 ) 

75 for r in graph_results: 

76 results.extend([r["a"], r["b"]]) 

77 

78 # ✅ 重複除去して返却(idベース) 

79 seen = set() 

80 deduped = [] 

81 for node in results: 

82 node_id = node.get("id") 

83 if node_id and node_id not in seen: 

84 seen.add(node_id) 

85 deduped.append({ 

86 "id": node.get("id"), 

87 "label": list(node.labels)[0] if node.labels else "Entity", 

88 "text": node.get("text"), 

89 # "embedding": node.get("embedding"), 

90 # "score": score, 

91 }) 

92 

93 return {"results": deduped}