We had an app that was using Confluent Kafka as the queue, and I wanted to scale out machines based on the current lag in the queue.

At the time of writing this, I couldn’t find a Confluent Kafka snippet that does this in python; so if you were looking for the same, here you go :)

import multiprocessing
import os
import time

from confluent_kafka import Consumer, TopicPartition

import ccloud_lib


def get_partition_lag(partition: int):
    topic_name = "production"
    CONFLUENT_CONFIG = {
        "bootstrap.servers": os.getenv("KAFKA_HOST"),
        "security.protocol": "SASL_SSL",
        "sasl.mechanisms": "PLAIN",
        "sasl.username": os.getenv("KAFKA_CONSUMER_KEY"),
        "sasl.password": os.getenv("KAFKA_CONSUMER_SECRET"),
        "schema.registry.url": "https://{{ SR_ENDPOINT }}",
        "basic.auth.credentials.source": "USER_INFO",
        "basic.auth.user.info": "{{ SR_API_KEY }}:{{ SR_API_SECRET }}",
    }
    conf = ccloud_lib.pop_schema_registry_params_from_config(CONFLUENT_CONFIG)
    conf["group.id"] = "fluidstack-consumers"
    conf["enable.auto.commit"] = False
    consumer = Consumer(conf)
    partition_lag = {}
    print(f"Getting lag for topic: {topic_name}, partition: {partition}")
    topic = TopicPartition(topic_name, partition)
    consumer.assign([topic])
    committed = consumer.committed([topic])[0].offset
    last_offset = consumer.get_watermark_offsets(topic)[1]
    if committed < 0:
        return {}
    partition_lag[partition] = last_offset - committed
    print(f"Partition: {partition}, lag:{last_offset-committed}")
    consumer.close()
    return partition_lag


if __name__ == "__main__":
    topic_wise_lag = {}
    paritition_count = 10

    t0 = time.perf_counter()
    pool = multiprocessing.Pool(processes=paritition_count)

    inputs = [x for x in range(paritition_count)]

    outputs = pool.map(get_partition_lag, inputs)
    print(f"Time taken: {time.perf_counter()-t0}")

    for output in outputs:
        topic_wise_lag.update(output)

    max_lag = max(zip(topic_wise_lag.values(), topic_wise_lag.keys()))[1]

    print(f"Max Lag: {topic_wise_lag[max_lag]}")
    print(f"Total Unconsumed: {sum(topic_wise_lag.values())}")

And here’s a gist if you prefer that https://gist.github.com/arjun921/4a2cc287d10487f37b08bf9f3eacfc09

Hope this helps :)

Peace ✌🏾