| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from typing import Optional, Dict, Any
- from langchain.chat_models import init_chat_model
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
- from conf.settings import model_settings
- class OpenAIChatModel:
- """
- OpenAI 风格的聊天模型封装
-
- 基于 langchain 的 init_chat_model 实现,
- 默认使用 deepseek-r1 模型,支持用户指定其他模型。
- """
-
- def __init__(
- self,
- model_provider: str = model_settings.model_provider,
- model_name: str = model_settings.chat_model_name,
- api_key: Optional[str] = model_settings.api_key,
- base_url: Optional[str] = model_settings.base_url,
- temperature: float = 0.7,
- max_tokens: Optional[int] = None,
- **kwargs
- ):
- """
- 初始化 OpenAI Chat 模型
-
- Args:
- model_name: 模型名称,默认为 "deepseek-r1"
- api_key: API 密钥
- base_url: API 基础 URL
- temperature: 生成文本的随机性,范围 0-2,默认为 0.7
- max_tokens: 最大生成 token 数
- **kwargs: 其他参数
- """
- # 使用 langchain 的 init_chat_model 初始化模型
- self.chat_model = init_chat_model(
- model_provider=model_provider,
- model=model_name,
- api_key=api_key,
- base_url=base_url,
- temperature=temperature,
- max_tokens=max_tokens,
- **kwargs
- )
-
- def get_chat_model(self):
- """
- 获取聊天模型实例
-
- Returns:
- 聊天模型实例(由 langchain.init_chat_model 返回的类型)
- """
- return self.chat_model
-
- def generate_response(
- self,
- prompt: str,
- system_prompt: Optional[str] = None,
- **kwargs
- ) -> str:
- """
- 生成响应
-
- Args:
- prompt: 用户提示
- system_prompt: 系统提示
- **kwargs: 其他参数
-
- Returns:
- str: 生成的响应
- """
- # 构建消息列表
- messages = []
-
- # 添加系统提示(如果有)
- if system_prompt:
- messages.append(SystemMessage(content=system_prompt))
-
- # 添加用户提示
- messages.append(HumanMessage(content=prompt))
-
- # 生成响应
- response = self.chat_model.invoke(messages, **kwargs)
-
- # 解析响应
- if isinstance(response, AIMessage):
- return response.content
- else:
- # 对于其他类型的响应,尝试获取内容
- return str(response)
-
- def chat(
- self,
- prompt: str,
- system_prompt: Optional[str] = None,
- **kwargs
- ) -> str:
- """
- 聊天接口(别名,向后兼容)
-
- Args:
- prompt: 用户提示
- system_prompt: 系统提示
- **kwargs: 其他参数
-
- Returns:
- str: 生成的响应
- """
- return self.generate_response(prompt, system_prompt, **kwargs)
|