#!/usr/bin/env python3 """ download.py - Download/repair model files and update model.yaml metadata. Usage: ./tools/download.py models/llama-2-7b-chat/model.yaml - Always (re)runs snapshot_download with resume support, so partially fetched directories get completed instead of being skipped. - Updates YAML after each variant with fresh file list + total size. - Tracks LFS via sensible patterns (plus a size threshold fallback). - Emits clear logs so you can see progress per variant. """ import sys import os import yaml import subprocess from pathlib import Path from typing import Iterable from huggingface_hub import snapshot_download LFS_PATTERNS: list[str] = [ # Extensions commonly used for model artifacts "*.safetensors", "*.bin", "*.pt", "*.gguf", "*.onnx", "*.ckpt", "*.tensors", "*.npz", "*.tar", "*.tar.gz", "*.zip", ] SIZE_THRESHOLD_BYTES = 1_000_000 # 1 MB fallback if a file doesn't match any pattern def run(cmd: list[str], check: bool = True) -> None: subprocess.run(cmd, check=check) def track_lfs_patterns(patterns: Iterable[str]) -> None: """ Track a set of patterns in Git LFS. This is idempotent; it just appends to .gitattributes as needed. """ for patt in patterns: try: run(["git", "lfs", "track", patt], check=False) except Exception: # Non-fatal: we’ll still fall back to per-file size rule below. pass def list_files_under(root: Path) -> list[Path]: return [p for p in root.rglob("*") if p.is_file()] def ensure_repo_root() -> None: # best effort: warn (but don’t die) if not in a git repo try: subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) except Exception: print("⚠️ Not inside a Git repository? Git/LFS steps may fail.", file=sys.stderr) def main() -> None: if len(sys.argv) != 2: print(f"Usage: {sys.argv[0]} ", file=sys.stderr) sys.exit(1) model_yaml_path = Path(sys.argv[1]) if not model_yaml_path.exists(): print(f"Model YAML not found: {model_yaml_path}", file=sys.stderr) sys.exit(1) ensure_repo_root() # Load YAML with open(model_yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} model_dir = model_yaml_path.parent # Proactively set up LFS tracking by patterns (idempotent) track_lfs_patterns(LFS_PATTERNS) # Iterate formats & variants formats = (data.get("model") or {}).get("formats") or [] for fmt in formats: variants = fmt.get("variants") or [] for variant in variants: variant_id = variant.get("id") hf_repo = variant.get("hf_repo") if not hf_repo or not variant_id: continue dest_path = model_dir / variant_id dest_path.mkdir(parents=True, exist_ok=True) repo_id = hf_repo.replace("https://huggingface.co/", "") print(f"\n[DL] Downloading/resuming variant '{variant_id}' from '{repo_id}' into '{dest_path}'") # Always call snapshot_download with resume enabled. This will: # - no-op for already-complete files # - resume partials # - fetch any missing files try: snapshot_download( repo_id=repo_id, local_dir=str(dest_path), local_dir_use_symlinks=False, resume_download=True, # explicit # You can add allow_patterns / ignore_patterns if you want to filter # allow_patterns=None, # ignore_patterns=None, ) except Exception as e: print(f"❌ snapshot_download failed for {variant_id}: {e}", file=sys.stderr) raise # Scan files, compute size, and ensure big files are tracked by LFS files_list: list[str] = [] total_size_bytes = 0 for p in list_files_under(dest_path): rel = p.relative_to(model_dir) files_list.append(str(rel)) try: size = p.stat().st_size except FileNotFoundError: # if a file was removed mid-scan, skip it continue total_size_bytes += size # Fallback: ensure big files get tracked even if patterns miss them if size > SIZE_THRESHOLD_BYTES: # Idempotent; harmless if already tracked. run(["git", "lfs", "track", str(p)], check=False) files_list.sort() variant["files"] = files_list variant["size_bytes"] = int(total_size_bytes) # Save updated YAML progressively after each variant with open(model_yaml_path, "w", encoding="utf-8") as f: yaml.dump(data, f, sort_keys=False, allow_unicode=True) print(f"✅ Updated {model_yaml_path} for variant '{variant_id}'") # Run cleanup script to commit, push, and prune commit_message = f"Add/update model files for {model_dir.name}/{variant_id}" print(f"🧹 Running cleanup for {variant_id}...") try: run(["./tools/cleanup.sh", commit_message], check=True) except subprocess.CalledProcessError as e: print(f"❌ cleanup.sh failed (continue to next variant): {e}", file=sys.stderr) # Decide whether to continue or abort; continuing is usually fine. # raise # uncomment to abort on failure print(f"\n✅ Download and YAML update complete for {model_yaml_path}.") if __name__ == "__main__": main()