227 lines
5.9 KiB
JavaScript
227 lines
5.9 KiB
JavaScript
const database = require('./database');
|
||
|
||
class Message {
|
||
constructor(data = {}) {
|
||
this.id = data.id;
|
||
this.conversation_id = data.conversation_id;
|
||
this.role = data.role; // 'user' 或 'assistant'
|
||
this.content = data.content;
|
||
this.timestamp = data.timestamp;
|
||
}
|
||
|
||
/**
|
||
* 创建新消息
|
||
*/
|
||
static async create(conversationId, role, content) {
|
||
try {
|
||
// 验证角色
|
||
if (!['user', 'assistant'].includes(role)) {
|
||
throw new Error('无效的消息角色');
|
||
}
|
||
|
||
const sql = `
|
||
INSERT INTO messages (conversation_id, role, content, timestamp)
|
||
VALUES (?, ?, ?, datetime('now'))
|
||
`;
|
||
const result = await database.run(sql, [conversationId, role, content]);
|
||
|
||
// 返回新创建的消息
|
||
return await Message.findById(result.id);
|
||
} catch (error) {
|
||
console.error('创建消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 根据ID查找消息
|
||
*/
|
||
static async findById(id) {
|
||
try {
|
||
const sql = 'SELECT * FROM messages WHERE id = ?';
|
||
const row = await database.get(sql, [id]);
|
||
return row ? new Message(row) : null;
|
||
} catch (error) {
|
||
console.error('查找消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取对话的所有消息(按时间正序)
|
||
*/
|
||
static async findByConversationId(conversationId, limit = 100, offset = 0) {
|
||
try {
|
||
const sql = `
|
||
SELECT * FROM messages
|
||
WHERE conversation_id = ?
|
||
ORDER BY timestamp ASC
|
||
LIMIT ? OFFSET ?
|
||
`;
|
||
const rows = await database.all(sql, [conversationId, limit, offset]);
|
||
return rows.map(row => new Message(row));
|
||
} catch (error) {
|
||
console.error('获取对话消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取对话的最近N条消息(用于构建上下文)
|
||
*/
|
||
static async getRecentMessages(conversationId, limit = 10) {
|
||
try {
|
||
const sql = `
|
||
SELECT * FROM messages
|
||
WHERE conversation_id = ?
|
||
ORDER BY timestamp DESC
|
||
LIMIT ?
|
||
`;
|
||
const rows = await database.all(sql, [conversationId, limit]);
|
||
// 返回正序排列的消息
|
||
return rows.reverse().map(row => new Message(row));
|
||
} catch (error) {
|
||
console.error('获取最近消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取对话的消息总数
|
||
*/
|
||
static async getMessageCount(conversationId) {
|
||
try {
|
||
const sql = 'SELECT COUNT(*) as count FROM messages WHERE conversation_id = ?';
|
||
const result = await database.get(sql, [conversationId]);
|
||
return result.count;
|
||
} catch (error) {
|
||
console.error('获取消息数量失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 删除指定消息
|
||
*/
|
||
static async delete(id) {
|
||
try {
|
||
const sql = 'DELETE FROM messages WHERE id = ?';
|
||
const result = await database.run(sql, [id]);
|
||
|
||
if (result.changes === 0) {
|
||
throw new Error('消息不存在');
|
||
}
|
||
|
||
return true;
|
||
} catch (error) {
|
||
console.error('删除消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 删除对话的所有消息
|
||
*/
|
||
static async deleteByConversationId(conversationId) {
|
||
try {
|
||
const sql = 'DELETE FROM messages WHERE conversation_id = ?';
|
||
const result = await database.run(sql, [conversationId]);
|
||
return result.changes;
|
||
} catch (error) {
|
||
console.error('删除对话消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 批量创建消息
|
||
*/
|
||
static async createBatch(messages) {
|
||
try {
|
||
await database.beginTransaction();
|
||
|
||
const results = [];
|
||
for (const messageData of messages) {
|
||
const { conversationId, role, content } = messageData;
|
||
const message = await Message.create(conversationId, role, content);
|
||
results.push(message);
|
||
}
|
||
|
||
await database.commit();
|
||
return results;
|
||
} catch (error) {
|
||
await database.rollback();
|
||
console.error('批量创建消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 搜索消息内容
|
||
*/
|
||
static async search(query, conversationId = null, limit = 50) {
|
||
try {
|
||
let sql = `
|
||
SELECT m.*, c.title as conversation_title
|
||
FROM messages m
|
||
JOIN conversations c ON m.conversation_id = c.id
|
||
WHERE m.content LIKE ?
|
||
`;
|
||
const params = [`%${query}%`];
|
||
|
||
if (conversationId) {
|
||
sql += ' AND m.conversation_id = ?';
|
||
params.push(conversationId);
|
||
}
|
||
|
||
sql += ' ORDER BY m.timestamp DESC LIMIT ?';
|
||
params.push(limit);
|
||
|
||
const rows = await database.all(sql, params);
|
||
return rows.map(row => ({
|
||
...new Message(row),
|
||
conversation_title: row.conversation_title
|
||
}));
|
||
} catch (error) {
|
||
console.error('搜索消息失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 获取用户消息和AI回复的对话历史(用于LLM上下文)
|
||
*/
|
||
static async getConversationHistory(conversationId, maxTokens = 4000) {
|
||
try {
|
||
// 简单估算:每个字符约占1个token,中文可能占更多
|
||
const messages = await Message.getRecentMessages(conversationId, 20);
|
||
|
||
const history = [];
|
||
let totalLength = 0;
|
||
|
||
// 从最新消息开始,控制总长度
|
||
for (let i = messages.length - 1; i >= 0; i--) {
|
||
const message = messages[i];
|
||
const messageLength = message.content.length;
|
||
|
||
if (totalLength + messageLength > maxTokens && history.length > 0) {
|
||
break;
|
||
}
|
||
|
||
history.unshift({
|
||
role: message.role,
|
||
content: message.content
|
||
});
|
||
|
||
totalLength += messageLength;
|
||
}
|
||
|
||
return history;
|
||
} catch (error) {
|
||
console.error('获取对话历史失败:', error);
|
||
throw error;
|
||
}
|
||
}
|
||
}
|
||
|
||
module.exports = Message; |