search_request_models.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """
  2. 搜索相关的Pydantic模型定义
  3. 该模块定义了用于图像和文本搜索的请求和响应模型。
  4. """
  5. from pydantic import BaseModel, Field, field_validator
  6. from typing import List, Optional
  7. class SearchRequest(BaseModel):
  8. """
  9. 搜索请求模型
  10. 支持通过图像URL或文本进行相似性搜索。
  11. 至少需要提供 image_url 或 matching_text 其中之一。
  12. Attributes:
  13. image_url: 用于图像搜索的图片URL (可选)
  14. matching_text: 用于文本匹配搜索的查询文本 (可选)
  15. topn: 返回的最大结果数量,默认为10,范围1-100
  16. """
  17. database_ids: List[str] = Field(
  18. default=[],
  19. description="知识库id",
  20. examples=["database_id1", "database_id2"]
  21. )
  22. image_url: Optional[str] = Field(
  23. default=None,
  24. description="用于图像搜索的图片URL",
  25. examples=["https://example.com/image.jpg"]
  26. )
  27. matching_text: Optional[str] = Field(
  28. default=None,
  29. description="用于文本匹配搜索的查询文本",
  30. examples=["搜索关键词"]
  31. )
  32. topk: int = Field(
  33. default=10,
  34. ge=1,
  35. le=100,
  36. description="返回的最大结果数量"
  37. )
  38. score_threshold: float = Field(
  39. default=0.0,
  40. ge=0.0,
  41. le=1.0,
  42. description="相似度分数阈值"
  43. )
  44. kb_id: Optional[str] = Field(
  45. default=None,
  46. description="知识库id",
  47. examples=["kb_id1", "kb_id2"]
  48. )
  49. @field_validator('image_url')
  50. @classmethod
  51. def validate_image_url(cls, v: Optional[str]) -> Optional[str]:
  52. """验证图片URL格式"""
  53. if v is not None and v.strip() == "":
  54. return None
  55. return v
  56. @field_validator('matching_text')
  57. @classmethod
  58. def validate_matching_text(cls, v: Optional[str]) -> Optional[str]:
  59. """验证匹配文本"""
  60. if v is not None and v.strip() == "":
  61. return None
  62. return v
  63. def model_post_init(self, __context) -> None:
  64. """模型初始化后验证:确保至少提供一个搜索条件"""
  65. if self.image_url is None and self.matching_text is None:
  66. raise ValueError("必须提供 image_url 或 matching_text 其中至少一个")
  67. @field_validator('kb_id')
  68. @classmethod
  69. def validate_kb_id(cls, v: Optional[str]) -> Optional[str]:
  70. """验证知识库id"""
  71. if v is not None and v.strip() == "":
  72. return None
  73. return v