openai_chat_model.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from typing import Optional, Dict, Any
  2. from langchain.chat_models import init_chat_model
  3. from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  4. from conf.settings import model_settings
  5. class OpenAIChatModel:
  6. """
  7. OpenAI 风格的聊天模型封装
  8. 基于 langchain 的 init_chat_model 实现,
  9. 默认使用 deepseek-r1 模型,支持用户指定其他模型。
  10. """
  11. def __init__(
  12. self,
  13. model_provider: str = model_settings.model_provider,
  14. model_name: str = model_settings.chat_model_name,
  15. api_key: Optional[str] = model_settings.api_key,
  16. base_url: Optional[str] = model_settings.base_url,
  17. temperature: float = 0.7,
  18. max_tokens: Optional[int] = None,
  19. **kwargs
  20. ):
  21. """
  22. 初始化 OpenAI Chat 模型
  23. Args:
  24. model_name: 模型名称,默认为 "deepseek-r1"
  25. api_key: API 密钥
  26. base_url: API 基础 URL
  27. temperature: 生成文本的随机性,范围 0-2,默认为 0.7
  28. max_tokens: 最大生成 token 数
  29. **kwargs: 其他参数
  30. """
  31. # 使用 langchain 的 init_chat_model 初始化模型
  32. self.chat_model = init_chat_model(
  33. model_provider=model_provider,
  34. model=model_name,
  35. api_key=api_key,
  36. base_url=base_url,
  37. temperature=temperature,
  38. max_tokens=max_tokens,
  39. **kwargs
  40. )
  41. def get_chat_model(self):
  42. """
  43. 获取聊天模型实例
  44. Returns:
  45. 聊天模型实例(由 langchain.init_chat_model 返回的类型)
  46. """
  47. return self.chat_model
  48. def generate_response(
  49. self,
  50. prompt: str,
  51. system_prompt: Optional[str] = None,
  52. **kwargs
  53. ) -> str:
  54. """
  55. 生成响应
  56. Args:
  57. prompt: 用户提示
  58. system_prompt: 系统提示
  59. **kwargs: 其他参数
  60. Returns:
  61. str: 生成的响应
  62. """
  63. # 构建消息列表
  64. messages = []
  65. # 添加系统提示(如果有)
  66. if system_prompt:
  67. messages.append(SystemMessage(content=system_prompt))
  68. # 添加用户提示
  69. messages.append(HumanMessage(content=prompt))
  70. # 生成响应
  71. response = self.chat_model.invoke(messages, **kwargs)
  72. # 解析响应
  73. if isinstance(response, AIMessage):
  74. return response.content
  75. else:
  76. # 对于其他类型的响应,尝试获取内容
  77. return str(response)
  78. def chat(
  79. self,
  80. prompt: str,
  81. system_prompt: Optional[str] = None,
  82. **kwargs
  83. ) -> str:
  84. """
  85. 聊天接口(别名,向后兼容)
  86. Args:
  87. prompt: 用户提示
  88. system_prompt: 系统提示
  89. **kwargs: 其他参数
  90. Returns:
  91. str: 生成的响应
  92. """
  93. return self.generate_response(prompt, system_prompt, **kwargs)