Files
llm-registry/tools/download.py
zcourts 849e7c4699
Some checks failed
Download Missing Models / download-models (push) Has been cancelled
Bring the cleanup work into the python script and drop the shell scripts
2025-09-27 18:29:05 +01:00

331 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
download.py - Download/repair model files, update model.yaml metadata,
and commit/push changes with proper Git LFS handling (no bash script needed).
Usage:
./tools/download.py models/llama-2-7b-chat/model.yaml
What this does:
- (Re)runs snapshot_download with resume support, so partially fetched directories
get completed instead of being skipped.
- Avoids adding Hugging Face housekeeping like ".cache/**" to your YAML.
- Updates YAML after each variant with fresh file list + total size.
- Tracks LFS via sensible patterns (plus a size threshold fallback) using
repo-relative paths so it actually applies.
- Runs a built-in cleanup step (commit, push, optional LFS push, and prune),
replacing the old cleanup.sh.
"""
from __future__ import annotations
import os
import sys
import yaml
import subprocess
from pathlib import Path
from typing import Iterable, List, Optional
from huggingface_hub import snapshot_download
# ----------------------------
# Configuration
# ----------------------------
LFS_PATTERNS: list[str] = [
# Extensions commonly used for model artifacts
"*.safetensors",
"*.bin",
"*.pt",
"*.gguf",
"*.onnx",
"*.ckpt",
"*.tensors",
"*.npz",
"*.tar",
"*.tar.gz",
"*.zip",
]
SIZE_THRESHOLD_BYTES = 1_000_000 # 1 MB fallback if a file doesn't match any pattern
# By default we skip pushing all LFS objects (same as prior bash script).
# Set env GIT_LFS_PUSH_ALL=1 to force a full "git lfs push origin --all".
LFS_PUSH_ALL = os.environ.get("GIT_LFS_PUSH_ALL", "0") == "1"
# ----------------------------
# Small subprocess helpers
# ----------------------------
def run(cmd: list[str], check: bool = True, cwd: Optional[Path] = None) -> None:
subprocess.run(cmd, check=check, cwd=str(cwd) if cwd else None)
def run_capture(cmd: list[str], cwd: Optional[Path] = None) -> str:
out = subprocess.check_output(cmd, cwd=str(cwd) if cwd else None, stderr=subprocess.DEVNULL)
return out.decode().strip()
# ----------------------------
# Git / LFS utilities
# ----------------------------
def ensure_repo_root() -> Path:
"""
Ensure we're in a git repo; install LFS filters locally; return repo root.
"""
try:
subprocess.run(
["git", "rev-parse", "--is-inside-work-tree"],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Make sure LFS filters are active in this repo (idempotent)
subprocess.run(
["git", "lfs", "install", "--local"],
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
root = run_capture(["git", "rev-parse", "--show-toplevel"])
return Path(root)
except Exception:
print("⚠️ Not inside a Git repository? Git/LFS steps may fail.", file=sys.stderr)
return Path.cwd()
def repo_relative_path(repo_root: Path, p: Path) -> Path:
"""
Return a path to p relative to repo_root. Works even if p is not a strict subpath
(falls back to os.path.relpath).
"""
try:
return p.resolve().relative_to(repo_root.resolve())
except Exception:
# Fallback (handles symlinks / different mounts)
return Path(os.path.relpath(p.resolve(), repo_root.resolve()))
def lfs_track_patterns(patterns: Iterable[str]) -> None:
"""
Track a set of glob patterns in Git LFS (idempotent).
"""
for patt in patterns:
try:
run(["git", "lfs", "track", patt], check=False)
except Exception:
# Non-fatal: well still fall back to per-file size rule below.
pass
def lfs_track_file(repo_root: Path, path_in_repo: Path) -> None:
"""
Track an individual file in Git LFS using a repo-relative path.
"""
# Normalize to POSIX-like string for .gitattributes consistency
rel = str(path_in_repo.as_posix())
try:
run(["git", "lfs", "track", rel], check=False, cwd=repo_root)
except Exception:
pass
def git_status_has_changes(repo_root: Path) -> bool:
try:
status = run_capture(["git", "status", "--porcelain=v1"], cwd=repo_root)
return bool(status.strip())
except Exception:
return False
def git_stage_and_commit_push(
repo_root: Path,
scope_paths: list[Path],
commit_message: str,
lfs_push_all: bool = False,
) -> None:
"""
Stages .gitattributes + renormalizes only the provided scope paths (e.g., 'models/…'),
then commits, pushes, optionally pushes all LFS objects, and finally prunes LFS.
"""
# Stage .gitattributes explicitly (ignore failures)
try:
run(["git", "add", ".gitattributes"], check=False, cwd=repo_root)
except Exception:
pass
# Renormalize only the relevant directories to avoid sweeping the whole repo.
# If scope_paths is empty, fall back to full repo (conservative).
if scope_paths:
for sp in scope_paths:
rel = repo_relative_path(repo_root, sp)
run(["git", "add", "--renormalize", str(rel)], check=False, cwd=repo_root)
else:
run(["git", "add", "--renormalize", "."], check=False, cwd=repo_root)
# If nothing is staged, skip
staged_is_empty = False
try:
# 'git diff --cached --quiet' exits 0 when no staged changes
subprocess.run(["git", "diff", "--cached", "--quiet"], cwd=str(repo_root))
staged_is_empty = True
except Exception:
staged_is_empty = False
if staged_is_empty:
print("No staged changes after normalization. Skipping commit and push.")
return
print("Committing and pushing changes...")
run(["git", "commit", "-m", commit_message], cwd=repo_root)
# Push main refs
run(["git", "push"], cwd=repo_root)
# Optionally ensure all LFS objects are uploaded
if lfs_push_all:
try:
run(["git", "lfs", "push", "origin", "--all"], check=True, cwd=repo_root)
except subprocess.CalledProcessError as e:
print(f"⚠️ 'git lfs push --all' failed: {e}. Continuing.", file=sys.stderr)
# Prune local LFS to save disk
try:
run(["git", "lfs", "prune", "--force"], check=False, cwd=repo_root)
except Exception as e:
print(f"⚠️ 'git lfs prune' failed: {e}. Continuing.", file=sys.stderr)
print("✅ Cleanup complete.")
# ----------------------------
# Filesystem helpers
# ----------------------------
def list_files_under(root: Path) -> list[Path]:
"""
Recursively collect files under `root`, skipping housekeeping dirs.
"""
skip_dirs = {".git", ".cache", ".hf_mirror_cache"}
files: list[Path] = []
for p in root.rglob("*"):
if not p.is_file():
continue
rel_parts = p.relative_to(root).parts
# Skip files if any parent is a hidden/skip dir
if any(part in skip_dirs or part.startswith(".") for part in rel_parts[:-1]):
continue
files.append(p)
return files
# ----------------------------
# Main routine
# ----------------------------
def main() -> None:
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <path-to-model.yaml>", file=sys.stderr)
sys.exit(1)
model_yaml_path = Path(sys.argv[1])
if not model_yaml_path.exists():
print(f"Model YAML not found: {model_yaml_path}", file=sys.stderr)
sys.exit(1)
repo_root = ensure_repo_root()
# Load YAML
with open(model_yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
model_dir = model_yaml_path.parent
# Proactively set up LFS tracking by patterns (idempotent)
lfs_track_patterns(LFS_PATTERNS)
# Iterate formats & variants
formats = (data.get("model") or {}).get("formats") or []
for fmt in formats:
variants = fmt.get("variants") or []
for variant in variants:
variant_id = variant.get("id")
hf_repo = variant.get("hf_repo")
if not hf_repo or not variant_id:
continue
dest_path = model_dir / variant_id
dest_path.mkdir(parents=True, exist_ok=True)
repo_id = hf_repo.replace("https://huggingface.co/", "")
print(f"\n[DL] Downloading/resuming variant '{variant_id}' from '{repo_id}' into '{dest_path}'")
# Always call snapshot_download with resume enabled. Filter out .cache.
try:
snapshot_download(
repo_id=repo_id,
local_dir=str(dest_path),
local_dir_use_symlinks=False,
resume_download=True, # explicit
ignore_patterns=[".cache/**"], # prevent housekeeping into tree
)
except Exception as e:
print(f"❌ snapshot_download failed for {variant_id}: {e}", file=sys.stderr)
raise
# Scan files, compute size, and ensure big files are tracked by LFS
files_list: list[str] = []
total_size_bytes = 0
for p in list_files_under(dest_path):
rel_to_model = p.relative_to(model_dir)
files_list.append(str(rel_to_model).replace("\\", "/"))
try:
size = p.stat().st_size
except FileNotFoundError:
# if a file was removed mid-scan, skip it
continue
total_size_bytes += size
# Fallback: ensure big files get tracked even if patterns miss them
if size > SIZE_THRESHOLD_BYTES:
rel_to_repo = repo_relative_path(repo_root, p)
lfs_track_file(repo_root, rel_to_repo)
files_list.sort()
variant["files"] = files_list
variant["size_bytes"] = int(total_size_bytes)
# Save updated YAML progressively after each variant
with open(model_yaml_path, "w", encoding="utf-8") as f:
yaml.dump(data, f, sort_keys=False, allow_unicode=True)
print(f"✅ Updated {model_yaml_path} for variant '{variant_id}'")
# ---- Built-in cleanup (replaces cleanup.sh) ----
print(f"🧹 Running cleanup for {variant_id}...")
try:
# Only scope renormalization to model_dir to keep things fast.
git_changed = git_status_has_changes(repo_root)
if not git_changed:
print("No new files or changes to commit. Skipping commit and push.")
else:
commit_message = f"Add/update model files for {model_dir.name}/{variant_id}"
git_stage_and_commit_push(
repo_root=repo_root,
scope_paths=[model_dir],
commit_message=commit_message,
lfs_push_all=LFS_PUSH_ALL,
)
except subprocess.CalledProcessError as e:
print(f"❌ Cleanup failed (continue to next variant): {e}", file=sys.stderr)
# Continue to next variant
print(f"\n✅ Download and YAML update complete for {model_yaml_path}.")
if __name__ == "__main__":
main()