How to Build Your Own Load Balancer from Scratch
Load balancers are a critical component of modern web infrastructure. They distribute incoming traffic across multiple servers to ensure high availability, reliability, and scalability. While there are many off-the-shelf solutions like Nginx and HAProxy, building your own load balancer is a great way to understand how they work under the hood.
In this blog post, we’ll walk through the process of building a simple yet powerful load balancer from scratch using Python. By the end, you’ll have a working load balancer with features like round-robin scheduling, health checks, weighted distribution, and least connections.
What is a Load Balancer?
A load balancer acts as a traffic cop, routing client requests to multiple backend servers. It ensures that no single server is overwhelmed, improving performance and reliability. Common use cases include:
- Distributing traffic across web servers.
- Handling failover in case a server goes down.
- Scaling applications horizontally.
What We’ll Build
We’ll create a Python-based load balancer with the following features:
- Round-Robin Scheduling: Distribute requests evenly across servers.
- Health Checks: Monitor backend servers and remove unhealthy ones.
- Weighted Round-Robin: Assign weights to servers based on their capacity.
- Least Connections: Forward requests to the server with the fewest active connections.
- Logging: Track requests and monitor performance.
Step 1: Setting Up Backend Servers
Before building the load balancer, we need backend servers to distribute traffic to. For simplicity, we’ll use Python’s built-in HTTP server.
Start Backend Servers:
# Server 1 (Port 8001)
python -m http.server 8001
# Server 2 (Port 8002)
python -m http.server 8002
These servers will act as our backend. In a real-world scenario, you’d replace these with actual web applications.
Step 2: Building the Basic Load Balancer
We’ll start with a simple round-robin load balancer that forwards requests to backend servers in rotation.
Code for Basic Load Balancer:
import http.server
import socketserver
import requests
# List of backend servers
backend_servers = [
"http://127.0.0.1:8001",
"http://127.0.0.1:8002"
]
# Round-robin counter
current_server = 0
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
def forward_request(self, backend):
# Forward the request to the backend server
try:
response = requests.request(
method=self.command,
url=f"{backend}{self.path}",
headers=self.headers,
data=self.rfile.read(int(self.headers.get('Content-Length', 0))),
allow_redirects=False
)
# Send the backend's response back to the client
self.send_response(response.status_code)
for header, value in response.headers.items():
self.send_header(header, value)
self.end_headers()
self.wfile.write(response.content)
except requests.RequestException as e:
self.send_error(500, f"Error forwarding request: {e}")
def do_GET(self):
global current_server
backend = backend_servers[current_server]
print(f"Forwarding request to {backend}")
self.forward_request(backend)
current_server = (current_server + 1) % len(backend_servers)
def do_POST(self):
self.do_GET() # For simplicity, handle POST the same way as GET
if __name__ == "__main__":
PORT = 8080
with socketserver.TCPServer(("", PORT), LoadBalancerHandler) as httpd:
print(f"Load balancer running on port {PORT}...")
httpd.serve_forever()
How It Works:
- The load balancer listens on port
8080
. - It forwards incoming requests to backend servers in a round-robin fashion.
- The
current_server
variable keeps track of which server to forward the next request to.
Step 3: Adding Health Checks
Health checks ensure that the load balancer only forwards requests to healthy servers.
Code for Health Checks:
import threading
import time
def health_check():
while True:
for i, server in enumerate(backend_servers):
try:
response = requests.get(server, timeout=2)
if response.status_code != 200:
print(f"Server {server} is unhealthy. Removing from pool.")
backend_servers.pop(i)
except requests.RequestException:
print(f"Server {server} is unhealthy. Removing from pool.")
backend_servers.pop(i)
time.sleep(10) # Check every 10 seconds
# Start health check in a separate thread
threading.Thread(target=health_check, daemon=True).start()
How It Works:
- The health check runs in a separate thread.
- It periodically sends a request to each backend server.
- If a server fails to respond or returns an error, it’s removed from the pool.
Step 4: Adding Weighted Round-Robin
Weighted round-robin allows you to assign weights to servers based on their capacity.
Code for Weighted Round-Robin:
# List of backend servers with weights
backend_servers = [
{"url": "http://127.0.0.1:8001", "weight": 2},
{"url": "http://127.0.0.1:8002", "weight": 1}
]
# Round-robin counter and weight tracking
current_server = 0
current_weight = 0
def get_next_server():
global current_server, current_weight
while True:
server = backend_servers[current_server]
if current_weight < server["weight"]:
current_weight += 1
return server["url"]
else:
current_weight = 0
current_server = (current_server + 1) % len(backend_servers)
How It Works:
- Servers with higher weights receive more requests.
- The
current_weight
variable ensures that requests are distributed according to the weights.
Step 5: Adding Least Connections
Least connections forwards requests to the server with the fewest active connections.
Code for Least Connections:
# Track active connections for each server
active_connections = {server["url"]: 0 for server in backend_servers}
def get_least_connections_server():
return min(active_connections, key=active_connections.get)
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
backend = get_least_connections_server()
active_connections[backend] += 1
print(f"Forwarding request to {backend} (connections: {active_connections[backend]})")
self.forward_request(backend)
active_connections[backend] -= 1
How It Works:
- The load balancer tracks the number of active connections for each server.
- Requests are forwarded to the server with the fewest active connections.
Step 6: Adding Logging
Logging helps you monitor the load balancer’s activity.
Code for Logging:
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
backend = get_next_server() # or get_least_connections_server()
logging.info(f"Forwarding request to {backend}")
self.forward_request(backend)
How It Works:
- Logs are written to the console with timestamps and severity levels.
Final Code
Here’s the complete code with all the features integrated:
import http.server
import socketserver
import requests
import threading
import time
import logging
# List of backend servers with weights
backend_servers = [
{"url": "http://127.0.0.1:8001", "weight": 2},
{"url": "http://127.0.0.1:8002", "weight": 1}
]
# Track active connections for each server
active_connections = {server["url"]: 0 for server in backend_servers}
# Round-robin counter and weight tracking
current_server = 0
current_weight = 0
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def health_check():
while True:
for i, server in enumerate(backend_servers):
try:
response = requests.get(server["url"], timeout=2)
if response.status_code != 200:
logging.warning(f"Server {server['url']} is unhealthy. Removing from pool.")
backend_servers.pop(i)
except requests.RequestException:
logging.warning(f"Server {server['url']} is unhealthy. Removing from pool.")
backend_servers.pop(i)
time.sleep(10) # Check every 10 seconds
def get_next_server():
global current_server, current_weight
while True:
server = backend_servers[current_server]
if current_weight < server["weight"]:
current_weight += 1
return server["url"]
else:
current_weight = 0
current_server = (current_server + 1) % len(backend_servers)
def get_least_connections_server():
return min(active_connections, key=active_connections.get)
class LoadBalancerHandler(http.server.BaseHTTPRequestHandler):
def forward_request(self, backend):
content_length = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(content_length) if content_length else None
try:
response = requests.request(
method=self.command,
url=f"{backend}{self.path}",
headers=self.headers,
data=body,
allow_redirects=False
)
self.send_response(response.status_code)
for header, value in response.headers.items():
self.send_header(header, value)
self.end_headers()
self.wfile.write(response.content)
except requests.RequestException as e:
self.send_error(500, f"Error forwarding request: {e}")
def do_GET(self):
backend = get_next_server() # or get_least_connections_server()
active_connections[backend] += 1
logging.info(f"Forwarding request to {backend} (connections: {active_connections[backend]})")
self.forward_request(backend)
active_connections[backend] -= 1
def do_POST(self):
self.do_GET()
if __name__ == "__main__":
# Start health check in a separate thread
threading.Thread(target=health_check, daemon=True).start()
PORT = 8080
with socketserver.TCPServer(("", PORT), LoadBalancerHandler) as httpd:
logging.info(f"Load balancer running on port {PORT}...")
httpd.serve_forever()
Feel free to experiment with the code and add your own features. Happy code EATING :)!