diff --git a/src/osqp/codegen/utils.py b/src/osqp/codegen/utils.py index da4c8016..27c7900b 100644 --- a/src/osqp/codegen/utils.py +++ b/src/osqp/codegen/utils.py @@ -19,40 +19,47 @@ def write_vec(f, vec, name, vec_type): """ Write vector to file """ - f.write('%s %s[%d] = {\n' % (vec_type, name, len(vec))) + if len(vec) > 0: - # Write vector elements - for i in range(len(vec)): - if vec_type == 'c_float': - f.write('(c_float)%.20f,\n' % vec[i]) - else: - f.write('%i,\n' % vec[i]) + f.write('%s %s[%d] = {\n' % (vec_type, name, len(vec))) + + # Write vector elements + for i in range(len(vec)): + if vec_type == 'c_float': + f.write('(c_float)%.20f,\n' % vec[i]) + else: + f.write('%i,\n' % vec[i]) - f.write('};\n') + f.write('};\n') def write_vec_extern(f, vec, name, vec_type): """ Write vector prototype to file """ - f.write("extern %s %s[%d];\n" % (vec_type, name, len(vec))) + if len(vec) > 0: + f.write("extern %s %s[%d];\n" % (vec_type, name, len(vec))) def write_mat(f, mat, name): """ Write scipy sparse matrix in CSC form to file """ - write_vec(f, mat['i'], name + '_i', 'c_int') write_vec(f, mat['p'], name + '_p', 'c_int') - write_vec(f, mat['x'], name + '_x', 'c_float') + if len(mat['x']) > 0: + write_vec(f, mat['i'], name + '_i', 'c_int') + write_vec(f, mat['x'], name + '_x', 'c_float') f.write("csc %s = {" % name) f.write("%d, " % mat['nzmax']) f.write("%d, " % mat['m']) f.write("%d, " % mat['n']) f.write("%s_p, " % name) - f.write("%s_i, " % name) - f.write("%s_x, " % name) + if len(mat['x']) > 0: + f.write("%s_i, " % name) + f.write("%s_x, " % name) + else: + f.write("0, 0, ") f.write("%d};\n" % mat['nz']) @@ -219,9 +226,20 @@ def write_linsys_solver_src(f, linsys_solver, embedded_flag): f.write("%d, " % linsys_solver['m']) if embedded_flag != 1: - f.write("linsys_solver_Pdiag_idx, ") + if len(linsys_solver['Pdiag_idx']) > 0: + linsys_solver_Pdiag_idx_string = 'linsys_solver_Pdiag_idx' + linsys_solver_PtoKKT_string = 'linsys_solver_PtoKKT' + else: + linsys_solver_Pdiag_idx_string = '0' + linsys_solver_PtoKKT_string = '0' + if len(linsys_solver['AtoKKT']) > 0: + linsys_solver_AtoKKT_string = 'linsys_solver_AtoKKT' + else: + linsys_solver_AtoKKT_string = '0' + f.write("%s, " % linsys_solver_Pdiag_idx_string) f.write("%d, " % linsys_solver['Pdiag_n']) - f.write("&linsys_solver_KKT, linsys_solver_PtoKKT, linsys_solver_AtoKKT, linsys_solver_rhotoKKT, " + + f.write("&linsys_solver_KKT, %s, %s, linsys_solver_rhotoKKT, " + % (linsys_solver_PtoKKT_string, linsys_solver_AtoKKT_string) + "linsys_solver_D, linsys_solver_etree, linsys_solver_Lnz, " + "linsys_solver_iwork, linsys_solver_bwork, linsys_solver_fwork, ")