pytorch-model-recovery

PyTorch Model Recovery

Safety Notice

This listing is imported from skills.sh public index metadata. Review upstream SKILL.md and repository scripts before running.

Copy this and send it to your AI assistant to learn

Install skill "pytorch-model-recovery" with this command: npx skills add letta-ai/skills/letta-ai-skills-pytorch-model-recovery

PyTorch Model Recovery

This skill provides guidance for tasks involving PyTorch model architecture recovery from state dictionaries, selective layer training, and TorchScript export.

When to Use This Skill

This skill applies when:

  • Reconstructing a model architecture from a state dictionary (.pt or .pth file containing weights)

  • Training or fine-tuning specific layers while keeping others frozen

  • Converting a recovered model to TorchScript format

  • Debugging model loading issues or architecture mismatches

Approach Overview

Model recovery tasks require a systematic, incremental approach with verification at each step. The key phases are:

  • Architecture Analysis - Infer model structure from state dictionary keys

  • Architecture Implementation - Build the model class to match the state dict

  • Verification - Confirm weights load correctly before any training

  • Training - Fine-tune specific layers with appropriate hyperparameters

  • Export - Save to required format (often TorchScript)

Phase 1: Architecture Analysis

Examining the State Dictionary

To understand the model architecture, first load and inspect the state dictionary:

import torch

weights = torch.load('model_weights.pt', map_location='cpu')

Print all keys with shapes

for key, value in weights.items(): print(f"{key}: {value.shape}")

Key Patterns to Identify

Common patterns in state dictionary keys:

Key Pattern Indicates

encoder.layers.N.*

Transformer encoder with N+1 layers

decoder.layers.N.*

Transformer decoder with N+1 layers

embedding.weight

Embedding layer

pos_encoder.pe

Positional encoding (often a buffer)

output_layer.weight/bias

Final linear projection

*.in_proj_weight

Combined QKV projection in attention

.self_attn.

Self-attention component

.linear1/linear2.

Feed-forward network layers

.norm1/norm2.

Layer normalization

Inferring Dimensions

Extract model dimensions from weight shapes:

Example: Inferring transformer dimensions

d_model = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[1] nhead = weights['encoder.layers.0.self_attn.in_proj_weight'].shape[0] // (3 * d_model) * nhead_factor

Note: in_proj_weight has shape [3*d_model, d_model] for combined QKV

vocab_size = weights['embedding.weight'].shape[0] num_layers = max(int(k.split('.')[2]) for k in weights if 'encoder.layers' in k) + 1

Phase 2: Architecture Implementation

Building the Model Class

When implementing the model class:

  • Match the exact layer names used in the state dictionary

  • Use the same PyTorch module types (e.g., nn.TransformerEncoder vs custom)

  • Register buffers for non-learnable tensors (e.g., positional encodings)

class RecoveredModel(nn.Module): def init(self, vocab_size, d_model, nhead, num_layers, dim_feedforward): super().init() # Ensure attribute names match state dict keys exactly self.embedding = nn.Embedding(vocab_size, d_model)

    # For positional encoding stored as buffer
    self.pos_encoder = PositionalEncoding(d_model)

    encoder_layer = nn.TransformerEncoderLayer(
        d_model=d_model,
        nhead=nhead,
        dim_feedforward=dim_feedforward,
        batch_first=True  # Check if original used batch_first
    )
    self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    self.output_layer = nn.Linear(d_model, vocab_size)

Common Architecture Mistakes

  • Incorrect layer naming: self.fc vs self.output_layer

  • must match exactly

  • Missing buffers: Positional encodings often registered as buffers, not parameters

  • Wrong module types: Custom attention vs nn.MultiheadAttention

  • Batch dimension mismatch: batch_first=True vs batch_first=False

Phase 3: Verification (Critical)

Verify Architecture Before Training

Always verify the model loads weights correctly before any training:

model = RecoveredModel(...)

This will raise an error if keys don't match

model.load_state_dict(weights, strict=True) print("Weights loaded successfully!")

Verify a forward pass works

with torch.no_grad(): dummy_input = torch.randint(0, vocab_size, (1, 10)) output = model(dummy_input) print(f"Output shape: {output.shape}")

Handling Key Mismatches

If load_state_dict fails, compare keys:

model_keys = set(model.state_dict().keys()) weight_keys = set(weights.keys())

missing = weight_keys - model_keys unexpected = model_keys - weight_keys

print(f"Missing in model: {missing}") print(f"Unexpected in model: {unexpected}")

Verify TorchScript Compatibility Early

If TorchScript export is required, test it early:

