Build your own Load Balancer!

Jan 26, 2025

How to Build Your Own Load Balancer from Scratch

Load balancers are a critical component of modern web infrastructure. They distribute incoming traffic across multiple servers to ensure high availability, reliability, and scalability. While there are many off-the-shelf solutions like Nginx and HAProxy, building your own load balancer is a great way to understand how they work under the hood.

In this blog post, we’ll walk through the process of building a simple yet powerful load balancer from scratch using Python. By the end, you’ll have a working load balancer with features like round-robin scheduling, health checks, weighted distribution, and least connections.


What is a Load Balancer?

A load balancer acts as a traffic cop, routing client requests to multiple backend servers. It ensures that no single server is overwhelmed, improving performance and reliability. Common use cases include:


What We’ll Build

We’ll create a Python-based load balancer with the following features:

  1. Round-Robin Scheduling: Distribute requests evenly across servers.
  2. Health Checks: Monitor backend servers and remove unhealthy ones.
  3. Weighted Round-Robin: Assign weights to servers based on their capacity.
  4. Least Connections: Forward requests to the server with the fewest active connections.
  5. Logging: Track requests and monitor performance.

Step 1: Setting Up Backend Servers

Before building the load balancer, we need backend servers to distribute traffic to. For simplicity, we’ll use Python’s built-in HTTP server.

Start Backend Servers:

# Server 1 (Port 8001)
python -m http.server 8001
 
# Server 2 (Port 8002)
python -m http.server 8002

These servers will act as our backend. In a real-world scenario, you’d replace these with actual web applications.


Step 2: Building the Basic Load Balancer

We’ll start with a simple round-robin load balancer that forwards requests to backend servers in rotation.

Code for Basic Load Balancer:

import http.server
import socketserver
import requests
 
# List of backend servers
backend_servers = [
    "http://127.0.0.1:8001",
    "http://127.0.0.1:8002"
]
 
# Round-robin counter
current_server = 0
 
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
    def forward_request(self, backend):
        # Forward the request to the backend server
        try:
            response = requests.request(
                method=self.command,
                url=f"{backend}{self.path}",
                headers=self.headers,
                data=self.rfile.read(int(self.headers.get('Content-Length', 0))),
                allow_redirects=False
            )
 
            # Send the backend's response back to the client
            self.send_response(response.status_code)
            for header, value in response.headers.items():
                self.send_header(header, value)
            self.end_headers()
            self.wfile.write(response.content)
        except requests.RequestException as e:
            self.send_error(500, f"Error forwarding request: {e}")
 
    def do_GET(self):
        global current_server
        backend = backend_servers[current_server]
        print(f"Forwarding request to {backend}")
        self.forward_request(backend)
        current_server = (current_server + 1) % len(backend_servers)
 
    def do_POST(self):
        self.do_GET()  # For simplicity, handle POST the same way as GET
 
if __name__ == "__main__":
    PORT = 8080
    with socketserver.TCPServer(("", PORT), LoadBalancerHandler) as httpd:
        print(f"Load balancer running on port {PORT}...")
        httpd.serve_forever()

How It Works:


Step 3: Adding Health Checks

Health checks ensure that the load balancer only forwards requests to healthy servers.

Code for Health Checks:

import threading
import time
 
def health_check():
    while True:
        for i, server in enumerate(backend_servers):
            try:
                response = requests.get(server, timeout=2)
                if response.status_code != 200:
                    print(f"Server {server} is unhealthy. Removing from pool.")
                    backend_servers.pop(i)
            except requests.RequestException:
                print(f"Server {server} is unhealthy. Removing from pool.")
                backend_servers.pop(i)
        time.sleep(10)  # Check every 10 seconds
 
# Start health check in a separate thread
threading.Thread(target=health_check, daemon=True).start()

How It Works:


Step 4: Adding Weighted Round-Robin

Weighted round-robin allows you to assign weights to servers based on their capacity.

Code for Weighted Round-Robin:

