Setup repo with Phi 3
This commit is contained in:
167
tools/download.py
Normal file
167
tools/download.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
download.py - Download/repair model files and update model.yaml metadata.
|
||||
|
||||
Usage:
|
||||
./tools/download.py models/llama-2-7b-chat/model.yaml
|
||||
|
||||
- Always (re)runs snapshot_download with resume support, so partially
|
||||
fetched directories get completed instead of being skipped.
|
||||
- Updates YAML after each variant with fresh file list + total size.
|
||||
- Tracks LFS via sensible patterns (plus a size threshold fallback).
|
||||
- Emits clear logs so you can see progress per variant.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import yaml
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
LFS_PATTERNS: list[str] = [
|
||||
# Extensions commonly used for model artifacts
|
||||
"*.safetensors",
|
||||
"*.bin",
|
||||
"*.pt",
|
||||
"*.gguf",
|
||||
"*.onnx",
|
||||
"*.ckpt",
|
||||
"*.tensors",
|
||||
"*.npz",
|
||||
"*.tar",
|
||||
"*.tar.gz",
|
||||
"*.zip",
|
||||
]
|
||||
|
||||
SIZE_THRESHOLD_BYTES = 1_000_000 # 1 MB fallback if a file doesn't match any pattern
|
||||
|
||||
def run(cmd: list[str], check: bool = True) -> None:
|
||||
subprocess.run(cmd, check=check)
|
||||
|
||||
|
||||
def track_lfs_patterns(patterns: Iterable[str]) -> None:
|
||||
"""
|
||||
Track a set of patterns in Git LFS. This is idempotent; it just
|
||||
appends to .gitattributes as needed.
|
||||
"""
|
||||
for patt in patterns:
|
||||
try:
|
||||
run(["git", "lfs", "track", patt], check=False)
|
||||
except Exception:
|
||||
# Non-fatal: we’ll still fall back to per-file size rule below.
|
||||
pass
|
||||
|
||||
|
||||
def list_files_under(root: Path) -> list[Path]:
|
||||
return [p for p in root.rglob("*") if p.is_file()]
|
||||
|
||||
|
||||
def ensure_repo_root() -> None:
|
||||
# best effort: warn (but don’t die) if not in a git repo
|
||||
try:
|
||||
subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
except Exception:
|
||||
print("⚠️ Not inside a Git repository? Git/LFS steps may fail.", file=sys.stderr)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 2:
|
||||
print(f"Usage: {sys.argv[0]} <path-to-model.yaml>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_yaml_path = Path(sys.argv[1])
|
||||
if not model_yaml_path.exists():
|
||||
print(f"Model YAML not found: {model_yaml_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
ensure_repo_root()
|
||||
|
||||
# Load YAML
|
||||
with open(model_yaml_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
model_dir = model_yaml_path.parent
|
||||
|
||||
# Proactively set up LFS tracking by patterns (idempotent)
|
||||
track_lfs_patterns(LFS_PATTERNS)
|
||||
|
||||
# Iterate formats & variants
|
||||
formats = (data.get("model") or {}).get("formats") or []
|
||||
for fmt in formats:
|
||||
variants = fmt.get("variants") or []
|
||||
for variant in variants:
|
||||
variant_id = variant.get("id")
|
||||
hf_repo = variant.get("hf_repo")
|
||||
|
||||
if not hf_repo or not variant_id:
|
||||
continue
|
||||
|
||||
dest_path = model_dir / variant_id
|
||||
dest_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_id = hf_repo.replace("https://huggingface.co/", "")
|
||||
print(f"\n[DL] Downloading/resuming variant '{variant_id}' from '{repo_id}' into '{dest_path}'")
|
||||
|
||||
# Always call snapshot_download with resume enabled. This will:
|
||||
# - no-op for already-complete files
|
||||
# - resume partials
|
||||
# - fetch any missing files
|
||||
try:
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=str(dest_path),
|
||||
local_dir_use_symlinks=False,
|
||||
resume_download=True, # explicit
|
||||
# You can add allow_patterns / ignore_patterns if you want to filter
|
||||
# allow_patterns=None,
|
||||
# ignore_patterns=None,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ snapshot_download failed for {variant_id}: {e}", file=sys.stderr)
|
||||
raise
|
||||
|
||||
# Scan files, compute size, and ensure big files are tracked by LFS
|
||||
files_list: list[str] = []
|
||||
total_size_bytes = 0
|
||||
|
||||
for p in list_files_under(dest_path):
|
||||
rel = p.relative_to(model_dir)
|
||||
files_list.append(str(rel))
|
||||
try:
|
||||
size = p.stat().st_size
|
||||
except FileNotFoundError:
|
||||
# if a file was removed mid-scan, skip it
|
||||
continue
|
||||
total_size_bytes += size
|
||||
|
||||
# Fallback: ensure big files get tracked even if patterns miss them
|
||||
if size > SIZE_THRESHOLD_BYTES:
|
||||
# Idempotent; harmless if already tracked.
|
||||
run(["git", "lfs", "track", str(p)], check=False)
|
||||
|
||||
files_list.sort()
|
||||
variant["files"] = files_list
|
||||
variant["size_bytes"] = int(total_size_bytes)
|
||||
|
||||
# Save updated YAML progressively after each variant
|
||||
with open(model_yaml_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, sort_keys=False, allow_unicode=True)
|
||||
|
||||
print(f"✅ Updated {model_yaml_path} for variant '{variant_id}'")
|
||||
# Run cleanup script to commit, push, and prune
|
||||
commit_message = f"Add/update model files for {model_dir.name}/{variant_id}"
|
||||
print(f"🧹 Running cleanup for {variant_id}...")
|
||||
try:
|
||||
run(["./tools/cleanup.sh", commit_message], check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"❌ cleanup.sh failed (continue to next variant): {e}", file=sys.stderr)
|
||||
# Decide whether to continue or abort; continuing is usually fine.
|
||||
# raise # uncomment to abort on failure
|
||||
|
||||
print(f"\n✅ Download and YAML update complete for {model_yaml_path}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user