Test scripting works before investing time in training

try: scripted = torch.jit.script(model) print("TorchScript scripting successful") except Exception as e: print(f"Scripting failed: {e}") # Try tracing instead traced = torch.jit.trace(model, dummy_input) print("TorchScript tracing successful")

Phase 4: Training Specific Layers

Freezing Layers

To train only specific layers, freeze all others:

Freeze all parameters first

for param in model.parameters(): param.requires_grad = False

Unfreeze only target layers

for param in model.output_layer.parameters(): param.requires_grad = True

Verify freeze status

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f"Trainable: {trainable:,} / {total:,} parameters")

Computing Baseline Loss

Before training, establish a baseline:

model.eval() with torch.no_grad(): outputs = model(inputs) original_loss = criterion(outputs, targets) print(f"Original MSE loss: {original_loss.item()}")

Training Loop Considerations

Create optimizer only for trainable parameters

optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=0.001 )

Training with progress tracking

for epoch in range(num_epochs): model.train() optimizer.zero_grad()

outputs = model(inputs)
loss = criterion(outputs, targets)

loss.backward()
optimizer.step()

if epoch % 10 == 0:
    print(f"Epoch {epoch}: Loss = {loss.item():.6f}")

Alternative: Closed-Form Solution for Linear Layers

When retraining only a linear output layer, consider a closed-form solution for efficiency:

Pre-compute frozen layer outputs

model.eval() with torch.no_grad(): # Get features before output layer features = model.get_features(inputs) # Shape: [N, d_model]

Solve linear regression: W*features = targets

Using pseudo-inverse: W = targets @ features.T @ (features @ features.T)^-1

solution = torch.linalg.lstsq(features, targets).solution model.output_layer.weight.data = solution.T

Phase 5: TorchScript Export

Saving the Model

Ensure model is in eval mode

model.eval()

Script the model (preferred for control flow)

scripted_model = torch.jit.script(model) scripted_model.save('/app/model.pt')

Or trace the model (for simpler models)

traced_model = torch.jit.trace(model, example_input) traced_model.save('/app/model.pt')

Verify Saved Model

Reload and verify

loaded = torch.jit.load('/app/model.pt') loaded.eval()

with torch.no_grad(): original_out = model(test_input) loaded_out = loaded(test_input)

diff = (original_out - loaded_out).abs().max()
print(f"Max difference: {diff.item()}")
assert diff < 1e-5, "Model outputs don't match!"

Environment Considerations

Handling Slow Environments

When operating in resource-constrained environments:

Benchmark first: Test basic operations before committing to full solution

import time start = time.time() _ = model(torch.randint(0, vocab_size, (1, 10))) print(f"Single forward pass: {time.time() - start:.2f}s")

Reduce batch size: Process samples individually if needed

Set realistic timeouts: Base on benchmarks, not arbitrary values

Use incremental checkpoints: Save progress periodically

Memory Management

Clear GPU cache between operations

torch.cuda.empty_cache()

Use gradient checkpointing for large models

from torch.utils.checkpoint import checkpoint

Process in smaller batches

for batch in torch.split(data, batch_size): process(batch)

Common Pitfalls

  • Not verifying architecture match before training - Always test load_state_dict first

  • Arbitrary hyperparameters - Justify choices based on task characteristics

  • Ignoring TorchScript compatibility - Test export early, not after training

  • Syntax errors in edits - Review code changes carefully, especially string formatting

  • Incomplete state dict mapping - Verify all keys are accounted for

  • Not establishing baseline metrics - Compute original loss before training

  • Missing torch.no_grad() for inference - Use context manager for evaluation

  • Forgetting to set model.eval()

  • Required for consistent behavior in eval/export

Verification Checklist

Before considering the task complete:

  • State dictionary keys fully analyzed and documented

  • Model architecture matches state dict exactly (verified with load_state_dict )

  • Forward pass produces valid output

  • Baseline loss/metric computed

  • Target layers correctly unfrozen, others frozen

  • Training improves loss over baseline

  • TorchScript export succeeds

  • Exported model produces same outputs as original

  • Model saved to required path

Source Transparency

This detail page is rendered from real SKILL.md content. Trust labels are metadata-based hints, not a safety guarantee.

Related Skills

Related by shared tags or category signals.

General

extracting-pdf-text

No summary provided by upstream source.

Repository SourceNeeds Review
General

video-processing

No summary provided by upstream source.

Repository SourceNeeds Review
General

google-workspace

No summary provided by upstream source.

Repository SourceNeeds Review
General

portfolio-optimization

No summary provided by upstream source.

Repository SourceNeeds Review