Federated Learning + DQN
Privacy-preserving distributed reinforcement learning for healthcare scheduling.
When to Use
-
Multi-institution ML without sharing raw data
-
Healthcare applications with privacy requirements
-
Distributed optimization across organizations
Architecture Overview
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ Hospital A │ │ Hospital B │ │ Hospital C │ │ Local DQN │ │ Local DQN │ │ Local DQN │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │ │ └───────────────────┼───────────────────┘ │ ┌──────▼──────┐ │ Aggregator │ │ (Server) │ └─────────────┘
Components
Federated Learning
FedAvg Algorithm:
Server
def federated_averaging(models, weights): total = sum(weights) averaged = {} for key in models[0].state_dict(): averaged[key] = sum( w * model.state_dict()[key] for model, w in zip(models, weights) ) / total return averaged
Round
for round in range(num_rounds): clients = select_clients() models, weights = [], [] for client in clients: model, weight = client.train(local_epochs) models.append(model) weights.append(weight) global_model.load_state_dict(federated_averaging(models, weights))
Deep Q-Network (DQN)
Network Architecture:
import torch.nn as nn
class DQN(nn.Module): def init(self, state_dim, action_dim): super().init() self.net = nn.Sequential( nn.Linear(state_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, action_dim) )
def forward(self, x):
return self.net(x)
Training Loop:
def train_dqn(agent, replay_buffer, target_net): for step in range(num_steps): state = env.reset() done = False
while not done:
# Epsilon-greedy action
action = agent.select_action(state, epsilon)
next_state, reward, done, _ = env.step(action)
# Store transition
replay_buffer.push(state, action, reward, next_state, done)
# Sample batch
batch = replay_buffer.sample(batch_size)
# Compute loss
q_values = agent(batch.state)
next_q_values = target_net(batch.next_state)
target = batch.reward + gamma * next_q_values.max(1)[0] * (1 - batch.done)
loss = nn.MSELoss()(q_values.gather(1, batch.action), target)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
state = next_state
# Update target network
if step % target_update == 0:
target_net.load_state_dict(agent.state_dict())
Multi-Level Feedback Queue (MLFQ)
Integration with DQN:
class MLFQScheduler: def init(self, num_queues=3): self.queues = [[] for _ in range(num_queues)] self.priority_boost = 10
def add_patient(self, patient, priority):
queue_idx = min(priority, len(self.queues) - 1)
self.queues[queue_idx].append(patient)
def get_next_patient(self):
# DQN selects which queue to serve
queue_state = self.get_queue_state()
action = dqn_agent.select_action(queue_state)
# Boost priority of waiting patients
self.boost_priorities()
return self.queues[action].pop(0) if self.queues[action] else None
def boost_priorities(self):
for i in range(len(self.queues) - 1, 0, -1):
for patient in self.queues[i]:
if patient.wait_time > self.priority_boost:
self.queues[i-1].append(patient)
self.queues[i].remove(patient)
Privacy Guarantees
Differential Privacy
def add_dp_noise(gradients, epsilon, delta, sensitivity): """Add Gaussian noise for (ε,δ)-differential privacy""" sigma = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon noise = torch.randn_like(gradients) * sigma return gradients + noise
Secure Aggregation
-
Clients encrypt model updates
-
Server aggregates without seeing individual updates
-
Only decrypted aggregate is visible
Healthcare Scheduling Use Case
State Representation
state = { 'queue_lengths': [len(q) for q in queues], # Shape: (num_queues,) 'patient_acuity': average_acuity_per_queue, # Shape: (num_queues,) 'resource_availability': [beds, staff, equipment], 'time_features': [hour_of_day, day_of_week], 'predicted_arrivals': next_hour_forecast, }
Action Space
actions = { 0: 'Schedule from high-priority queue', 1: 'Schedule from medium-priority queue', 2: 'Schedule from low-priority queue', 3: 'Allocate additional resource', 4: 'Request transfer from other hospital', }
Reward Function
def calculate_reward(state, action, next_state): reward = 0
# Minimize wait time (weighted by acuity)
reward -= sum(
patient.wait_time * patient.acuity
for patient in all_patients
)
# Penalize queue imbalance
reward -= variance(queue_lengths) * 10
# Reward completing high-acuity cases
reward += completed_high_acuity * 50
# Penalize resource overutilization
if resource_utilization > threshold:
reward -= overutilization_penalty
return reward
Implementation Considerations
Communication Efficiency
-
Compression: Quantize model updates
-
Federated Dropout: Train smaller subnetworks
-
Asynchronous Updates: No synchronization barrier
Handling Non-IID Data
-
Personalization: Fine-tune global model locally
-
Clustered FL: Group similar hospitals
-
Multi-task Learning: Shared representation + task-specific heads
System Heterogeneity
-
Straggler Handling: Async aggregation or timeout
-
Variable Resources: Adaptive local epochs
-
Device Selection: Probabilistic client sampling
Evaluation Metrics
Metric Description
Privacy Budget (ε) Differential privacy guarantee
Model Accuracy Comparison to centralized training
Communication Rounds Convergence speed
Patient Wait Time Scheduling effectiveness
Resource Utilization System efficiency
Resources
-
Federated Learning Paper (McMahan et al.)
-
DQN Paper (Mnih et al.)
-
Healthcare Scheduling Survey