Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error)
outputTensorsValues, noutputs, NULL /* target_opers */, 0 /* ntargets */,
NULL /* run_Metadata */, status);

bool delete_output = true;
if (TF_GetCode(status) != TF_OK) {
RAI_SetError(error, RAI_EMODELRUN, TF_Message(status));
delete_output = false;
goto cleanup;
}

Expand Down Expand Up @@ -575,8 +577,10 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error)
}
TF_DeleteTensor(inputTensorsValues[i]);
}
for (size_t i = 0; i < noutputs; i++) {
TF_DeleteTensor(outputTensorsValues[i]);
if (delete_output) {
for (size_t i = 0; i < noutputs; i++) {
TF_DeleteTensor(outputTensorsValues[i]);
}
}
return res;
}
Expand Down
3 changes: 3 additions & 0 deletions tests/flow/test_data/frozen_bad_model.pb
Git LFS file not shown
11 changes: 11 additions & 0 deletions tests/flow/tests_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,3 +759,14 @@ def run():
env.assertEqual(out_values, [b'this is', b'the first batch'])
out_values = con.execute_command('AI.TENSORGET', 'second_batch{1}', 'VALUES')
env.assertEqual(out_values, [b'that is', b'the second batch'])

@skip_if_no_TF
def test_bad_execution_model(env):
con = get_connection(env, '{1}')

model_pb = load_file_content('frozen_bad_model.pb')
ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 1, 'x', 'OUTPUTS', 1, 'Identity', 'BLOB', model_pb)
env.assertEqual(ret, b'OK')
con.execute_command('AI.TENSORSET', 'my_str_tensor{1}', 'STRING', 4, 'BLOB', "how do I extract keys from a dict into a list?\x00debug public static void main(string[] args) {...}\x00should I use def main()\x00type hinting for list?\x00")
env.assertEqual(ret, b'OK')
check_error(env, con, 'AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'my_str_tensor{1}', 'OUTPUTS', 1, 'foo{1}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document what is the expected error?