#include "elm.h"
#include <chrono>
#include <iomanip>
#include <algorithm>
#include <unordered_set>
#include <random>
VectorXf encode_addr(size_t rel_off, const ModelConfig& cfg);
inline bool is_utf8_boundary(unsigned char c) {
return (c & 0xC0) != 0x80;
}
std::string filter_low_signal(const std::string& s) {
std::string filtered;
const auto Stop_Strings = {
" ", "\t", "\n", "\r", ",", ".", ";", ":", "!", "?", "\"", "'",
"(", ")", "[", "]", "{", "}", "<", ">", "/", "\\", "-", "_", "=", "+", "*", "&",
"^", "%", "$", "#", "@", "~", "`",
",", "。", ";", ":", "!", "?", "“", "”", "‘", "’", "、", "(", ")",
"【", "】", "《", "》", "…", "—", "·",
"的", "地", "得", "了", "着", "过", "是", "在", "有", "无", "和", "与",
"及", "或", "而", "则", "就", "都", "也", "还", "又", "才", "只", "却",
"不", "没", "别", "莫", "非", "否", "啊", "哦", "呢", "吗", "吧", "呀", "哇", "呐"
};
const std::unordered_set<std::string> stop_strs{Stop_Strings.begin(), Stop_Strings.end()};
for (size_t i = 0; i < s.size(); ) {
uint8_t c = static_cast<uint8_t>(s[i]);
std::string current_char;
if (c < 0x80) {
current_char = std::string(1, s[i]);
i += 1;
} else if (c >= 0xE0) {
if (i + 2 < s.size()) {
current_char = s.substr(i, 3);
i += 3;
} else {
i += 1;
continue;
}
} else {
i += 1;
continue;
}
if (stop_strs.find(current_char) == stop_strs.end()) {
filtered += current_char;
}
}
return filtered;
}
std::string trim_quotes(const std::string& s) {
if (s.size() < 2) return s;
std::string res = s;
if ((res.front() == '"' && res.back() == '"') || (res.front() == '\'' && res.back() == '\'')) {
res = res.substr(1, res.size() - 2);
}
return res;
}
class ShardedManager {
public:
ModelConfig cfg;
MatrixXf global_win;
std::vector<FuzzyELM> chunks32;
std::unique_ptr<SemanticFeatureExtractor> extractor;
ShardedManager() {
try {
extractor = std::make_unique<SemanticFeatureExtractor>("model.onnx", "tokenizer.json");
} catch (const std::exception& e) {
std::cerr << "❌ 语义模型加载失败: " << e.what() << std::endl;
exit(1);
}
init_win();
}
void init_win() {
std::mt19937 gen(cfg.seed);
std::normal_distribution<float> dist(0, 0.01);
global_win.resize(cfg.hidden_dim, cfg.input_dim);
for(int i=0; i<global_win.rows(); ++i)
for(int j=0; j<global_win.cols(); ++j)
global_win(i,j) = dist(gen);
}
void train(const std::string& data_file) {
std::ifstream ifs(data_file, std::ios::binary | std::ios::ate);
if (!ifs) {
std::cerr << "❌ 无法打开数据文件: " << data_file << std::endl;
return;
}
size_t total_size = ifs.tellg();
ifs.close();
size_t num_chunks = (total_size + CHUNK_SIZE - 1) / CHUNK_SIZE;
chunks32.clear();
chunks32.resize(num_chunks);
std::ofstream dbg("debug_samples.txt");
for (size_t i = 0; i < num_chunks; ++i) {
auto start_time = std::chrono::high_resolution_clock::now();
size_t start_pos = i * CHUNK_SIZE;
size_t curr_size = std::min(CHUNK_SIZE, total_size - start_pos);
std::ifstream f(data_file, std::ios::binary);
f.seekg(start_pos);
std::vector<char> buf(curr_size);
f.read(buf.data(), curr_size);
f.close();
std::vector<size_t> off_list;
for (size_t off = 0; off + cfg.slice_bytes < curr_size; off += 20) {
size_t safe_start = off;
while (safe_start < curr_size && !is_utf8_boundary((unsigned char)buf[safe_start])) {
safe_start++;
}
size_t safe_end = safe_start + cfg.slice_bytes;
while (safe_end < curr_size && !is_utf8_boundary((unsigned char)buf[safe_end])) {
safe_end++;
}
std::string text(buf.data() + safe_start, safe_end - safe_start);
std::string filtered = filter_low_signal(text);
if (filtered.size() >= 5 && safe_start < curr_size && safe_end <= curr_size) {
off_list.push_back(safe_start);
}
}
if (off_list.empty()) continue;
MatrixXf H32(off_list.size(), cfg.hidden_dim);
MatrixXf T32(off_list.size(), cfg.addr_bits);
size_t count = 0;
for (size_t safe_off : off_list) {
size_t safe_end = safe_off + cfg.slice_bytes;
while (safe_end < curr_size && !is_utf8_boundary((unsigned char)buf[safe_end])) {
safe_end++;
}
std::string text(buf.data() + safe_off, safe_end - safe_off);
std::string filtered_text = filter_low_signal(text);
dbg << "[Pos: " << (start_pos + safe_off) << "] | " << filtered_text << "\n";
VectorXf sem_v = extractor->extract(filtered_text);
H32.row(count) = (global_win * sem_v).unaryExpr([](float x){ return std::tanh(x); });
T32.row(count) = encode_addr(safe_off, cfg);
count++;
}
chunks32[i].train(H32, T32, cfg.lambda);
auto end_time = std::chrono::high_resolution_clock::now();
std::cout << "✅ Chunk " << i+1 << "/" << num_chunks
<< " | 样本数: " << count
<< " | 耗时: " << std::fixed << std::setprecision(2)
<< std::chrono::duration<double, std::milli>(end_time - start_time).count() << "ms" << std::endl;
}
dbg.close();
save(data_file);
regression_test_full(data_file, "debug_samples.txt");
}
void save(const std::string& f) {
std::ofstream os(f + ".model", std::ios::binary);
size_t n = chunks32.size();
os.write((char*)&n, sizeof(n));
for(size_t i=0; i<n; ++i) {
chunks32[i].save_to_stream(os);
}
os.close();
}
bool load(const std::string& f) {
std::ifstream is(f + ".model", std::ios::binary);
if(!is) return false;
size_t n;
is.read((char*)&n, sizeof(n));
chunks32.resize(n);
for(size_t i=0; i<n; i++) {
if (!chunks32[i].load_from_stream(is)) {
return false;
}
}
is.close();
return true;
}
std::set<size_t> search(const std::string& raw_query) {
std::set<size_t> total_results;
if (raw_query.empty()) return total_results;
std::string query = trim_quotes(raw_query);
std::string filtered_query = filter_low_signal(query);
if (filtered_query.empty()) return total_results;
VectorXf sem_v = extractor->extract(filtered_query);
for (auto& chunk : chunks32) {
if (chunk.Beta.size() == 0) continue;
VectorXf h = (global_win * sem_v).unaryExpr([](float x) {
return std::tanh(x);
});
VectorXf out = h.transpose() * chunk.Beta;
get_candidates(out, total_results, cfg.addr_bits);
}
return total_results;
}
void regression_test_full(const std::string& data_file, const std::string& log_file) {
std::ifstream log_ifs(log_file);
std::string line;
int hits = 0;
int total = 0;
std::ofstream dbg("debug_out.txt");
while (std::getline(log_ifs, line)) {
if (line.empty() || line[0] == '-') continue;
size_t pos;
if (sscanf(line.c_str(), "[Pos: %zu]", &pos) == 1) {
size_t delim = line.find("| ");
if (delim == std::string::npos) continue;
std::string content = line.substr(delim + 2);
if (content.empty()) continue;
total++;
auto results = this->search(content);
dbg << "Target Pos: " << pos << " | Candidates: [";
for (auto r : results) dbg << r << ", ";
dbg << "]\n";
if (results.count(pos)) {
hits++;
}
}
}
if (total > 0) {
std::cout << "📈 训练样本召回率: " << (float)hits/total * 100 << "% (" << hits << "/" << total << ")" << std::endl;
} else {
std::cout << "❌ 错误:未识别到任何测试样本" << std::endl;
}
dbg.close();
}
void run_fuzzy_benchmark(const std::string& data_file) {
std::ifstream ifs(data_file, std::ios::binary);
std::string content((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
ifs.close();
int test_count = 100;
int passed = 0;
std::mt19937 gen(42);
std::cout << "🧪 启动自动化模糊检索压力测试 (样本量: " << test_count << ")..." << std::endl;
for(int i=0; i<test_count; ++i) {
size_t target_pos = (gen() % (content.size() - 100)) / 3 * 3;
std::string original = content.substr(target_pos, 45);
std::string query = filter_low_signal(original);
if(query.size() > 30) query = query.substr(0, 30);
if(query.size() > 6) {
size_t flip_idx = (gen() % (query.size()/3)) * 3;
query[flip_idx] = (char)(gen() % 255);
}
auto results = this->search(query);
bool found = false;
for(auto res : results) {
if(std::abs((long)res - (long)target_pos) < 60) {
found = true;
break;
}
}
if(found) passed++;
}
std::cout << "📊 模糊检索鲁棒性评分: " << passed << "/" << test_count
<< " (" << (float)passed/test_count * 100 << "%)" << std::endl;
}
};
VectorXf encode_addr(size_t rel_off, const ModelConfig& cfg) {
float norm_off = 2.0f * (float)rel_off / (float)CHUNK_SIZE - 1.0f;
VectorXf t_vec = VectorXf::Constant(cfg.addr_bits, norm_off);
return t_vec;
}
int main(int argc, char** argv) {
if(argc < 3) {
std::cout << "Usage:\n"
<< " Train: ./elm train <file>\n"
<< " Search: ./elm search <file> \"text\"\n"
<< " Bench: ./elm bench <file>" << std::endl;
return 1;
}
std::string cmd = argv[1];
std::string file = argv[2];
ShardedManager mgr;
if (cmd == "train") {
mgr.train(file);
} else if (cmd == "search") {
if(argc < 4) {
std::cerr << "❌ 缺少搜索关键词" << std::endl;
return 1;
}
if(!mgr.load(file)) {
std::cerr << "❌ 无法加载模型文件: " << file << ".model" << std::endl;
return 1;
}
std::string raw_query = argv[3];
for(int i=4; i<argc; ++i) raw_query += " " + std::string(argv[i]);
std::cout << "🔍 检索内容: [" << raw_query << "]" << std::endl;
auto results = mgr.search(raw_query);
if(results.empty()) {
std::cout << "❌ 索引中未发现匹配点。" << std::endl;
} else {
std::ifstream ifs(file, std::ios::binary);
for(auto addr : results) {
long start = (long)addr - 32;
if (start < 0) start = 0;
ifs.seekg(start);
std::vector<char> context(128);
ifs.read(context.data(), 128);
std::string output(context.data(), 128);
std::cout << "📍 地址: " << std::left << std::setw(8) << addr
<< " | 上下文: [..." << output << "...]" << std::endl;
}
ifs.close();
}
} else if (cmd == "bench") {
if(!mgr.load(file)) {
std::cerr << "❌ 无法加载模型文件" << std::endl;
return 1;
}
mgr.run_fuzzy_benchmark(file);
} else {
std::cerr << "❌ 无效命令: " << cmd << std::endl;
return 1;
}
return 0;
}