replicate-integration

Replicate API Integration

Safety Notice

This listing is imported from skills.sh public index metadata. Review upstream SKILL.md and repository scripts before running.

Copy this and send it to your AI assistant to learn

Install skill "replicate-integration" with this command: npx skills add guaderrama/cabo-health-ai/guaderrama-cabo-health-ai-replicate-integration

Replicate API Integration

Purpose

Deploy and run AI models in production using Replicate's cloud platform. Specialized for image generation with Flux Dev, SDXL, custom LoRA fine-tuning, and prediction polling patterns.

When to Use

  • Generating images with Flux Dev, SDXL, or custom models

  • Fine-tuning models with custom LoRA weights

  • Running predictions with polling/webhook patterns

  • Deploying custom models to production

  • Managing long-running AI workloads

Architecture Pattern

Project Structure

backend/ ├── services/ │ ├── replicate_service.py # Main Replicate client │ └── model_service.py # Model-specific logic ├── models/ │ └── replicate_models.py # Pydantic models ├── config/ │ └── replicate_config.py # Configuration └── utils/ └── polling.py # Polling utilities

Installation

pip install replicate httpx python-dotenv pydantic

Environment Setup

.env

REPLICATE_API_TOKEN=r8_... FRONTEND_URL=http://localhost:3000 DEFAULT_MODEL=black-forest-labs/flux-dev LORA_MODEL_ID=your-username/your-model-id

Quick Start

Basic Image Generation

import replicate import os

client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

Simple prediction

output = client.run( "black-forest-labs/flux-dev", input={ "prompt": "A serene mountain landscape at sunset", "num_outputs": 1, "aspect_ratio": "16:9" } )

print(f"Generated image: {output[0]}")

Async Pattern with Polling

import replicate import asyncio from typing import List

async def generate_images_async( prompt: str, num_images: int = 4 ) -> List[str]: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

# Start prediction
prediction = client.predictions.create(
    version="black-forest-labs/flux-dev",
    input={
        "prompt": prompt,
        "num_outputs": num_images,
        "guidance_scale": 3.5,
        "num_inference_steps": 28
    }
)

# Poll for completion
while prediction.status not in ["succeeded", "failed", "canceled"]:
    await asyncio.sleep(1)
    prediction = client.predictions.get(prediction.id)

if prediction.status == "succeeded":
    return prediction.output
else:
    raise Exception(f"Prediction failed: {prediction.error}")

Flux Dev Integration

Complete Flux Dev Configuration

from pydantic import BaseModel, Field from typing import Literal

