diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 443d02cfb..30fecefdb 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -540,6 +540,43 @@ static PyObject *Message_set_key(Message *self, PyObject *new_key) { Py_RETURN_NONE; } +static PyObject *Message_set_topic(Message *self, PyObject *new_topic) { + if (self->topic) + Py_DECREF(self->topic); + self->topic = new_topic; + Py_INCREF(self->topic); + + Py_RETURN_NONE; +} + +static PyObject *Message_set_error(Message *self, PyObject *new_error) { + if (self->error) + Py_DECREF(self->error); + self->error = new_error; + Py_INCREF(self->error); + + Py_RETURN_NONE; +} + +static PyObject *Message_reduce(Message *self, PyObject *Py_UNUSED(ignored)) { +#ifdef RD_KAFKA_V_HEADERS + if (!self->headers && self->c_headers) { + self->headers = c_headers_to_py(self->c_headers); + rd_kafka_headers_destroy(self->c_headers); + self->c_headers = NULL; + } +#endif + + return Py_BuildValue( + "O(NNNNNiiiLN)", Py_TYPE(self), Message_topic(self, NULL), + Message_value(self, NULL), Message_key(self, NULL), + Message_headers(self, NULL), Message_error(self, NULL), + self->partition, self->offset, self->leader_epoch, self->timestamp, + (self->latency >= 0 + ? PyFloat_FromDouble((double)self->latency / 1000000.0) + : cfl_PyInt_FromInt(-1))); +} + static PyMethodDef Message_methods[] = { {"error", (PyCFunction)Message_error, METH_NOARGS, " The message object is also used to propagate errors and events, " @@ -634,6 +671,22 @@ static PyMethodDef Message_methods[] = { " :returns: None.\n" " :rtype: None\n" "\n"}, + {"set_topic", (PyCFunction)Message_set_topic, METH_O, + " Set the field 'Message.topic' with new value.\n" + "\n" + " :param object value: Message.topic.\n" + " :returns: None.\n" + " :rtype: None\n" + "\n"}, + {"set_error", (PyCFunction)Message_set_error, METH_O, + " Set the field 'Message.error' with new value.\n" + "\n" + " :param object value: Message.error.\n" + " :returns: None.\n" + " :rtype: None\n" + "\n"}, + {"__reduce__", (PyCFunction)Message_reduce, METH_NOARGS, + " Function for serializing Message using the pickle protocol."}, {NULL}}; static int Message_clear(Message *self) { @@ -783,14 +836,66 @@ static int Message_traverse(Message *self, visitproc visit, void *arg) { return 0; } +static PyObject *Message_richcompare(PyObject *self, PyObject *other, int op) { + if (op != Py_EQ && op != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + if (self == other) { + return op == Py_EQ ? Py_True : Py_False; + } + + if (!PyObject_TypeCheck(other, &MessageType)) { + return op == Py_EQ ? Py_False : Py_True; + } + + Message *msg_self = (Message *)self; + Message *msg_other = (Message *)other; + + int result; + +#define _LOCAL_COMPARE(left, right) \ + do { \ + result = PyObject_RichCompareBool(left, right, Py_EQ); \ + if (result < 0) \ + return NULL; \ + if (result == 0) \ + return op == Py_EQ ? Py_False : Py_True; \ + } while (0) + _LOCAL_COMPARE(msg_self->topic, msg_other->topic); + _LOCAL_COMPARE(msg_self->value, msg_other->value); + _LOCAL_COMPARE(msg_self->key, msg_other->key); + _LOCAL_COMPARE(msg_self->headers, msg_other->headers); + _LOCAL_COMPARE(msg_self->error, msg_other->error); +#undef _LOCAL_COMPARE + +#define _LOCAL_COMPARE(left, right) \ + do { \ + if (left != right) \ + return op == Py_EQ ? Py_False : Py_True; \ + } while (0) + _LOCAL_COMPARE(msg_self->partition, msg_other->partition); + _LOCAL_COMPARE(msg_self->offset, msg_other->offset); + _LOCAL_COMPARE(msg_self->leader_epoch, msg_other->leader_epoch); + _LOCAL_COMPARE(msg_self->timestamp, msg_other->timestamp); + // latency is skipped, it is a float and not that significant. +#undef _LOCAL_COMPARE + + return Py_True; +} + static Py_ssize_t Message__len__(Message *self) { - return self->value ? PyObject_Length(self->value) : 0; + return self->value && self->value != Py_None + ? PyObject_Length(self->value) + : 0; } static PySequenceMethods Message_seq_methods = { (lenfunc)Message__len__ /* sq_length */ }; + PyTypeObject MessageType = { PyVarObject_HEAD_INIT(NULL, 0) "cimpl.Message", /*tp_name*/ sizeof(Message), /*tp_basicsize*/ @@ -853,7 +958,7 @@ PyTypeObject MessageType = { "\n", /*tp_doc*/ (traverseproc)Message_traverse, /* tp_traverse */ (inquiry)Message_clear, /* tp_clear */ - 0, /* tp_richcompare */ + (richcmpfunc)Message_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ @@ -911,8 +1016,6 @@ PyObject *Message_new0(const Handle *handle, const rd_kafka_message_t *rkm) { return (PyObject *)self; } - - /**************************************************************************** * * @@ -1066,7 +1169,6 @@ PyTypeObject UuidType = { }; - /**************************************************************************** * * diff --git a/tests/test_Message.py b/tests/test_Message.py new file mode 100644 index 000000000..d6a31d0ca --- /dev/null +++ b/tests/test_Message.py @@ -0,0 +1,263 @@ +import pickle +import pytest +import sys + +from confluent_kafka.cimpl import Message + + +def empty_message_1(): + return Message() + + +def empty_message_2(): + return Message(None, None, None, None, None, -2, -2, -2, -2, -2) + + +def empty_message_3(): + msg = Message() + msg.set_topic(None) + msg.set_value(None) + msg.set_key(None) + msg.set_headers(None) + msg.set_error(None) + return msg + + +def empty_message_4(): + return Message.__new__(Message) + + +class Message2(Message): + def __init__(self, *args): + super().__init__(*args) + self.dummy = 1 + + +def empty_message_5(): + msg = Message2() + assert type(msg) is Message2 + assert msg.dummy == 1 + return msg + + +@pytest.mark.parametrize( + "make_message", + [ + empty_message_1, + empty_message_2, + empty_message_3, + empty_message_4, + empty_message_5, + ], +) +def test_message_create_empty(make_message): + # Checks the creation of an empty Message with no data. + + msg = make_message() + + assert len(msg) == 0 + assert msg.topic() is None + assert msg.value() is None + assert msg.key() is None + assert msg.headers() is None + assert msg.error() is None + assert msg.partition() is None + assert msg.offset() is None + assert msg.leader_epoch() is None + assert msg.timestamp() == (0, 0) + assert msg.latency() is None + assert str(msg) + assert repr(msg) + + subtest_pickling(msg, (None,) * 5 + (-1, -1, -1, 0, -1)) + + +def test_message_create_with_dummy(): + # Checks the creation of an Message with any kind of dummy arguments. Useful + # to create Message objects in unit tests with Mock objects as arguments, + # for instance. + + dummy = object() + msg = Message(dummy, dummy, dummy, dummy, dummy) + assert msg.topic() is dummy + assert msg.value() is dummy + assert msg.key() is dummy + assert msg.headers() is dummy + assert msg.error() is dummy + assert str(msg) + assert repr(msg) + + +def test_message_create_with_args(): + # Tests all positional arguments. + + headers, error = [], object() + msg = Message("t", "v", "k", headers, error, 1, 2, 3, 4, 5.67) + assert len(msg) == 1 + assert msg.topic() == "t" + assert msg.value() == "v" + assert msg.key() == "k" + assert msg.headers() is headers + assert msg.error() is error + assert msg.partition() == 1 + assert msg.offset() == 2 + assert msg.leader_epoch() == 3 + assert msg.timestamp() == (0, 4) + assert msg.latency() == 5.67 + assert str(msg) + assert repr(msg) + + +def test_message_create_with_kwds(): + # Tests all keyword arguments. + + headers, error = [], object() + msg = Message( + topic="t", + value="v", + key="k", + headers=headers, + error=error, + partition=1, + offset=2, + leader_epoch=3, + timestamp=4, + latency=5.67, + ) + assert len(msg) == 1 + assert msg.topic() == "t" + assert msg.value() == "v" + assert msg.key() == "k" + assert msg.headers() is headers + assert msg.error() is error + assert msg.partition() == 1 + assert msg.offset() == 2 + assert msg.leader_epoch() == 3 + assert msg.timestamp() == (0, 4) + assert msg.latency() == 5.67 + + +def test_message_set_properties(): + # Tests all set_() methods. + + headers, error = [], object() + msg = Message() + assert len(msg) == 0 + msg.set_topic("t") + assert msg.topic() == "t" + msg.set_value("v") + assert msg.value() == "v" + assert len(msg) == 1 + msg.set_key("k") + assert msg.key() == "k" + msg.set_headers(headers) + assert msg.headers() is headers + msg.set_error(error) + assert msg.error() is error + + +@pytest.mark.parametrize("value", [None, object()]) +def test_message_exceptions(value): + # Tests many situations which should raise TypeError. This is important to + # ensure the "self" object is type checked to be a Message before trying to + # do anything with it internally in the C code. + + with pytest.raises(TypeError): + Message.__new__(value) + with pytest.raises(TypeError): + Message.__new__(str) + + with pytest.raises(TypeError): + Message.__init__(value) + + with pytest.raises(TypeError): + Message.topic(value) + with pytest.raises(TypeError): + Message.value(value) + with pytest.raises(TypeError): + Message.key(value) + with pytest.raises(TypeError): + Message.headers(value) + with pytest.raises(TypeError): + Message.error(value) + + with pytest.raises(TypeError): + Message.partition(value) + with pytest.raises(TypeError): + Message.offset(value) + with pytest.raises(TypeError): + Message.leader_epoch(value) + with pytest.raises(TypeError): + Message.timestamp(value) + with pytest.raises(TypeError): + Message.latency(value) + + with pytest.raises(TypeError): + Message.set_topic(value, "t") + with pytest.raises(TypeError): + Message.set_value(value, "v") + with pytest.raises(TypeError): + Message.set_key(value, "k") + with pytest.raises(TypeError): + Message.set_headers(value, []) + with pytest.raises(TypeError): + Message.set_error(value, object()) + + with pytest.raises(TypeError): + len(Message(value=1)) + with pytest.raises(TypeError): + len(Message(value=object())) + + +def subtest_pickling(msg, exp_args): + assert msg.__reduce__() == (type(msg), exp_args) + + pickled = pickle.dumps(msg) + restored = pickle.loads(pickled) + + assert restored.__reduce__() == (type(msg), exp_args) + assert msg is not restored + assert type(msg) is type(restored) + + assert len(msg) == len(restored) + assert msg.topic() == restored.topic() + assert msg.value() == restored.value() + assert msg.key() == restored.key() + assert msg.headers() == restored.headers() + assert msg.error() == restored.error() + assert msg.partition() == restored.partition() + assert msg.offset() == restored.offset() + assert msg.leader_epoch() == restored.leader_epoch() + assert msg.timestamp() == restored.timestamp() + assert msg.latency() == restored.latency() + + +def test_message_pickle(): + args = "t", "v", "k", [], None, 1, 2, 3, 4, 5.67 + msg = Message(*args) + assert msg.latency() == 5.67 + + subtest_pickling(msg, args) + + +def test_message_compare(): + args0 = "t", "v", "k", [], None, 1, 2, 3, 4, 5.67 + args1 = "t", "v", "z", [], None, 1, 2, 3, 4, 5.67 + + msg0 = Message(*args0) + msg01 = Message(*args0) + msg1 = Message(*args1) + + assert msg0 == msg0 + assert msg0 == msg01 + assert msg0 != msg1 + assert msg0 != 1 + assert msg0 != None + assert msg0 != object() + + with pytest.raises(TypeError): + assert msg0 < msg0 + with pytest.raises(TypeError): + assert msg0 < None + with pytest.raises(TypeError): + assert msg0 < object() diff --git a/tests/test_message.py b/tests/test_message.py index f894ff121..b2a19dee8 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,6 +1,7 @@ -# #!/usr/bin/env python +#!/usr/bin/env python from confluent_kafka import KafkaError, Message +import pickle def test_init_no_params(): @@ -68,3 +69,86 @@ def test_set_value(): m = Message() m.set_value(b"value") assert m.value() == b"value" + + +def test_set_topic(): + m = Message() + m.set_topic("test_topic") + assert m.topic() == "test_topic" + m.set_topic("another_topic") + assert m.topic() == "another_topic" + + +def test_set_error(): + m = Message() + m.set_error(KafkaError(0)) + assert m.error() == KafkaError(0) + m.set_error(KafkaError(1)) + assert m.error() == KafkaError(1) + + +def test_equality(): + m1 = Message( + topic="test", + partition=1, + offset=2, + key=b"key", + value=b"value", + headers=[("h1", "v1")], + error=KafkaError(0), + timestamp=(1, 1762499956), + leader_epoch=1762499956, + ) + m2 = Message( + topic="test", + partition=1, + offset=2, + key=b"key", + value=b"value", + headers=[("h1", "v1")], + error=KafkaError(0), + timestamp=(1, 1762499956), + leader_epoch=1762499956, + ) + m3 = Message( + topic="different", + partition=1, + offset=2, + key=b"key", + value=b"value", + ) + + assert m1 == m2 + assert m1 != m3 + assert m2 != m3 + assert m1 != "not a message" + + +def test_pickling(): + m = Message( + topic="test", + partition=1, + offset=2, + key=b"key", + value=b"value", + headers=[("h1", "v1")], + error=KafkaError(0), + timestamp=(1, 1762499956), + latency=0.05, + leader_epoch=1762499956, + ) + + # Pickle and unpickle + pickled = pickle.dumps(m) + unpickled = pickle.loads(pickled) + + assert unpickled.topic() == m.topic() + assert unpickled.partition() == m.partition() + assert unpickled.offset() == m.offset() + assert unpickled.key() == m.key() + assert unpickled.value() == m.value() + assert unpickled.headers() == m.headers() + assert unpickled.error() == m.error() + assert unpickled.timestamp() == m.timestamp() + assert unpickled.latency() == m.latency() + assert unpickled.leader_epoch() == m.leader_epoch()