# List of backend servers with weights
backend_servers = [
    {"url": "http://127.0.0.1:8001", "weight": 2},
    {"url": "http://127.0.0.1:8002", "weight": 1}
]
 
# Round-robin counter and weight tracking
current_server = 0
current_weight = 0
 
def get_next_server():
    global current_server, current_weight
    while True:
        server = backend_servers[current_server]
        if current_weight < server["weight"]:
            current_weight += 1
            return server["url"]
        else:
            current_weight = 0
            current_server = (current_server + 1) % len(backend_servers)

How It Works:


Step 5: Adding Least Connections

Least connections forwards requests to the server with the fewest active connections.

Code for Least Connections:

# Track active connections for each server
active_connections = {server["url"]: 0 for server in backend_servers}
 
def get_least_connections_server():
    return min(active_connections, key=active_connections.get)
 
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
    def do_GET(self):
        backend = get_least_connections_server()
        active_connections[backend] += 1
        print(f"Forwarding request to {backend} (connections: {active_connections[backend]})")
        self.forward_request(backend)
        active_connections[backend] -= 1

How It Works:


Step 6: Adding Logging

Logging helps you monitor the load balancer’s activity.

Code for Logging:

import logging
 
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
 
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
    def do_GET(self):
        backend = get_next_server()  # or get_least_connections_server()
        logging.info(f"Forwarding request to {backend}")
        self.forward_request(backend)

How It Works:


Final Code

Here’s the complete code with all the features integrated:

import http.server
import socketserver
import requests
import threading
import time
import logging
 
# List of backend servers with weights
backend_servers = [
    {"url": "http://127.0.0.1:8001", "weight": 2},
    {"url": "http://127.0.0.1:8002", "weight": 1}
]
 
# Track active connections for each server
active_connections = {server["url"]: 0 for server in backend_servers}
 
# Round-robin counter and weight tracking
current_server = 0
current_weight = 0
 
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
 
def health_check():
    while True:
        for i, server in enumerate(backend_servers):
            try:
                response = requests.get(server["url"], timeout=2)
                if response.status_code != 200:
                    logging.warning(f"Server {server['url']} is unhealthy. Removing from pool.")
                    backend_servers.pop(i)
            except requests.RequestException:
                logging.warning(f"Server {server['url']} is unhealthy. Removing from pool.")
                backend_servers.pop(i)
        time.sleep(10)  # Check every 10 seconds
 
def get_next_server():
    global current_server, current_weight
    while True:
        server = backend_servers[current_server]
        if current_weight < server["weight"]:
            current_weight += 1
            return server["url"]
        else:
            current_weight = 0
            current_server = (current_server + 1) % len(backend_servers)
 
def get_least_connections_server():
    return min(active_connections, key=active_connections.get)
 
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
    def forward_request(self, backend):
        content_length = int(self.headers.get('Content-Length', 0))
        body = self.rfile.read(content_length) if content_length else None
 
        try:
            response = requests.request(
                method=self.command,
                url=f"{backend}{self.path}",
                headers=self.headers,
                data=body,
                allow_redirects=False
            )
 
            self.send_response(response.status_code)
            for header, value in response.headers.items():
                self.send_header(header, value)
            self.end_headers()
            self.wfile.write(response.content)
        except requests.RequestException as e:
            self.send_error(500, f"Error forwarding request: {e}")
 
    def do_GET(self):
        backend = get_next_server()  # or get_least_connections_server()
        active_connections[backend] += 1
        logging.info(f"Forwarding request to {backend} (connections: {active_connections[backend]})")
        self.forward_request(backend)
        active_connections[backend] -= 1
 
    def do_POST(self):
        self.do_GET()
 
if __name__ == "__main__":
    # Start health check in a separate thread
    threading.Thread(target=health_check, daemon=True).start()
 
    PORT = 8080
    with socketserver.TCPServer(("", PORT), LoadBalancerHandler) as httpd:
        logging.info(f"Load balancer running on port {PORT}...")
        httpd.serve_forever()

Feel free to experiment with the code and add your own features. Happy code EATING :)!