class FluxDevInput(BaseModel): prompt: str = Field(description="Text description of image to generate") aspect_ratio: Literal["1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "9:16", "9:21"] = "1:1" num_outputs: int = Field(ge=1, le=4, default=1) num_inference_steps: int = Field(ge=1, le=50, default=28) guidance_scale: float = Field(ge=0, le=10, default=3.5) output_format: Literal["webp", "jpg", "png"] = "webp" output_quality: int = Field(ge=0, le=100, default=80) seed: int | None = None disable_safety_checker: bool = False

async def generate_flux_images(input: FluxDevInput) -> List[str]: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

output = await client.async_run(
    "black-forest-labs/flux-dev",
    input=input.model_dump()
)

return output

Flux Dev Best Practices

Optimal settings for portraits

PORTRAIT_CONFIG = { "aspect_ratio": "2:3", "num_inference_steps": 30, "guidance_scale": 4.0, "output_format": "webp", "output_quality": 90 }

Optimal settings for landscapes

LANDSCAPE_CONFIG = { "aspect_ratio": "16:9", "num_inference_steps": 28, "guidance_scale": 3.5, "output_format": "webp", "output_quality": 85 }

Batch generation pattern

async def batch_generate(prompts: List[str]) -> List[List[str]]: tasks = [generate_flux_images(FluxDevInput(prompt=p)) for p in prompts] return await asyncio.gather(*tasks)

LoRA Fine-Tuning Integration

Using Custom LoRA Models

class LoRAInput(BaseModel): prompt: str lora_scale: float = Field(ge=0, le=1, default=1.0, description="LoRA influence strength") trigger_word: str | None = Field(default=None, description="Special token for identity") num_outputs: int = Field(ge=1, le=10, default=1)

async def generate_with_lora( model_id: str, # e.g., "daniel-carreon/danielcarrong" input: LoRAInput ) -> List[str]: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

# Ensure trigger word is in prompt
prompt = input.prompt
if input.trigger_word and input.trigger_word not in prompt:
    prompt = f"{input.trigger_word} {prompt}"

output = await client.async_run(
    model_id,
    input={
        "prompt": prompt,
        "lora_scale": input.lora_scale,
        "num_outputs": input.num_outputs,
        "num_inference_steps": 28,
        "guidance_scale": 3.5
    }
)

return output

Training Custom LoRA

async def train_lora( images_zip_url: str, trigger_word: str, steps: int = 1000 ) -> str: """Train custom LoRA model on Replicate""" client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

training = client.trainings.create(
    version="ostris/flux-dev-lora-trainer",
    input={
        "input_images": images_zip_url,
        "trigger_word": trigger_word,
        "steps": steps,
        "lora_rank": 16,
        "optimizer": "adamw8bit",
        "batch_size": 1,
        "learning_rate": 4e-4,
        "caption_prefix": f"a photo of {trigger_word}"
    },
    destination=f"{os.getenv('REPLICATE_USERNAME')}/my-lora-model"
)

# Wait for training completion
while training.status not in ["succeeded", "failed", "canceled"]:
    await asyncio.sleep(30)
    training = client.trainings.get(training.id)

if training.status == "succeeded":
    return training.output  # Model URL
else:
    raise Exception(f"Training failed: {training.error}")

Polling Patterns

Robust Polling with Retry

import asyncio from typing import Callable, Any

class PollConfig(BaseModel): max_wait: int = 300 # 5 minutes poll_interval: float = 1.0 # 1 second timeout_multiplier: float = 1.5 # Backoff factor

async def poll_prediction( prediction_id: str, on_progress: Callable[[float], None] | None = None ) -> Any: """Poll Replicate prediction with exponential backoff""" client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

start_time = asyncio.get_event_loop().time()
poll_interval = 1.0

while True:
    prediction = client.predictions.get(prediction_id)

    # Report progress
    if on_progress and hasattr(prediction, 'logs'):
        progress = extract_progress(prediction.logs)
        on_progress(progress)

    # Check status
    if prediction.status == "succeeded":
        return prediction.output
    elif prediction.status in ["failed", "canceled"]:
        raise Exception(f"Prediction {prediction.status}: {prediction.error}")

    # Timeout check
    elapsed = asyncio.get_event_loop().time() - start_time
    if elapsed > 300:  # 5 minutes
        raise TimeoutError(f"Prediction timeout after {elapsed}s")

    # Exponential backoff
    await asyncio.sleep(poll_interval)
    poll_interval = min(poll_interval * 1.5, 5.0)

def extract_progress(logs: str) -> float: """Extract progress from logs (0.0 to 1.0)""" # Example: "Progress: 50%" import re match = re.search(r"Progress: (\d+)%", logs or "") return float(match.group(1)) / 100 if match else 0.0

Webhook Pattern (Production)

from fastapi import FastAPI, Request

app = FastAPI()

@app.post("/webhooks/replicate") async def handle_webhook(request: Request): """Handle Replicate webhook callback""" payload = await request.json()

prediction_id = payload["id"]
status = payload["status"]

if status == "succeeded":
    output = payload["output"]
    # Process completed prediction
    await save_results(prediction_id, output)
elif status == "failed":
    error = payload["error"]
    # Handle error
    await log_error(prediction_id, error)

return {"status": "received"}

Start prediction with webhook

def create_prediction_with_webhook(prompt: str) -> str: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

prediction = client.predictions.create(
    version="black-forest-labs/flux-dev",
    input={"prompt": prompt},
    webhook=f"{os.getenv('BACKEND_URL')}/webhooks/replicate",
    webhook_events_filter=["completed"]
)

return prediction.id

Error Handling

Comprehensive Error Handling

from enum import Enum

class ReplicateError(Exception): """Base Replicate error""" pass

class RateLimitError(ReplicateError): """Rate limit exceeded""" pass

class ModelNotFoundError(ReplicateError): """Model not found""" pass

async def safe_replicate_call( model: str, input: dict, max_retries: int = 3 ) -> Any: """Call Replicate with retry logic""" client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

for attempt in range(max_retries):
    try:
        output = await client.async_run(model, input=input)
        return output
    except replicate.exceptions.ModelError as e:
        if "not found" in str(e).lower():
            raise ModelNotFoundError(f"Model {model} not found")
        raise
    except replicate.exceptions.ReplicateError as e:
        if "rate limit" in str(e).lower():
            if attempt < max_retries - 1:
                await asyncio.sleep(2 ** attempt)  # Exponential backoff
                continue
            raise RateLimitError("Rate limit exceeded")
        raise ReplicateError(f"Replicate error: {e}")
    except Exception as e:
        if attempt < max_retries - 1:
            await asyncio.sleep(1)
            continue
        raise

FastAPI Integration

Complete Replicate Endpoint

from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel

app = FastAPI()

class GenerateRequest(BaseModel): prompt: str num_images: int = 4 use_lora: bool = False lora_model_id: str | None = None

class GenerateResponse(BaseModel): prediction_id: str status: str images: List[str] | None = None

@app.post("/generate", response_model=GenerateResponse) async def generate_images(request: GenerateRequest): """Generate images using Replicate""" try: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))

    # Select model
    model = request.lora_model_id if request.use_lora else "black-forest-labs/flux-dev"

    # Create prediction
    prediction = client.predictions.create(
        version=model,
        input={
            "prompt": request.prompt,
            "num_outputs": request.num_images
        }
    )

    return GenerateResponse(
        prediction_id=prediction.id,
        status=prediction.status
    )
