multimodal_embedding.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from typing import List
  2. from PIL import Image
  3. import base64
  4. import io
  5. from langchain_openai import OpenAIEmbeddings
  6. from dashscope import MultiModalEmbedding
  7. from conf.config import ModelConfig
  8. from utils.file.image_util import image_util as ImageUtil
  9. class Embedding:
  10. """Embedding模型工具"""
  11. def __init__(self, model_name: str = None, api_key: str = None):
  12. """
  13. 初始化Embedding模型
  14. Args:
  15. model_name: 模型名称,若为None则使用配置文件中的值
  16. api_key: API密钥,若为None则使用配置文件中的值
  17. """
  18. # 获取模型配置
  19. self.model_provider = ModelConfig.get_model_provider()
  20. self.model_name = model_name or ModelConfig.get_model_name()
  21. self.base_url = ModelConfig.get_base_url()
  22. self.api_key = api_key or ModelConfig.get_api_key()
  23. # 使用langchain_openai初始化OpenAI Embeddings模型
  24. self.embeddings = OpenAIEmbeddings(
  25. model=self.model_name,
  26. base_url=self.base_url,
  27. api_key=self.api_key
  28. )
  29. def get_text_embedding(self, text: str) -> List[float]:
  30. """
  31. 获取文本的embedding
  32. Args:
  33. text: 要获取embedding的文本
  34. Returns:
  35. List[float]: 文本的embedding向量
  36. """
  37. try:
  38. embedding = self.embeddings.embed_query(text)
  39. return embedding
  40. except Exception as e:
  41. raise Exception(f"文本embedding生成失败: {str(e)}")
  42. def get_texts_embedding(self, texts: List[str]) -> List[List[float]]:
  43. """
  44. 获取多个文本的embedding
  45. Args:
  46. texts: 要获取embedding的文本列表
  47. Returns:
  48. List[List[float]]: 文本列表的embedding向量列表
  49. """
  50. try:
  51. embeddings = self.embeddings.embed_documents(texts)
  52. return embeddings
  53. except Exception as e:
  54. raise Exception(f"多个文本embedding生成失败: {str(e)}")
  55. def get_image_embedding(self, image: Image.Image) -> List[float]:
  56. """
  57. 获取图像的embedding
  58. Args:
  59. image: PIL图像对象
  60. Returns:
  61. List[float]: 图像的embedding向量
  62. """
  63. try:
  64. # 将图像转换为base64
  65. buffer = io.BytesIO()
  66. image.save(buffer, format="PNG")
  67. buffer.seek(0)
  68. # 压缩图片字节流
  69. compressed_bytes = ImageUtil._compress_image_to_bytes(buffer)
  70. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  71. # 构建输入项
  72. item = [
  73. {
  74. "image": f"data:image/png;base64,{image_base64}"
  75. }
  76. ]
  77. response = MultiModalEmbedding.call(
  78. model=self.model_name,
  79. api_key=self.api_key,
  80. input=item
  81. )
  82. if response.status_code == 200:
  83. return response.output["embeddings"][0]["embedding"]
  84. except Exception as e:
  85. raise Exception(f"图像embedding生成失败: {str(e)}")
  86. def get_multimodal_embedding(self, text: str, image: Image.Image) -> List[float]:
  87. """
  88. 获取多模态(文本+图像)的embedding
  89. Args:
  90. text: 文本内容
  91. image: PIL图像对象
  92. Returns:
  93. List[float]: 多模态的embedding向量
  94. """
  95. try:
  96. item = []
  97. if image is not None:
  98. buffer = io.BytesIO()
  99. image.save(buffer, format="PNG")
  100. buffer.seek(0)
  101. # 压缩图片字节流
  102. compressed_bytes = ImageUtil()._compress_image_to_bytes(buffer)
  103. image_base64 = base64.b64encode(compressed_bytes).decode("utf-8")
  104. item.append({'image': f"data:image/png;base64,{image_base64}"})
  105. # 判断text部位None或者不为空字符串
  106. if text is not None and text.strip() != "":
  107. item.append({'text': text})
  108. response = MultiModalEmbedding.call(
  109. model=self.model_name,
  110. api_key=self.api_key,
  111. input=item
  112. )
  113. if response.status_code == 200:
  114. return response.output["embeddings"][0]["embedding"]
  115. else:
  116. raise Exception(f"Error: {response.message}")
  117. except Exception as e:
  118. raise Exception(f"多模态embedding生成失败: {str(e)}")
  119. # 全局单例
  120. def get_embedding_model() -> Embedding:
  121. """
  122. 获取全局单例的Embedding模型
  123. Returns:
  124. Embedding: 全局单例的Embedding模型
  125. """
  126. return Embedding(
  127. model_name=ModelConfig.get_multimodal_embedding_model_name(),
  128. api_key=ModelConfig.get_dashscope_api_key()
  129. )