diff --git a/clstmhl.h b/clstmhl.h index 1df36b0..923c0b8 100644 --- a/clstmhl.h +++ b/clstmhl.h @@ -146,8 +146,9 @@ struct CLSTMText { struct CLSTMOCR { shared_ptr normalizer; Network net; - int target_height = 48; + int target_height;// = 48; // to avoid unwanted values. int nclasses = -1; + string dewarp; // Option for text-line normalization Sequence aligned, targets; Tensor2 image; void setLearningRate(float lr, float mom) { net->setLearningRate(lr, mom); } @@ -161,7 +162,7 @@ struct CLSTMOCR { return false; } nclasses = net->codec.size(); - normalizer.reset(make_CenterNormalizer()); + normalizer.reset(make_Normalizer(dewarp)); normalizer->target_height = target_height; return true; } @@ -194,7 +195,7 @@ struct CLSTMOCR { {"nhidden", nhidden}}); net->initialize(); net->codec.set(codec); - normalizer.reset(make_CenterNormalizer()); + normalizer.reset(make_Normalizer(dewarp)); normalizer->target_height = target_height; } std::wstring fwdbwd(TensorMap2 raw, const std::wstring &target) { diff --git a/clstmocr.cc b/clstmocr.cc index 131759d..414191c 100644 --- a/clstmocr.cc +++ b/clstmocr.cc @@ -46,8 +46,10 @@ int main1(int argc, char **argv) { string load_name = getsenv("load", ""); if (load_name == "") THROW("must give load= parameter"); CLSTMOCR clstm; + clstm.target_height = int(getrenv("target_height", 45)); + clstm.dewarp = getsenv("dewarp", "none"); clstm.load(load_name); - + bool conf = getienv("conf", 0); string output = getsenv("output", "text"); bool save_text = getienv("save_text", 1); diff --git a/clstmocrtrain.cc b/clstmocrtrain.cc index eb31249..8cda1b6 100644 --- a/clstmocrtrain.cc +++ b/clstmocrtrain.cc @@ -65,6 +65,19 @@ struct Dataset { for (auto s : fnames) gtnames.push_back(basename(s) + ".gt.txt"); codec.build(gtnames, charsep); } + void getCodec(Codec &codec, vector file_lists) { + // get codec from several files, including training files, validation files, + // and perhaps testing files, in order to avoid unrecognized codecs + vector gts; + for (int i=0; i temp_names; + read_lines(temp_names, file_lists[i]); + for (auto s : temp_names) gts.push_back(basename(s) + ".gt.txt"); + } + // build the codecs + codec.build(gts, charsep); + } + void readSample(Tensor2 &raw, wstring >, int index) { string fname = fnames[index]; string base = basename(fname); @@ -92,12 +105,19 @@ int main1(int argc, char **argv) { int ntrain = getienv("ntrain", 10000000); string save_name = getsenv("save_name", "_ocr"); int report_time = getienv("report_time", 0); + // vector storing the training and testing files + vector file_lists; if (argc < 2 || argc > 3) THROW("... training [testing]"); Dataset trainingset(argv[1]); + file_lists.push_back(argv[1]); assert(trainingset.size() > 0); Dataset testset; - if (argc > 2) testset.readFileList(argv[2]); + if (argc > 2) { + testset.readFileList(argv[2]); + file_lists.push_back(argv[2]); + } + print("got", trainingset.size(), "files,", testset.size(), "tests"); string load_name = getsenv("load", ""); @@ -108,13 +128,16 @@ int main1(int argc, char **argv) { clstm.load(load_name); } else { Codec codec; - trainingset.getCodec(codec); + //trainingset.getCodec(codec); + trainingset.getCodec(codec, file_lists); // use all ground truth files print("got", codec.size(), "classes"); - clstm.target_height = int(getrenv("target_height", 48)); + clstm.target_height = int(getrenv("target_height", 45)); + clstm.dewarp = getsenv("dewarp", "none"); clstm.createBidi(codec.codec, getienv("nhidden", 100)); clstm.setLearningRate(getdenv("lrate", 1e-4), getdenv("momentum", 0.9)); } + file_lists.clear(); // clear the file_lists vector network_info(clstm.net); double test_error = 9999.0; @@ -135,12 +158,16 @@ int main1(int argc, char **argv) { Trigger report_trigger(getienv("report_every", 100), ntrain, start); Trigger display_trigger(getienv("display_every", 0), ntrain, start); + double train_errors = 0.0; + double train_count = 0.0; for (int trial = start; trial < ntrain; trial++) { int sample = lrand48() % trainingset.size(); Tensor2 raw; wstring gt; trainingset.readSample(raw, gt, sample); wstring pred = clstm.train(raw(), gt); + train_count += gt.size(); + train_errors += levenshtein(pred, gt); if (report_trigger(trial)) { print(trial); @@ -168,6 +195,15 @@ int main1(int argc, char **argv) { double count = tse.second; test_error = errors / count; print("ERROR", trial, test_error, " ", errors, count); + double train_error; + if (train_errors > 0) + train_error = train_count / train_errors; + else + train_error = 9999.0; + print("Train ERROR: ", train_error); + train_count = 0.0; + train_errors = 0.0; + if (test_error < best_error) { best_error = test_error; string fname = save_name + ".clstm"; diff --git a/extras.h b/extras.h index eaa7533..3ba093b 100644 --- a/extras.h +++ b/extras.h @@ -31,7 +31,7 @@ using std::min; // text line normalization struct INormalizer { - int target_height = 48; + int target_height; // = 48; float smooth2d = 1.0; float smooth1d = 0.3; float range = 4.0;