Files
llm-registry/tools/download.py
2025-09-27 18:26:34 +01:00

168 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
download.py - Download/repair model files and update model.yaml metadata.
Usage:
./tools/download.py models/llama-2-7b-chat/model.yaml
- Always (re)runs snapshot_download with resume support, so partially
fetched directories get completed instead of being skipped.
- Updates YAML after each variant with fresh file list + total size.
- Tracks LFS via sensible patterns (plus a size threshold fallback).
- Emits clear logs so you can see progress per variant.
"""
import sys
import os
import yaml
import subprocess
from pathlib import Path
from typing import Iterable
from huggingface_hub import snapshot_download
LFS_PATTERNS: list[str] = [
# Extensions commonly used for model artifacts
"*.safetensors",
"*.bin",
"*.pt",
"*.gguf",
"*.onnx",
"*.ckpt",
"*.tensors",
"*.npz",
"*.tar",
"*.tar.gz",
"*.zip",
]
SIZE_THRESHOLD_BYTES = 1_000_000 # 1 MB fallback if a file doesn't match any pattern
def run(cmd: list[str], check: bool = True) -> None:
subprocess.run(cmd, check=check)
def track_lfs_patterns(patterns: Iterable[str]) -> None:
"""
Track a set of patterns in Git LFS. This is idempotent; it just
appends to .gitattributes as needed.
"""
for patt in patterns:
try:
run(["git", "lfs", "track", patt], check=False)
except Exception:
# Non-fatal: well still fall back to per-file size rule below.
pass
def list_files_under(root: Path) -> list[Path]:
return [p for p in root.rglob("*") if p.is_file()]
def ensure_repo_root() -> None:
# best effort: warn (but dont die) if not in a git repo
try:
subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except Exception:
print("⚠️ Not inside a Git repository? Git/LFS steps may fail.", file=sys.stderr)
def main() -> None:
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <path-to-model.yaml>", file=sys.stderr)
sys.exit(1)
model_yaml_path = Path(sys.argv[1])
if not model_yaml_path.exists():
print(f"Model YAML not found: {model_yaml_path}", file=sys.stderr)
sys.exit(1)
ensure_repo_root()
# Load YAML
with open(model_yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
model_dir = model_yaml_path.parent
# Proactively set up LFS tracking by patterns (idempotent)
track_lfs_patterns(LFS_PATTERNS)
# Iterate formats & variants
formats = (data.get("model") or {}).get("formats") or []
for fmt in formats:
variants = fmt.get("variants") or []
for variant in variants:
variant_id = variant.get("id")
hf_repo = variant.get("hf_repo")
if not hf_repo or not variant_id:
continue
dest_path = model_dir / variant_id
dest_path.mkdir(parents=True, exist_ok=True)
repo_id = hf_repo.replace("https://huggingface.co/", "")
print(f"\n[DL] Downloading/resuming variant '{variant_id}' from '{repo_id}' into '{dest_path}'")
# Always call snapshot_download with resume enabled. This will:
# - no-op for already-complete files
# - resume partials
# - fetch any missing files
try:
snapshot_download(
repo_id=repo_id,
local_dir=str(dest_path),
local_dir_use_symlinks=False,
resume_download=True, # explicit
# You can add allow_patterns / ignore_patterns if you want to filter
# allow_patterns=None,
# ignore_patterns=None,
)
except Exception as e:
print(f"❌ snapshot_download failed for {variant_id}: {e}", file=sys.stderr)
raise
# Scan files, compute size, and ensure big files are tracked by LFS
files_list: list[str] = []
total_size_bytes = 0
for p in list_files_under(dest_path):
rel = p.relative_to(model_dir)
files_list.append(str(rel))
try:
size = p.stat().st_size
except FileNotFoundError:
# if a file was removed mid-scan, skip it
continue
total_size_bytes += size
# Fallback: ensure big files get tracked even if patterns miss them
if size > SIZE_THRESHOLD_BYTES:
# Idempotent; harmless if already tracked.
run(["git", "lfs", "track", str(p)], check=False)
files_list.sort()
variant["files"] = files_list
variant["size_bytes"] = int(total_size_bytes)
# Save updated YAML progressively after each variant
with open(model_yaml_path, "w", encoding="utf-8") as f:
yaml.dump(data, f, sort_keys=False, allow_unicode=True)
print(f"✅ Updated {model_yaml_path} for variant '{variant_id}'")
# Run cleanup script to commit, push, and prune
commit_message = f"Add/update model files for {model_dir.name}/{variant_id}"
print(f"🧹 Running cleanup for {variant_id}...")
try:
run(["./tools/cleanup.sh", commit_message], check=True)
except subprocess.CalledProcessError as e:
print(f"❌ cleanup.sh failed (continue to next variant): {e}", file=sys.stderr)
# Decide whether to continue or abort; continuing is usually fine.
# raise # uncomment to abort on failure
print(f"\n✅ Download and YAML update complete for {model_yaml_path}.")
if __name__ == "__main__":
main()