diff --git a/notebooks/kalman_CUDA_demo/kalman.h++ b/notebooks/kalman_CUDA_demo/kalman.hpp similarity index 100% rename from notebooks/kalman_CUDA_demo/kalman.h++ rename to notebooks/kalman_CUDA_demo/kalman.hpp diff --git a/notebooks/kalman_CUDA_demo/run_kf.ipynb b/notebooks/kalman_CUDA_demo/run_kf.ipynb index ab5a9757..020fd954 100644 --- a/notebooks/kalman_CUDA_demo/run_kf.ipynb +++ b/notebooks/kalman_CUDA_demo/run_kf.ipynb @@ -396,12 +396,15 @@ "metadata": {}, "outputs": [], "source": [ - "int run_kf(int expansion_factor, bool verbose) {\n", + "std::vector> run_kf(int expansion_factor, bool verbose) {\n", " \n", " int n = 3; // Number of states\n", " int m = 1; // Number of measurements\n", "\n", - " double dt = 1.0 / 30; // Time step\n", + " double dt = 1.0 / 40; // Time step\n", + " \n", + " std::vector g_preds;\n", + " std::vector g_est;\n", " \n", " std::vector> A(n, std::vector(n));\n", " std::vector> C(m, std::vector(n));\n", @@ -426,19 +429,32 @@ " 2.81003612051, 2.88321849354, 2.69789264832, 2.4342229249, 2.23464791825,\n", " 2.30278776224, 2.02069770395, 1.94393985809, 1.82498398739, 1.52526230354,\n", " 1.86967808173, 1.18073207847, 1.10729605087, 0.916168349913, 0.678547664519,\n", - " 0.562381751596, 0.355468474885, -0.155607486619, -0.287198661013, -0.602973173813\n", + " 0.562381751596, 0.355468474885, -0.155607486619, -0.287198661013, -0.602973173813,\n", + " -0.873817307503, -1.144661441193, -1.415505574883, -1.686349708573, -1.957193842263, \n", + " -2.228037975953, -2.498882109643, -2.769726243333, -3.040570377023, -3.311414510713, \n", + " -3.582258644403, -3.853102778093, -4.123946911783, -4.394791045473, -4.665635179163, \n", + " -4.936479312853, -5.207323446543, -5.478167580233, -5.749011713923, -6.019855847613,\n", + " -6.290699981303, -6.561544114993, -6.832388248683, -7.103232382373, -7.374076516063,\n", + " -7.644920649753, -7.915764783443, -8.186608917133, -8.457453050823, -8.728297184513, \n", + " -8.999141318203, -9.269985451893, -9.540829585583, -9.811673719273, -10.082517852963,\n", + " -10.353361986653, -10.624206120343, -10.895050254033, -11.165894387723, -11.436738521413,\n", + " -11.707582655103, -11.978426788793, -12.249270922483, -12.520115056173\n", " };\n", " \n", - " std::vector expandedMeasurements = generateExpandedData(measurements, expansion_factor);\n", + " // std::vector expandedMeasurements = generateExpandedData(measurements, expansion_factor);\n", " \n", - " std::vector x0 = {expandedMeasurements[0], 0, -9.81};\n", + " std::vector x0 = {measurements[0], 0, 0};\n", " kf.init(0, x0);\n", "\n", " // Feed measurements into filter, output estimated states\n", " std::vector y(m);\n", " if(verbose) {\n", " std::cout << \"t = \" << 0 << \", \" << \"x_hat[0]: \";\n", - " for (auto& val : kf.state()) std::cout << val << \" \";\n", + " for (auto& val : kf.state()){\n", + " std::cout << val << \" \";\n", + " }\n", + " g_est.push_back(kf.state()[2]);\n", + " \n", " std::cout << std::endl;\n", " }\n", " \n", @@ -448,7 +464,10 @@ " kf.update(y);\n", " if(verbose) {\n", " std::cout << \"t = \" << (i + 1) * dt << \", y[\" << i << \"] = \" << y[0] << \", x_hat[\" << i << \"] = \";\n", - " for (auto& val : kf.state()) std::cout << val << \" \";\n", + " for (auto& val : kf.state()) {\n", + " std::cout << val << \" \";\n", + " }\n", + " g_preds.push_back(kf.state()[2]);\n", " std::cout << std::endl;\n", " }\n", " }\n", @@ -456,7 +475,14 @@ " std::cout<<\"Exec Success, Final kf states:\";\n", " for (auto& val : kf.state()) std::cout << val << \" \";\n", " std::cout << std::endl;\n", - " return 0;\n", + " \n", + " std::vector> g_res;\n", + " for (size_t i = 0; i < g_preds.size(); ++i) {\n", + " std::vector pair = {g_preds[i], g_est[i]};\n", + " g_res.push_back(pair);\n", + " }\n", + " \n", + " return g_res;\n", "}" ] }, @@ -480,12 +506,12 @@ "metadata": {}, "outputs": [], "source": [ - "int pyrun_sim(int exp_factor, bool verbose) {\n", + "std::vector> pyrun_sim(int exp_factor, bool verbose) {\n", " start = std::chrono::high_resolution_clock::now();\n", - " run_kf(exp_factor, verbose);\n", + " std::vector> g_res = run_kf(exp_factor, verbose);\n", " stop = std::chrono::high_resolution_clock::now();\n", " duration = std::chrono::duration_cast(stop - start);\n", - " return duration.count();\n", + " return g_res;\n", "}" ] }, @@ -499,119 +525,263 @@ "name": "stdout", "output_type": "stream", "text": [ - "t = 0, x_hat[0]: 1.49273 0 -9.81 \n", - "t = 0.0333333, y[0] = 1.04203, x_hat[0] = 1.18055 -9.56663 -9.82201 \n", - "t = 0.0666667, y[1] = 1.10727, x_hat[1] = 1.04216 -7.18534 -9.81834 \n", - "t = 0.1, y[2] = 1.29135, x_hat[2] = 1.10856 -4.4342 -9.80991 \n", - "t = 0.133333, y[3] = 1.48485, x_hat[3] = 1.23897 -2.65176 -9.79492 \n", - "t = 0.166667, y[4] = 1.72826, x_hat[4] = 1.41541 -1.36441 -9.7679 \n", - "t = 0.2, y[5] = 1.74216, x_hat[5] = 1.5201 -0.923214 -9.74154 \n", - "t = 0.233333, y[6] = 2.11672, x_hat[6] = 1.71571 -0.249681 -9.67866 \n", - "t = 0.266667, y[7] = 2.14529, x_hat[7] = 1.85071 -0.0135839 -9.61971 \n", - "t = 0.3, y[8] = 2.1603, x_hat[8] = 1.94342 -0.0067291 -9.56589 \n", - "t = 0.333333, y[9] = 2.21269, x_hat[9] = 2.01834 -0.0830709 -9.50747 \n", - "t = 0.366667, y[10] = 2.57709, x_hat[10] = 2.16226 0.0419556 -9.35945 \n", - "t = 0.4, y[11] = 2.66822, x_hat[11] = 2.28825 0.0859778 -9.20125 \n", - "t = 0.433333, y[12] = 2.51642, x_hat[12] = 2.34414 -0.0748605 -9.11883 \n", - "t = 0.466667, y[13] = 2.76034, x_hat[13] = 2.43631 -0.124624 -8.94316 \n", - "t = 0.5, y[14] = 2.88132, x_hat[14] = 2.53044 -0.161939 -8.73035 \n", - "t = 0.533333, y[15] = 2.88374, x_hat[15] = 2.60153 -0.250322 -8.54117 \n", - "t = 0.566667, y[16] = 2.94485, x_hat[16] = 2.66672 -0.338958 -8.33748 \n", - "t = 0.6, y[17] = 2.82867, x_hat[17] = 2.69116 -0.520459 -8.22869 \n", - "t = 0.633333, y[18] = 3.00066, x_hat[18] = 2.74064 -0.611671 -8.00887 \n", - "t = 0.666667, y[19] = 3.12921, x_hat[19] = 2.80348 -0.646838 -7.71773 \n", - "t = 0.7, y[20] = 2.85836, x_hat[20] = 2.79746 -0.860126 -7.66077 \n", - "t = 0.733333, y[21] = 2.83808, x_hat[21] = 2.78289 -1.075 -7.6073 \n", - "t = 0.766667, y[22] = 2.68975, x_hat[22] = 2.73536 -1.36255 -7.65263 \n", - "t = 0.8, y[23] = 2.66533, x_hat[23] = 2.6849 -1.63242 -7.67241 \n", - "t = 0.833333, y[24] = 2.81613, x_hat[24] = 2.66865 -1.7756 -7.52222 \n", - "t = 0.866667, y[25] = 2.81004, x_hat[25] = 2.65085 -1.90394 -7.36013 \n", - "t = 0.9, y[26] = 2.88322, x_hat[26] = 2.6486 -1.96825 -7.12303 \n", - "t = 0.933333, y[27] = 2.69789, x_hat[27] = 2.60682 -2.13543 -7.0323 \n", - "t = 0.966667, y[28] = 2.43422, x_hat[28] = 2.51459 -2.43157 -7.11077 \n", - "t = 1, y[29] = 2.23465, x_hat[29] = 2.39226 -2.78874 -7.2608 \n", - "t = 1.03333, y[30] = 2.30279, x_hat[30] = 2.30003 -3.02869 -7.25825 \n", - "t = 1.06667, y[31] = 2.0207, x_hat[31] = 2.16219 -3.37581 -7.38461 \n", - "t = 1.1, y[32] = 1.94394, x_hat[32] = 2.02788 -3.68333 -7.45684 \n", - "t = 1.13333, y[33] = 1.82498, x_hat[33] = 1.88867 -3.97757 -7.50949 \n", - "t = 1.16667, y[34] = 1.52526, x_hat[34] = 1.709 -4.35693 -7.65511 \n", - "t = 1.2, y[35] = 1.86968, x_hat[35] = 1.6258 -4.44465 -7.47015 \n", - "t = 1.23333, y[36] = 1.18073, x_hat[36] = 1.41784 -4.8526 -7.64198 \n", - "t = 1.26667, y[37] = 1.1073, x_hat[37] = 1.22633 -5.18515 -7.72431 \n", - "t = 1.3, y[38] = 0.916168, x_hat[38] = 1.02624 -5.51275 -7.79693 \n", - "t = 1.33333, y[39] = 0.678548, x_hat[39] = 0.810209 -5.85433 -7.87974 \n", - "t = 1.36667, y[40] = 0.562382, x_hat[40] = 0.604779 -6.14259 -7.90515 \n", - "t = 1.4, y[41] = 0.355468, x_hat[41] = 0.391401 -6.42721 -7.92567 \n", - "t = 1.43333, y[42] = -0.155607, x_hat[42] = 0.113297 -6.84513 -8.072 \n", - "t = 1.46667, y[43] = -0.287199, x_hat[43] = -0.147659 -7.19181 -8.14437 \n", - "t = 1.5, y[44] = -0.602973, x_hat[44] = -0.428043 -7.55795 -8.23086 \n", + "t = 0, x_hat[0]: 1.04203 0 0 \n", + "t = 0.025, y[0] = 1.04203, x_hat[0] = 1.04203 0 0 \n", + "t = 0.05, y[1] = 1.10727, x_hat[1] = 1.08709 0.898405 0.00110609 \n", + "t = 0.075, y[2] = 1.29135, x_hat[2] = 1.22062 2.38572 0.00356059 \n", + "t = 0.1, y[3] = 1.48485, x_hat[3] = 1.3876 3.46905 0.00712492 \n", + "t = 0.125, y[4] = 1.72826, x_hat[4] = 1.58997 4.40683 0.013793 \n", + "t = 0.15, y[5] = 1.74216, x_hat[5] = 1.71702 4.52169 0.0154301 \n", + "t = 0.175, y[6] = 2.11672, x_hat[6] = 1.93313 5.12414 0.031192 \n", + "t = 0.2, y[7] = 2.14529, x_hat[7] = 2.08865 5.26575 0.0374157 \n", + "t = 0.225, y[8] = 2.1603, x_hat[8] = 2.20235 5.18421 0.0316618 \n", + "t = 0.25, y[9] = 2.21269, x_hat[9] = 2.29893 5.04726 0.0173069 \n", + "t = 0.275, y[10] = 2.57709, x_hat[10] = 2.46442 5.19829 0.039685 \n", + "t = 0.3, y[11] = 2.66822, x_hat[11] = 2.61234 5.26324 0.0527055 \n", + "t = 0.325, y[12] = 2.51642, x_hat[12] = 2.69148 5.08935 0.00552738 \n", + "t = 0.35, y[13] = 2.76034, x_hat[13] = 2.80588 5.04887 -0.00849014 \n", + "t = 0.375, y[14] = 2.88132, x_hat[14] = 2.92141 5.01625 -0.0224249 \n", + "t = 0.4, y[15] = 2.88374, x_hat[15] = 3.0137 4.91887 -0.0729322 \n", + "t = 0.425, y[16] = 2.94485, x_hat[16] = 3.09895 4.80966 -0.139272 \n", + "t = 0.45, y[17] = 2.82867, x_hat[17] = 3.14447 4.59753 -0.288577 \n", + "t = 0.475, y[18] = 3.00066, x_hat[18] = 3.21103 4.45686 -0.396929 \n", + "t = 0.5, y[19] = 3.12921, x_hat[19] = 3.287 4.34976 -0.484796 \n", + "t = 0.525, y[20] = 2.85836, x_hat[20] = 3.29866 4.07181 -0.747944 \n", + "t = 0.55, y[21] = 2.83808, x_hat[21] = 3.30007 3.77743 -1.04221 \n", + "t = 0.575, y[22] = 2.68975, x_hat[22] = 3.26982 3.40694 -1.43333 \n", + "t = 0.6, y[23] = 2.66533, x_hat[23] = 3.23376 3.03344 -1.83638 \n", + "t = 0.625, y[24] = 2.81613, x_hat[24] = 3.2232 2.74466 -2.13797 \n", + "t = 0.65, y[25] = 2.81004, x_hat[25] = 3.20765 2.45227 -2.44388 \n", + "t = 0.675, y[26] = 2.88322, x_hat[26] = 3.20161 2.19814 -2.69673 \n", + "t = 0.7, y[27] = 2.69789, x_hat[27] = 3.15898 1.84844 -3.07251 \n", + "t = 0.725, y[28] = 2.43422, x_hat[28] = 3.07035 1.3784 -3.6016 \n", + "t = 0.75, y[29] = 2.23465, x_hat[29] = 2.95232 0.840664 -4.20755 \n", + "t = 0.775, y[30] = 2.30279, x_hat[30] = 2.85557 0.387859 -4.67894 \n", + "t = 0.8, y[31] = 2.0207, x_hat[31] = 2.7166 -0.169629 -5.27548 \n", + "t = 0.825, y[32] = 1.94394, x_hat[32] = 2.57681 -0.704099 -5.81837 \n", + "t = 0.85, y[33] = 1.82498, x_hat[33] = 2.42945 -1.23525 -6.3351 \n", + "t = 0.875, y[34] = 1.52526, x_hat[34] = 2.24401 -1.85275 -6.94502 \n", + "t = 0.9, y[35] = 1.86968, x_hat[35] = 2.13957 -2.19864 -7.17158 \n", + "t = 0.925, y[36] = 1.18073, x_hat[36] = 1.92439 -2.85128 -7.78709 \n", + "t = 0.95, y[37] = 1.1073, x_hat[37] = 1.72092 -3.43475 -8.28638 \n", + "t = 0.975, y[38] = 0.916168, x_hat[38] = 1.50771 -4.01434 -8.75833 \n", + "t = 1, y[39] = 0.678548, x_hat[39] = 1.2784 -4.60797 -9.22649 \n", + "t = 1.025, y[40] = 0.562382, x_hat[40] = 1.05708 -5.14471 -9.60338 \n", + "t = 1.05, y[41] = 0.355468, x_hat[41] = 0.827473 -5.67368 -9.95378 \n", + "t = 1.075, y[42] = -0.155607, x_hat[42] = 0.537759 -6.3418 -10.4545 \n", + "t = 1.1, y[43] = -0.287199, x_hat[43] = 0.262434 -6.93117 -10.8401 \n", + "t = 1.125, y[44] = -0.602973, x_hat[44] = -0.0317108 -7.53829 -11.2291 \n", + "t = 1.15, y[45] = -0.873817, x_hat[45] = -0.333878 -8.13198 -11.5854 \n", + "t = 1.175, y[46] = -1.14466, x_hat[46] = -0.64242 -8.70817 -11.9065 \n", + "t = 1.2, y[47] = -1.41551, x_hat[47] = -0.955917 -9.2638 -12.1908 \n", + "t = 1.225, y[48] = -1.68635, x_hat[48] = -1.27315 -9.79661 -12.4381 \n", + "t = 1.25, y[49] = -1.95719, x_hat[49] = -1.5931 -10.3051 -12.6488 \n", + "t = 1.275, y[50] = -2.22804, x_hat[50] = -1.91487 -10.7882 -12.8241 \n", + "t = 1.3, y[51] = -2.49888, x_hat[51] = -2.23773 -11.2455 -12.9653 \n", + "t = 1.325, y[52] = -2.76973, x_hat[52] = -2.56107 -11.6768 -13.0743 \n", + "t = 1.35, y[53] = -3.04057, x_hat[53] = -2.88438 -12.0825 -13.1531 \n", + "t = 1.375, y[54] = -3.31141, x_hat[54] = -3.20725 -12.4629 -13.2039 \n", + "t = 1.4, y[55] = -3.58226, x_hat[55] = -3.52932 -12.8188 -13.2289 \n", + "t = 1.425, y[56] = -3.8531, x_hat[56] = -3.85034 -13.1508 -13.2301 \n", + "t = 1.45, y[57] = -4.12395, x_hat[57] = -4.17008 -13.46 -13.2098 \n", + "t = 1.475, y[58] = -4.39479, x_hat[58] = -4.48837 -13.7472 -13.17 \n", + "t = 1.5, y[59] = -4.66564, x_hat[59] = -4.8051 -14.0134 -13.1126 \n", + "t = 1.525, y[60] = -4.93648, x_hat[60] = -5.12016 -14.2597 -13.0395 \n", + "t = 1.55, y[61] = -5.20732, x_hat[61] = -5.4335 -14.4872 -12.9525 \n", + "t = 1.575, y[62] = -5.47817, x_hat[62] = -5.74508 -14.6968 -12.853 \n", + "t = 1.6, y[63] = -5.74901, x_hat[63] = -6.05487 -14.8895 -12.7428 \n", + "t = 1.625, y[64] = -6.01986, x_hat[64] = -6.36288 -15.0664 -12.623 \n", + "t = 1.65, y[65] = -6.2907, x_hat[65] = -6.66912 -15.2283 -12.4951 \n", + "t = 1.675, y[66] = -6.56154, x_hat[66] = -6.97361 -15.3763 -12.3601 \n", + "t = 1.7, y[67] = -6.83239, x_hat[67] = -7.27637 -15.5111 -12.2191 \n", + "t = 1.725, y[68] = -7.10323, x_hat[68] = -7.57746 -15.6336 -12.0731 \n", + "t = 1.75, y[69] = -7.37408, x_hat[69] = -7.87691 -15.7446 -11.9229 \n", + "t = 1.775, y[70] = -7.64492, x_hat[70] = -8.17476 -15.8449 -11.7694 \n", + "t = 1.8, y[71] = -7.91576, x_hat[71] = -8.47108 -15.9351 -11.6132 \n", + "t = 1.825, y[72] = -8.18661, x_hat[72] = -8.7659 -16.0159 -11.4551 \n", + "t = 1.85, y[73] = -8.45745, x_hat[73] = -9.05929 -16.0881 -11.2954 \n", + "t = 1.875, y[74] = -8.7283, x_hat[74] = -9.3513 -16.152 -11.1349 \n", + "t = 1.9, y[75] = -8.99914, x_hat[75] = -9.64199 -16.2084 -10.9739 \n", + "t = 1.925, y[76] = -9.26999, x_hat[76] = -9.9314 -16.2578 -10.8128 \n", + "t = 1.95, y[77] = -9.54083, x_hat[77] = -10.2196 -16.3005 -10.652 \n", + "t = 1.975, y[78] = -9.81167, x_hat[78] = -10.5066 -16.3372 -10.4918 \n", + "t = 2, y[79] = -10.0825, x_hat[79] = -10.7925 -16.3682 -10.3325 \n", + "t = 2.025, y[80] = -10.3534, x_hat[80] = -11.0774 -16.3939 -10.1743 \n", + "t = 2.05, y[81] = -10.6242, x_hat[81] = -11.3612 -16.4148 -10.0175 \n", + "t = 2.075, y[82] = -10.8951, x_hat[82] = -11.6441 -16.4311 -9.86215 \n", + "t = 2.1, y[83] = -11.1659, x_hat[83] = -11.9261 -16.4432 -9.70853 \n", + "t = 2.125, y[84] = -11.4367, x_hat[84] = -12.2072 -16.4514 -9.55674 \n", + "t = 2.15, y[85] = -11.7076, x_hat[85] = -12.4874 -16.4559 -9.4069 \n", + "t = 2.175, y[86] = -11.9784, x_hat[86] = -12.7669 -16.4572 -9.25909 \n", + "t = 2.2, y[87] = -12.2493, x_hat[87] = -13.0456 -16.4553 -9.1134 \n", + "t = 2.225, y[88] = -12.5201, x_hat[88] = -13.3236 -16.4505 -8.96991 \n", "\n", - "Exec Success, Final kf states:-0.428043 -7.55795 -8.23086 \n" + "Exec Success, Final kf states:-13.3236 -16.4505 -8.96991 \n" ] } ], "source": [ - "pyrun_sim(10, true)" + "std::vector> g_res = pyrun_sim(1, true);" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "5e8f3f24-b003-4416-bb16-eabc480e437e", + "metadata": {}, + "outputs": [], + "source": [ + "std::vector py_g_pred;\n", + "std::vector py_g_est;" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "107f53c0-0f2d-47b9-a218-67c29e16b87b", + "metadata": {}, + "outputs": [], + "source": [ + "int k = g_res.size();" ] }, { "cell_type": "code", "execution_count": 23, - "id": "d0152901-3f6d-49e5-9138-126d1b072912", + "id": "63f5ba73-393b-4a0a-9c83-6829b56fc526", + "metadata": {}, + "outputs": [], + "source": [ + "std::vector ret_1d_vector(std::vector> res, int axis) {\n", + " std::vector ret;\n", + " for (int i = 0; i < res.size(); i++) {\n", + " ret.push_back(res[i][axis]);\n", + " }\n", + " return ret;\n", + "}\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "a6bc89d0-026f-4f52-b8c9-9a8fca6948f2", + "metadata": {}, + "outputs": [], + "source": [ + "py_g_pred = ret_1d_vector(g_res, 0);\n", + "py_g_est = ret_1d_vector(g_res, 1);" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "1b22451b-0bd7-48d6-8590-ff59b56a8031", + "metadata": {}, + "outputs": [], + "source": [ + "void printMatrix(const std::vector& vec) {\n", + " for (size_t i = 0; i < vec.size(); i++) {\n", + " std::cout << vec[i] << \" \";\n", + " }\n", + " std::cout << std::endl;\n", + " }\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "47fd1710-59a8-4d80-96de-04ef86e233fd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n", - "Exec Success, Final kf states:-0.454554 -7.75636 -8.59785 \n", - "\n", - "Exec Success, Final kf states:-0.505705 -8.13916 -9.30595 \n", - "\n", - "Exec Success, Final kf states:-0.488794 -8.0126 -9.07184 \n", - "\n", - "Exec Success, Final kf states:-0.536246 -8.36772 -9.72873 \n", - "\n", - "Exec Success, Final kf states:-0.560802 -8.55149 -10.0687 \n" + "0 0.00110609 0.00356059 0.00712492 0.013793 0.0154301 0.031192 0.0374157 0.0316618 0.0173069 0.039685 0.0527055 0.00552738 -0.00849014 -0.0224249 -0.0729322 -0.139272 -0.288577 -0.396929 -0.484796 -0.747944 -1.04221 -1.43333 -1.83638 -2.13797 -2.44388 -2.69673 -3.07251 -3.6016 -4.20755 -4.67894 -5.27548 -5.81837 -6.3351 -6.94502 -7.17158 -7.78709 -8.28638 -8.75833 -9.22649 -9.60338 -9.95378 -10.4545 -10.8401 -11.2291 -11.5854 -11.9065 -12.1908 -12.4381 -12.6488 -12.8241 -12.9653 -13.0743 -13.1531 -13.2039 -13.2289 -13.2301 -13.2098 -13.17 -13.1126 -13.0395 -12.9525 -12.853 -12.7428 -12.623 -12.4951 -12.3601 -12.2191 -12.0731 -11.9229 -11.7694 -11.6132 -11.4551 -11.2954 -11.1349 -10.9739 -10.8128 -10.652 -10.4918 -10.3325 -10.1743 -10.0175 -9.86215 -9.70853 -9.55674 -9.4069 -9.25909 -9.1134 -8.96991 \n" ] } ], "source": [ - "%%python\n", + "printMatrix(py_g_pred);" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "d0152901-3f6d-49e5-9138-126d1b072912", + "metadata": {}, + "outputs": [], + "source": [ + "// %%python\n", "\n", - "expand_factor = [500000, 1000000, 10000000, 20000000, 40000000]\n", - "benchmarks = []\n", + "// expand_factor = [1, 2, 3, 4, 5]\n", + "// benchmarks = []\n", "\n", - "for i in expand_factor:\n", - " time_kf = cppyy.gbl.pyrun_sim(int(i), 0)\n", - " benchmarks.append(time_kf)\n" + "// for i in expand_factor:\n", + "// time_kf = cppyy.gbl.pyrun_sim(int(i), 0)\n", + "// benchmarks.append(time_kf)\n", + "// " ] }, { "cell_type": "code", - "execution_count": 24, - "id": "aa8b6916-eb9e-449b-b11a-255609a94420", + "execution_count": 28, + "id": "2c0b5c86-7875-45d9-8113-00febd1be3d5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[706, 1404, 13545, 27020, 53801]\n" - ] - } - ], + "outputs": [], + "source": [ + "%%python\n", + "\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "e0d86e62-207f-4e14-80c1-2e04c8984fd6", + "metadata": {}, + "outputs": [], "source": [ "%%python\n", "\n", - "print(benchmarks)" + "true_val = 9.81\n", + "g_pred = list(cppyy.gbl.py_g_pred)\n", + "g_pred = list(-x for x in g_pred)\n", + "x = range(len(g_pred))\n", + "\n", + "\n", + "# Plot the constant green line\n", + "plt.axhline(y=true_val, color='green', linestyle='-')\n", + "\n", + "# Plot g_pred in orange\n", + "plt.plot(x, g_pred, color='orange', marker='o', label='KF Estimates')\n", + "plt.annotate(f'{true_val}', xy=(-0.5, true_val), color='green',\n", + " verticalalignment='center', horizontalalignment = 'left')\n", + "# Add labels and title\n", + "plt.xlabel('Index')\n", + "plt.ylabel('Acceleration (m/s₂)')\n", + "plt.title('True Value vs. g_pred')\n", + "plt.legend()\n", + "plt.savefig(\"1D_kf_plot.jpg\")\n", + " \n", + "plt.yscale('symlog')\n", + "# Show the plot\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "7f478ffc-eafa-4c98-9fcc-cd3d18833987", + "metadata": {}, + "source": [ + "" ] }, { "cell_type": "code", "execution_count": null, - "id": "2c0b5c86-7875-45d9-8113-00febd1be3d5", + "id": "145a5f37-15ae-4336-b126-fba8b6760de2", "metadata": {}, "outputs": [], "source": []