diff --git a/src/grib_context.c b/src/grib_context.c index ed2f62850..c02c25000 100644 --- a/src/grib_context.c +++ b/src/grib_context.c @@ -242,20 +242,22 @@ static void default_print(const grib_context* c, void* descriptor, const char* m void grib_context_set_print_proc(grib_context* c, grib_print_proc p) { - c = c ? c : grib_context_get_default(); - c->print = p; + c = c ? c : grib_context_get_default(); + /* Set logging back to the default if p is NULL */ + c->print = (p ? p : &default_print); } void grib_context_set_debug(grib_context* c, int mode) { - c = c ? c : grib_context_get_default(); + c = c ? c : grib_context_get_default(); c->debug = mode; } void grib_context_set_logging_proc(grib_context* c, grib_log_proc p) { - c = c ? c : grib_context_get_default(); - c->output_log = p; + c = c ? c : grib_context_get_default(); + /* Set logging back to the default if p is NULL */ + c->output_log = (p ? p : &default_log); } long grib_get_api_version() diff --git a/tests/unit_tests.c b/tests/unit_tests.c index 43870f7ef..6075baa75 100644 --- a/tests/unit_tests.c +++ b/tests/unit_tests.c @@ -14,6 +14,7 @@ #define NUMBER(x) (sizeof(x) / sizeof(x[0])) int assertion_caught = 0; +int logging_caught = 0; typedef enum { @@ -1468,6 +1469,28 @@ static void test_assertion_catching() free(list); } + +static void my_logging_proc(const grib_context* c, int level, const char* mesg) +{ + logging_caught = 1; +} +static void test_logging_proc() +{ + grib_context* context = grib_context_get_default(); + Assert(logging_caught == 0); + + /* Override default behaviour */ + grib_context_set_logging_proc(context, my_logging_proc); + grib_context_log(context, GRIB_LOG_ERROR, "This error will be handled by me"); + Assert(logging_caught == 1); + + /* Restore the logging proc */ + logging_caught = 0; + grib_context_set_logging_proc(context, NULL); + grib_context_log(context, GRIB_LOG_ERROR, "This will come out as normal"); + Assert(logging_caught == 0); +} + static void test_concept_condition_strings() { int err = 0; @@ -1631,10 +1654,12 @@ static void test_parse_keyval_string() free( (void*)values3[0].name ); } + int main(int argc, char** argv) { printf("Doing unit tests. ecCodes version = %ld\n", grib_get_api_version()); + test_logging_proc(); test_grib_binary_search(); test_parse_keyval_string();