chunk_service.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from typing import Dict, Any, List, Optional
  2. class ChunkService:
  3. def __init__(self, http_client):
  4. self.http_client = http_client
  5. def create_chunk(self, dataset_id: str, document_id: str, content: str,
  6. meta_fields: Dict = None) -> Dict[str, Any]:
  7. endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
  8. data = {"content": content}
  9. if meta_fields is not None:
  10. data["meta_fields"] = meta_fields
  11. response = self.http_client.post(endpoint, json_data=data)
  12. if response.get("code") == 0 and response.get("data"):
  13. return response["data"]
  14. else:
  15. raise Exception(f"创建切片失败: {response.get('message', '未知错误')}")
  16. def update_chunk(self, dataset_id: str, chunk_id: str, content: str = None,
  17. meta_fields: Dict = None) -> Dict[str, Any]:
  18. endpoint = f"/api/v1/datasets/{dataset_id}/chunks/{chunk_id}"
  19. data = {}
  20. if content is not None:
  21. data["content"] = content
  22. if meta_fields is not None:
  23. data["meta_fields"] = meta_fields
  24. response = self.http_client.post(endpoint, json=data)
  25. if response.get("code") == 0 and response.get("data"):
  26. return response["data"]
  27. else:
  28. raise Exception(f"更新切片失败: {response.get('message', '未知错误')}")
  29. def delete_chunk(self, dataset_id: str, chunk_id: str) -> bool:
  30. endpoint = f"/api/v1/datasets/{dataset_id}/chunks/{chunk_id}"
  31. response = self.http_client.post(endpoint, json_data={})
  32. if response.get("code") == 0:
  33. return True
  34. else:
  35. raise Exception(f"删除切片失败: {response.get('message', '未知错误')}")
  36. def delete_chunks(self, dataset_id: str, document_id: str, chunk_ids: List[str]) -> bool:
  37. endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
  38. response = self.http_client.post(endpoint, json_data={"chunk_ids": chunk_ids})
  39. if response.get("code") == 0:
  40. return True
  41. else:
  42. raise Exception(f"批量删除切片失败: {response.get('message', '未知错误')}")
  43. def retrieval(self, dataset_ids: List[str], query: str, top_k: int = 5,
  44. similarity_threshold: float = 0.1, vector_similarity_weight: float = 0.3,
  45. refine: bool = False) -> List[Dict[str, Any]]:
  46. endpoint = "/api/v1/retrieval"
  47. data = {
  48. "dataset_ids": dataset_ids,
  49. "query": query,
  50. "top_k": top_k,
  51. "similarity_threshold": similarity_threshold,
  52. "vector_similarity_weight": vector_similarity_weight,
  53. "refine": refine
  54. }
  55. response = self.http_client.post(endpoint, json=data)
  56. if response.get("code") == 0 and response.get("data"):
  57. return response["data"]
  58. else:
  59. raise Exception(f"检索失败: {response.get('message', '未知错误')}")