except Exception as e:
    raise HTTPException(status_code=500, detail=str(e))

@app.get("/predictions/{prediction_id}") async def get_prediction(prediction_id: str): """Check prediction status""" try: client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN")) prediction = client.predictions.get(prediction_id)

    return GenerateResponse(
        prediction_id=prediction.id,
        status=prediction.status,
        images=prediction.output if prediction.status == "succeeded" else None
    )
except Exception as e:
    raise HTTPException(status_code=404, detail="Prediction not found")

Rate Limiting

Rate Limit Management

from collections import deque from datetime import datetime, timedelta

class RateLimiter: def init(self, max_requests: int = 50, window: int = 60): self.max_requests = max_requests self.window = timedelta(seconds=window) self.requests: deque[datetime] = deque()

async def acquire(self):
    """Wait if rate limit reached"""
    now = datetime.now()

    # Remove old requests
    while self.requests and self.requests[0] < now - self.window:
        self.requests.popleft()

    # Check limit
    if len(self.requests) >= self.max_requests:
        wait_time = (self.requests[0] + self.window - now).total_seconds()
        if wait_time > 0:
            await asyncio.sleep(wait_time)

    self.requests.append(now)

Usage

limiter = RateLimiter(max_requests=50, window=60)

async def generate_with_limit(prompt: str): await limiter.acquire() return await generate_flux_images(FluxDevInput(prompt=prompt))

Best Practices

  • Always use async/await for non-blocking I/O

  • Implement polling with backoff to avoid rate limits

  • Use webhooks in production for long-running tasks

  • Cache prediction results to avoid redundant API calls

  • Monitor costs - log prediction IDs and metrics

  • Handle errors gracefully with retry logic

  • Use typed inputs with Pydantic models

  • Set timeouts for all predictions

  • Validate outputs before returning to users

  • Store metadata for debugging and analytics

Common Pitfalls

❌ Don't: Poll too frequently (wastes API calls) ✅ Do: Use exponential backoff (1s → 1.5s → 2.25s → ...)

❌ Don't: Forget to handle prediction failures ✅ Do: Check status and handle errors

❌ Don't: Hardcode model versions ✅ Do: Use environment variables for flexibility

❌ Don't: Block on predictions in API endpoints ✅ Do: Return prediction ID immediately, poll separately

Complete Example: Production Service

from fastapi import FastAPI, BackgroundTasks import replicate from pydantic import BaseModel from typing import List import asyncio

app = FastAPI()

class ImageGenerationService: def init(self): self.client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN")) self.rate_limiter = RateLimiter(max_requests=50, window=60)

async def generate(
    self,
    prompt: str,
    num_images: int = 4,
    model_id: str = "black-forest-labs/flux-dev"
) -> List[str]:
    """Generate images with rate limiting and error handling"""
    await self.rate_limiter.acquire()

    try:
        prediction = self.client.predictions.create(
            version=model_id,
            input={
                "prompt": prompt,
                "num_outputs": num_images,
                "num_inference_steps": 28,
                "guidance_scale": 3.5
            }
        )

        # Poll for completion
        output = await poll_prediction(
            prediction.id,
            on_progress=lambda p: print(f"Progress: {p*100:.0f}%")
        )

        return output
    except Exception as e:
        print(f"Generation failed: {e}")
        raise

Global service instance

service = ImageGenerationService()

@app.post("/api/generate") async def generate_endpoint(request: GenerateRequest): try: images = await service.generate( prompt=request.prompt, num_images=request.num_images ) return {"status": "success", "images": images} except Exception as e: return {"status": "error", "message": str(e)}

Resources

  • Replicate Docs

  • Flux Dev Model

  • LoRA Training Guide

  • Python Client

Source Transparency

This detail page is rendered from real SKILL.md content. Trust labels are metadata-based hints, not a safety guarantee.

Related Skills

Related by shared tags or category signals.

General

nextjs-16-complete-guide

No summary provided by upstream source.

Repository SourceNeeds Review
General

supabase-auth-memory

No summary provided by upstream source.

Repository SourceNeeds Review
General

skill-creator

No summary provided by upstream source.

Repository SourceNeeds Review
Automation

agent-builder-pydantic-ai

No summary provided by upstream source.

Repository SourceNeeds Review