generated from nhcarrigan/template
Compare commits
17 Commits
main
...
v0.1.0-alpha
| Author | SHA1 | Date | |
|---|---|---|---|
|
e98b115104
|
|||
|
a5d6119ee0
|
|||
|
1918ffb868
|
|||
|
f9cb13f53f
|
|||
|
23519c0502
|
|||
|
76d851ad11
|
|||
|
4ed7462a17
|
|||
|
9efda8ded6
|
|||
|
9bf92d3365
|
|||
|
43a544a886
|
|||
|
a6843cb3f1
|
|||
|
e6c19b589e
|
|||
|
df8a89e05d
|
|||
|
c3acd8e7a6
|
|||
|
74c334c939
|
|||
| 3c8a46e5a6 | |||
|
96494a9997
|
@@ -0,0 +1,2 @@
|
|||||||
|
[target.x86_64-pc-windows-msvc]
|
||||||
|
linker = "lld-link"
|
||||||
+3
-1
@@ -5,4 +5,6 @@
|
|||||||
|
|
||||||
# Ignore binary files >:(
|
# Ignore binary files >:(
|
||||||
*.png binary
|
*.png binary
|
||||||
*.jpg binary
|
*.jpg binary
|
||||||
|
*.ico binary
|
||||||
|
*.icns binary
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
name: 🐛 Bug Report
|
name: 🐛 Bug Report
|
||||||
description: Something isn't working as expected? Let us know!
|
description: Something isn't working as expected? Let us know!
|
||||||
title: '[BUG] - '
|
title: "[BUG] - "
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
@@ -50,7 +50,7 @@ body:
|
|||||||
description: The operating system you are using, including the version/build number.
|
description: The operating system you are using, including the version/build number.
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
# Remove this section for non-web apps.
|
# Remove this section for non-web apps.
|
||||||
- type: input
|
- type: input
|
||||||
id: browser
|
id: browser
|
||||||
attributes:
|
attributes:
|
||||||
@@ -66,4 +66,3 @@ body:
|
|||||||
- No
|
- No
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
|
||||||
|
|||||||
@@ -2,4 +2,4 @@ blank_issues_enabled: false
|
|||||||
contact_links:
|
contact_links:
|
||||||
- name: "Discord"
|
- name: "Discord"
|
||||||
url: "https://chat.nhcarrigan.com"
|
url: "https://chat.nhcarrigan.com"
|
||||||
about: "Chat with us directly."
|
about: "Chat with us directly."
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
name: 💭 Feature Proposal
|
name: 💭 Feature Proposal
|
||||||
description: Have an idea for how we can improve? Share it here!
|
description: Have an idea for how we can improve? Share it here!
|
||||||
title: '[FEAT] - '
|
title: "[FEAT] - "
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
name: ❓ Other Issue
|
name: ❓ Other Issue
|
||||||
description: I have something that is neither a bug nor a feature request.
|
description: I have something that is neither a bug nor a feature request.
|
||||||
title: '[OTHER] - '
|
title: "[OTHER] - "
|
||||||
labels:
|
labels:
|
||||||
- "status/awaiting triage"
|
- "status/awaiting triage"
|
||||||
body:
|
body:
|
||||||
|
|||||||
@@ -0,0 +1,201 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint-and-test:
|
||||||
|
name: Lint & Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Linux dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y \
|
||||||
|
libwebkit2gtk-4.1-dev \
|
||||||
|
librsvg2-dev \
|
||||||
|
patchelf \
|
||||||
|
libgtk-3-dev \
|
||||||
|
libayatana-appindicator3-dev \
|
||||||
|
libasound2-dev \
|
||||||
|
pkg-config \
|
||||||
|
libclang-dev \
|
||||||
|
cmake
|
||||||
|
|
||||||
|
- name: Setup pnpm
|
||||||
|
uses: pnpm/action-setup@v4
|
||||||
|
with:
|
||||||
|
version: 9
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
cache: pnpm
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Run ESLint
|
||||||
|
run: pnpm lint
|
||||||
|
|
||||||
|
- name: Run Prettier check
|
||||||
|
run: pnpm format:check
|
||||||
|
|
||||||
|
- name: Build frontend
|
||||||
|
run: pnpm build
|
||||||
|
|
||||||
|
- name: Run frontend tests
|
||||||
|
run: pnpm test
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
components: clippy
|
||||||
|
|
||||||
|
- name: Cache Rust dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
src-tauri/target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Run Clippy
|
||||||
|
working-directory: src-tauri
|
||||||
|
run: cargo clippy --all-targets --all-features -- -D warnings
|
||||||
|
|
||||||
|
- name: Run Rust tests
|
||||||
|
working-directory: src-tauri
|
||||||
|
run: cargo test
|
||||||
|
|
||||||
|
build-linux:
|
||||||
|
name: Build Linux
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint-and-test
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Linux dependencies
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y \
|
||||||
|
libwebkit2gtk-4.1-dev \
|
||||||
|
librsvg2-dev \
|
||||||
|
patchelf \
|
||||||
|
libgtk-3-dev \
|
||||||
|
libayatana-appindicator3-dev \
|
||||||
|
libasound2-dev \
|
||||||
|
pkg-config \
|
||||||
|
libclang-dev \
|
||||||
|
cmake \
|
||||||
|
xdg-utils
|
||||||
|
|
||||||
|
- name: Setup pnpm
|
||||||
|
uses: pnpm/action-setup@v4
|
||||||
|
with:
|
||||||
|
version: 9
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
cache: pnpm
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Cache Rust dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
src-tauri/target/
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Build Linux
|
||||||
|
run: pnpm build:linux
|
||||||
|
|
||||||
|
build-windows:
|
||||||
|
name: Build Windows (cross-compile)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint-and-test
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Linux dependencies for cross-compilation
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y \
|
||||||
|
libwebkit2gtk-4.1-dev \
|
||||||
|
librsvg2-dev \
|
||||||
|
patchelf \
|
||||||
|
libgtk-3-dev \
|
||||||
|
libayatana-appindicator3-dev \
|
||||||
|
libasound2-dev \
|
||||||
|
pkg-config \
|
||||||
|
libclang-dev \
|
||||||
|
cmake \
|
||||||
|
clang \
|
||||||
|
lld \
|
||||||
|
llvm \
|
||||||
|
nsis
|
||||||
|
|
||||||
|
- name: Setup pnpm
|
||||||
|
uses: pnpm/action-setup@v4
|
||||||
|
with:
|
||||||
|
version: 9
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
cache: pnpm
|
||||||
|
|
||||||
|
- name: Install frontend dependencies
|
||||||
|
run: pnpm install
|
||||||
|
|
||||||
|
- name: Setup Rust
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
with:
|
||||||
|
targets: x86_64-pc-windows-msvc
|
||||||
|
|
||||||
|
- name: Install cargo-xwin
|
||||||
|
run: |
|
||||||
|
curl -fsSL https://github.com/rust-cross/cargo-xwin/releases/download/v0.20.2/cargo-xwin-v0.20.2.x86_64-unknown-linux-musl.tar.gz | tar xz
|
||||||
|
sudo mv cargo-xwin /usr/local/bin/
|
||||||
|
|
||||||
|
- name: Cache Rust dependencies
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo/bin/
|
||||||
|
~/.cargo/registry/index/
|
||||||
|
~/.cargo/registry/cache/
|
||||||
|
~/.cargo/git/db/
|
||||||
|
src-tauri/target/
|
||||||
|
key: ${{ runner.os }}-cargo-windows-${{ hashFiles('**/Cargo.lock') }}
|
||||||
|
|
||||||
|
- name: Build Windows
|
||||||
|
run: pnpm build:windows
|
||||||
@@ -2,11 +2,11 @@ name: Security Scan and Upload
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches: [main]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [main]
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1'
|
- cron: "0 0 * * 1"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -24,18 +24,18 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
DD_URL: ${{ secrets.DD_URL }}
|
DD_URL: ${{ secrets.DD_URL }}
|
||||||
DD_TOKEN: ${{ secrets.DD_TOKEN }}
|
DD_TOKEN: ${{ secrets.DD_TOKEN }}
|
||||||
PRODUCT_NAME: ${{ github.repository }}
|
PRODUCT_NAME: ${{ github.repository }}
|
||||||
PRODUCT_TYPE_ID: 1
|
PRODUCT_TYPE_ID: 1
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install jq -y > /dev/null
|
sudo apt-get install jq -y > /dev/null
|
||||||
|
|
||||||
echo "Checking connection to $DD_URL..."
|
echo "Checking connection to $DD_URL..."
|
||||||
|
|
||||||
# Check if product exists - capture HTTP code to debug connection issues
|
# Check if product exists - capture HTTP code to debug connection issues
|
||||||
RESPONSE=$(curl --write-out "%{http_code}" --silent --output /tmp/response.json \
|
RESPONSE=$(curl --write-out "%{http_code}" --silent --output /tmp/response.json \
|
||||||
-H "Authorization: Token $DD_TOKEN" \
|
-H "Authorization: Token $DD_TOKEN" \
|
||||||
"$DD_URL/api/v2/products/?name=$PRODUCT_NAME")
|
"$DD_URL/api/v2/products/?name=$PRODUCT_NAME")
|
||||||
|
|
||||||
# If response is not 200, print error
|
# If response is not 200, print error
|
||||||
if [ "$RESPONSE" != "200" ]; then
|
if [ "$RESPONSE" != "200" ]; then
|
||||||
echo "::error::Failed to query DefectDojo. HTTP Code: $RESPONSE"
|
echo "::error::Failed to query DefectDojo. HTTP Code: $RESPONSE"
|
||||||
@@ -44,7 +44,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
COUNT=$(cat /tmp/response.json | jq -r '.count')
|
COUNT=$(cat /tmp/response.json | jq -r '.count')
|
||||||
|
|
||||||
if [ "$COUNT" = "0" ]; then
|
if [ "$COUNT" = "0" ]; then
|
||||||
echo "Creating product '$PRODUCT_NAME'..."
|
echo "Creating product '$PRODUCT_NAME'..."
|
||||||
curl -s -X POST "$DD_URL/api/v2/products/" \
|
curl -s -X POST "$DD_URL/api/v2/products/" \
|
||||||
@@ -75,7 +75,7 @@ jobs:
|
|||||||
echo "Uploading Trivy results..."
|
echo "Uploading Trivy results..."
|
||||||
# Generate today's date in YYYY-MM-DD format
|
# Generate today's date in YYYY-MM-DD format
|
||||||
TODAY=$(date +%Y-%m-%d)
|
TODAY=$(date +%Y-%m-%d)
|
||||||
|
|
||||||
HTTP_CODE=$(curl --write-out "%{http_code}" --output response.txt --silent -X POST "$DD_URL/api/v2/import-scan/" \
|
HTTP_CODE=$(curl --write-out "%{http_code}" --output response.txt --silent -X POST "$DD_URL/api/v2/import-scan/" \
|
||||||
-H "Authorization: Token $DD_TOKEN" \
|
-H "Authorization: Token $DD_TOKEN" \
|
||||||
-F "active=true" \
|
-F "active=true" \
|
||||||
@@ -86,7 +86,7 @@ jobs:
|
|||||||
-F "scan_date=$TODAY" \
|
-F "scan_date=$TODAY" \
|
||||||
-F "auto_create_context=true" \
|
-F "auto_create_context=true" \
|
||||||
-F "file=@trivy-results.json")
|
-F "file=@trivy-results.json")
|
||||||
|
|
||||||
if [[ "$HTTP_CODE" != "200" && "$HTTP_CODE" != "201" ]]; then
|
if [[ "$HTTP_CODE" != "200" && "$HTTP_CODE" != "201" ]]; then
|
||||||
echo "::error::Upload Failed with HTTP $HTTP_CODE"
|
echo "::error::Upload Failed with HTTP $HTTP_CODE"
|
||||||
echo "--- SERVER RESPONSE ---"
|
echo "--- SERVER RESPONSE ---"
|
||||||
@@ -154,7 +154,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
echo "Uploading Semgrep results..."
|
echo "Uploading Semgrep results..."
|
||||||
TODAY=$(date +%Y-%m-%d)
|
TODAY=$(date +%Y-%m-%d)
|
||||||
|
|
||||||
HTTP_CODE=$(curl --write-out "%{http_code}" --output response.txt --silent -X POST "$DD_URL/api/v2/import-scan/" \
|
HTTP_CODE=$(curl --write-out "%{http_code}" --output response.txt --silent -X POST "$DD_URL/api/v2/import-scan/" \
|
||||||
-H "Authorization: Token $DD_TOKEN" \
|
-H "Authorization: Token $DD_TOKEN" \
|
||||||
-F "active=true" \
|
-F "active=true" \
|
||||||
@@ -174,4 +174,4 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
echo "Upload Success!"
|
echo "Upload Success!"
|
||||||
fi
|
fi
|
||||||
|
|||||||
+61
@@ -0,0 +1,61 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
.venv/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# Models - large ML model files
|
||||||
|
models/
|
||||||
|
src/pretrained_models/
|
||||||
|
src-tauri/resources/models/
|
||||||
|
*.gguf
|
||||||
|
*.bin
|
||||||
|
|
||||||
|
# Tauri
|
||||||
|
src-tauri/target/
|
||||||
|
src-tauri/WixTools/
|
||||||
|
src-tauri/resources/
|
||||||
|
|
||||||
|
# Build outputs
|
||||||
|
build/
|
||||||
|
|
||||||
|
# App data
|
||||||
|
recordings/
|
||||||
|
transcripts/
|
||||||
|
summaries/
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
*.env.local
|
||||||
|
prod.env
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
build/
|
||||||
|
.svelte-kit/
|
||||||
|
dist/
|
||||||
|
src-tauri/target/
|
||||||
|
src-tauri/gen/
|
||||||
|
node_modules/
|
||||||
|
.pnpm-store/
|
||||||
|
pnpm-lock.yaml
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"semi": true,
|
||||||
|
"singleQuote": false,
|
||||||
|
"tabWidth": 2,
|
||||||
|
"trailingComma": "es5",
|
||||||
|
"printWidth": 100
|
||||||
|
}
|
||||||
Vendored
+3
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"recommendations": ["tauri-apps.tauri-vscode", "rust-lang.rust-analyzer"]
|
||||||
|
}
|
||||||
+120
@@ -0,0 +1,120 @@
|
|||||||
|
# Chronara - Local Meeting Transcription & Summarization
|
||||||
|
|
||||||
|
A Windows desktop application that transcribes, diarizes, and summarizes meetings using only locally-running models.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🎙️ Real-time audio transcription with speaker diarization (WhisperX)
|
||||||
|
- 📝 Intelligent meeting summarization (Llama 3.2)
|
||||||
|
- 🖥️ Everything runs locally - no cloud services required
|
||||||
|
- 📦 All models bundled - no separate downloads needed
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
- **Transcription**: WhisperX (Whisper + speaker diarization)
|
||||||
|
- **Summarization**: Llama 3.2 1B/3B
|
||||||
|
- **Backend**: Python with FastAPI
|
||||||
|
- **Frontend**: Tauri + React
|
||||||
|
- **Model Runtime**: llama-cpp-python
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
chronara/
|
||||||
|
├── src/
|
||||||
|
│ ├── backend/ # Python FastAPI backend
|
||||||
|
│ ├── components/ # React components
|
||||||
|
│ └── App.tsx # Main React app
|
||||||
|
├── src-tauri/ # Tauri configuration
|
||||||
|
├── models/ # Bundled model files
|
||||||
|
├── scripts/ # Build and setup scripts
|
||||||
|
└── assets/ # Icons, resources
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Node.js 18+ with pnpm
|
||||||
|
- Python 3.10+
|
||||||
|
- Rust (for Tauri)
|
||||||
|
- Windows build tools (for native modules)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. Clone the repository:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/naomi-lgbt/chronara.git
|
||||||
|
cd chronara
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install frontend dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Install Python dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Download the AI models:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/download_models.py
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Run in development mode:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pnpm tauri:dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building for Production
|
||||||
|
|
||||||
|
### Windows
|
||||||
|
|
||||||
|
1. Download models if not already done:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/download_models.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Build the Windows executable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/build_windows.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The installer will be created in `src-tauri/target/release/bundle/nsis/`.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. **Start Recording**: Click the "Start Recording" button to begin capturing audio
|
||||||
|
2. **Real-time Transcription**: Watch as the conversation is transcribed with speaker labels
|
||||||
|
3. **Generate Summary**: After recording, click "Generate Summary" for an AI-powered meeting summary
|
||||||
|
4. **Export**: Download both the full transcript and summary as text files
|
||||||
|
|
||||||
|
## Model Information
|
||||||
|
|
||||||
|
### Transcription (WhisperX)
|
||||||
|
|
||||||
|
- **Model**: OpenAI Whisper base model with WhisperX enhancements
|
||||||
|
- **Features**: Speaker diarization, timestamp alignment
|
||||||
|
- **Size**: ~150MB
|
||||||
|
|
||||||
|
### Summarization (Llama 3.2)
|
||||||
|
|
||||||
|
- **1B Model**: Fast, good for basic summaries (~1.2GB)
|
||||||
|
- **3B Model**: Better quality summaries (~2.5GB)
|
||||||
|
- **Format**: GGUF quantized models
|
||||||
|
|
||||||
|
## Privacy & Security
|
||||||
|
|
||||||
|
- All processing happens locally on your machine
|
||||||
|
- No audio or text data is sent to external servers
|
||||||
|
- Models are bundled with the application
|
||||||
|
- Meeting data stays on your device
|
||||||
@@ -1,16 +1,6 @@
|
|||||||
# New Repository Template
|
# Chronara
|
||||||
|
|
||||||
This template contains all of our basic files for a new GitHub repository. There is also a handy workflow that will create an issue on a new repository made from this template, with a checklist for the steps we usually take in setting up a new repository.
|
A meeting transcription and summarisation tool that uses 100% local models.
|
||||||
|
|
||||||
If you're starting a Node.JS project with TypeScript, we have a [specific template](https://github.com/naomi-lgbt/nodejs-typescript-template) for that purpose.
|
|
||||||
|
|
||||||
## Readme
|
|
||||||
|
|
||||||
Delete all of the above text (including this line), and uncomment the below text to use our standard readme template.
|
|
||||||
|
|
||||||
<!-- # Project Name
|
|
||||||
|
|
||||||
Project Description
|
|
||||||
|
|
||||||
## Live Version
|
## Live Version
|
||||||
|
|
||||||
@@ -36,4 +26,4 @@ Copyright held by Naomi Carrigan.
|
|||||||
|
|
||||||
## Contact
|
## Contact
|
||||||
|
|
||||||
We may be contacted through our [Chat Server](http://chat.nhcarrigan.com) or via email at `contact@nhcarrigan.com`. -->
|
We may be contacted through our [Chat Server](http://chat.nhcarrigan.com) or via email at `contact@nhcarrigan.com`.
|
||||||
|
|||||||
Executable
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "🔍 Running all checks..."
|
||||||
|
echo "========================================"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "📦 Installing dependencies..."
|
||||||
|
pnpm install
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🔎 Running ESLint..."
|
||||||
|
pnpm lint
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "💅 Running Prettier check..."
|
||||||
|
pnpm format:check
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🏗️ Building frontend..."
|
||||||
|
pnpm build
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🧪 Running frontend tests..."
|
||||||
|
pnpm test
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🦀 Running Clippy..."
|
||||||
|
cd src-tauri
|
||||||
|
cargo clippy --all-targets --all-features -- -D warnings
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "🧪 Running Rust tests..."
|
||||||
|
cargo test
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "========================================"
|
||||||
|
echo "✅ All checks passed!"
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
import js from "@eslint/js";
|
||||||
|
import tseslint from "typescript-eslint";
|
||||||
|
import reactHooks from "eslint-plugin-react-hooks";
|
||||||
|
import reactRefresh from "eslint-plugin-react-refresh";
|
||||||
|
import prettier from "eslint-config-prettier";
|
||||||
|
import globals from "globals";
|
||||||
|
|
||||||
|
export default tseslint.config(
|
||||||
|
js.configs.recommended,
|
||||||
|
...tseslint.configs.recommended,
|
||||||
|
prettier,
|
||||||
|
{
|
||||||
|
languageOptions: {
|
||||||
|
globals: {
|
||||||
|
...globals.browser,
|
||||||
|
...globals.node,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
files: ["**/*.{ts,tsx}"],
|
||||||
|
plugins: {
|
||||||
|
"react-hooks": reactHooks,
|
||||||
|
"react-refresh": reactRefresh,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
...reactHooks.configs.recommended.rules,
|
||||||
|
"react-refresh/only-export-components": ["warn", { allowConstantExport: true }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ignores: ["build/", "dist/", "src-tauri/target/", "node_modules/"],
|
||||||
|
}
|
||||||
|
);
|
||||||
+14
@@ -0,0 +1,14 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>Tauri + React + Typescript</title>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.tsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"name": "chronara",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.1.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "tsc && vite build",
|
||||||
|
"lint": "eslint src",
|
||||||
|
"lint:fix": "eslint src --fix",
|
||||||
|
"format": "prettier --write .",
|
||||||
|
"format:check": "prettier --check .",
|
||||||
|
"preview": "vite preview",
|
||||||
|
"tauri": "tauri",
|
||||||
|
"tauri:dev": "tauri dev",
|
||||||
|
"build:linux": "tauri build",
|
||||||
|
"build:windows": "./scripts/build-windows-nsis.sh",
|
||||||
|
"build:all": "pnpm build:linux && pnpm build:windows",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest",
|
||||||
|
"test:coverage": "vitest run --coverage"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@tauri-apps/api": "^2",
|
||||||
|
"@tauri-apps/plugin-opener": "^2",
|
||||||
|
"react": "^19.1.0",
|
||||||
|
"react-dom": "^19.1.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@eslint/js": "^9.19.0",
|
||||||
|
"@tauri-apps/cli": "^2",
|
||||||
|
"@testing-library/jest-dom": "^6.9.1",
|
||||||
|
"@testing-library/react": "^16.3.0",
|
||||||
|
"@types/react": "^19.1.8",
|
||||||
|
"@types/react-dom": "^19.1.6",
|
||||||
|
"@vitejs/plugin-react": "^4.6.0",
|
||||||
|
"eslint": "^9.19.0",
|
||||||
|
"eslint-config-prettier": "^10.1.8",
|
||||||
|
"eslint-plugin-react-hooks": "^5.2.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.4.20",
|
||||||
|
"globals": "^17.0.0",
|
||||||
|
"jsdom": "^27.4.0",
|
||||||
|
"prettier": "^3.8.0",
|
||||||
|
"typescript": "~5.8.3",
|
||||||
|
"typescript-eslint": "^8.53.0",
|
||||||
|
"vite": "^7.0.4",
|
||||||
|
"vitest": "^4.0.17"
|
||||||
|
}
|
||||||
|
}
|
||||||
Generated
+373
@@ -0,0 +1,373 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 4
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aho-corasick"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bindgen"
|
||||||
|
version = "0.72.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"cexpr",
|
||||||
|
"clang-sys",
|
||||||
|
"itertools",
|
||||||
|
"log",
|
||||||
|
"prettyplease",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"regex",
|
||||||
|
"rustc-hash",
|
||||||
|
"shlex",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bitflags"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cc"
|
||||||
|
version = "1.2.49"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215"
|
||||||
|
dependencies = [
|
||||||
|
"find-msvc-tools",
|
||||||
|
"jobserver",
|
||||||
|
"libc",
|
||||||
|
"shlex",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cexpr"
|
||||||
|
version = "0.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cfg-if"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "clang-sys"
|
||||||
|
version = "1.8.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"libc",
|
||||||
|
"libloading",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cmake"
|
||||||
|
version = "0.1.56"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b042e5d8a74ae91bb0961acd039822472ec99f8ab0948cbf6d1369588f8be586"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3dca9240753cf90908d7e4aac30f630662b02aebaa1b58a3cadabdb23385b58b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "find-msvc-tools"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "find_cuda_helper"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f9f9e65c593dd01ac77daad909ea4ad17f0d6d1776193fc8ea766356177abdad"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "glob"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jobserver"
|
||||||
|
version = "0.1.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.155"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libloading"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "llama-cpp-sys-2"
|
||||||
|
version = "0.1.132"
|
||||||
|
dependencies = [
|
||||||
|
"bindgen",
|
||||||
|
"cc",
|
||||||
|
"cmake",
|
||||||
|
"find_cuda_helper",
|
||||||
|
"glob",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "log"
|
||||||
|
version = "0.4.21"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memchr"
|
||||||
|
version = "2.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "minimal-lexical"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "7.1.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"minimal-lexical",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "prettyplease"
|
||||||
|
version = "0.2.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.85"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.36"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex"
|
||||||
|
version = "1.10.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-automata",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-automata"
|
||||||
|
version = "0.4.7"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"memchr",
|
||||||
|
"regex-syntax",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "regex-syntax"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustc-hash"
|
||||||
|
version = "2.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "same-file"
|
||||||
|
version = "1.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shlex"
|
||||||
|
version = "1.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.87"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "walkdir"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
|
||||||
|
dependencies = [
|
||||||
|
"same-file",
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winapi-util"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
|
||||||
|
dependencies = [
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-sys"
|
||||||
|
version = "0.52.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||||
|
dependencies = [
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows-targets"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
|
||||||
|
dependencies = [
|
||||||
|
"windows_aarch64_gnullvm",
|
||||||
|
"windows_aarch64_msvc",
|
||||||
|
"windows_i686_gnu",
|
||||||
|
"windows_i686_gnullvm",
|
||||||
|
"windows_i686_msvc",
|
||||||
|
"windows_x86_64_gnu",
|
||||||
|
"windows_x86_64_gnullvm",
|
||||||
|
"windows_x86_64_msvc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_aarch64_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnu"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_i686_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnu"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_gnullvm"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "windows_x86_64_msvc"
|
||||||
|
version = "0.52.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||||
|
#
|
||||||
|
# When uploading crates to the registry Cargo will automatically
|
||||||
|
# "normalize" Cargo.toml files for maximal compatibility
|
||||||
|
# with all versions of Cargo and also rewrite `path` dependencies
|
||||||
|
# to registry (e.g., crates.io) dependencies.
|
||||||
|
#
|
||||||
|
# If you are reading this file be aware that the original Cargo.toml
|
||||||
|
# will likely look very different (and much more reasonable).
|
||||||
|
# See Cargo.toml.orig for the original contents.
|
||||||
|
|
||||||
|
[package]
|
||||||
|
edition = "2021"
|
||||||
|
name = "llama-cpp-sys-2"
|
||||||
|
version = "0.1.132"
|
||||||
|
build = "build.rs"
|
||||||
|
links = "llama"
|
||||||
|
include = [
|
||||||
|
"wrapper.h",
|
||||||
|
"wrapper_mtmd.h",
|
||||||
|
"build.rs",
|
||||||
|
"/src",
|
||||||
|
"/llama.cpp/common/**/*.h",
|
||||||
|
"/llama.cpp/common/**/*.hpp",
|
||||||
|
"/llama.cpp/common/**/*.cpp",
|
||||||
|
"/llama.cpp/ggml/include/*.h",
|
||||||
|
"/llama.cpp/ggml/src/*.h",
|
||||||
|
"/llama.cpp/ggml/src/*.c",
|
||||||
|
"/llama.cpp/ggml/src/*.cpp",
|
||||||
|
"/llama.cpp/src/*.h",
|
||||||
|
"/llama.cpp/src/*.cpp",
|
||||||
|
"/llama.cpp/src/models/*.h",
|
||||||
|
"/llama.cpp/src/models/*.cpp",
|
||||||
|
"/llama.cpp/tools/mtmd/*.h",
|
||||||
|
"/llama.cpp/tools/mtmd/*.cpp",
|
||||||
|
"/llama.cpp/convert_hf_to_gguf.py",
|
||||||
|
"/llama.cpp/common/build-info.cpp.in",
|
||||||
|
"/llama.cpp/ggml/src/ggml-cuda.cu",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal.m",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal.metal",
|
||||||
|
"/llama.cpp/include/llama.h",
|
||||||
|
"/llama.cpp/include/llama-cpp.h",
|
||||||
|
"/llama.cpp/ggml/src/ggml-cpu/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-cuda/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-vulkan/**/*",
|
||||||
|
"/llama.cpp/ggml/src/llamafile/sgemm.h",
|
||||||
|
"/llama.cpp/ggml/src/llamafile/sgemm.cpp",
|
||||||
|
"/llama.cpp/pocs",
|
||||||
|
"/llama.cpp/vendor",
|
||||||
|
"/llama.cpp/CMakeLists.txt",
|
||||||
|
"/llama.cpp/common/CMakeLists.txt",
|
||||||
|
"/llama.cpp/ggml/CMakeLists.txt",
|
||||||
|
"/llama.cpp/ggml/src/CMakeLists.txt",
|
||||||
|
"/llama.cpp/src/CMakeLists.txt",
|
||||||
|
"/llama.cpp/cmake",
|
||||||
|
"/llama.cpp/ggml/cmake",
|
||||||
|
"/llama.cpp/common/cmake",
|
||||||
|
]
|
||||||
|
autolib = false
|
||||||
|
autobins = false
|
||||||
|
autoexamples = false
|
||||||
|
autotests = false
|
||||||
|
autobenches = false
|
||||||
|
description = "Low Level Bindings to llama.cpp"
|
||||||
|
readme = "README.md"
|
||||||
|
license = "MIT OR Apache-2.0"
|
||||||
|
repository = "https://github.com/utilityai/llama-cpp-rs"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
cuda = []
|
||||||
|
cuda-no-vmm = ["cuda"]
|
||||||
|
dynamic-link = []
|
||||||
|
metal = []
|
||||||
|
mtmd = []
|
||||||
|
openmp = []
|
||||||
|
shared-stdcxx = []
|
||||||
|
system-ggml = []
|
||||||
|
vulkan = []
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "llama_cpp_sys_2"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
|
||||||
|
[build-dependencies.bindgen]
|
||||||
|
version = "0.72.1"
|
||||||
|
|
||||||
|
[build-dependencies.cc]
|
||||||
|
version = "1.2.49"
|
||||||
|
features = ["parallel"]
|
||||||
|
|
||||||
|
[build-dependencies.cmake]
|
||||||
|
version = "0.1"
|
||||||
|
|
||||||
|
[build-dependencies.find_cuda_helper]
|
||||||
|
version = "0.2.0"
|
||||||
|
|
||||||
|
[build-dependencies.glob]
|
||||||
|
version = "0.3.3"
|
||||||
|
|
||||||
|
[build-dependencies.walkdir]
|
||||||
|
version = "2"
|
||||||
Generated
+85
@@ -0,0 +1,85 @@
|
|||||||
|
[package]
|
||||||
|
name = "llama-cpp-sys-2"
|
||||||
|
description = "Low Level Bindings to llama.cpp"
|
||||||
|
version = "0.1.132"
|
||||||
|
edition = "2021"
|
||||||
|
license = "MIT OR Apache-2.0"
|
||||||
|
repository = "https://github.com/utilityai/llama-cpp-rs"
|
||||||
|
links = "llama"
|
||||||
|
|
||||||
|
include = [
|
||||||
|
"wrapper.h",
|
||||||
|
"wrapper_mtmd.h",
|
||||||
|
"build.rs",
|
||||||
|
"/src",
|
||||||
|
|
||||||
|
"/llama.cpp/common/**/*.h",
|
||||||
|
"/llama.cpp/common/**/*.hpp",
|
||||||
|
"/llama.cpp/common/**/*.cpp",
|
||||||
|
"/llama.cpp/ggml/include/*.h",
|
||||||
|
"/llama.cpp/ggml/src/*.h",
|
||||||
|
"/llama.cpp/ggml/src/*.c",
|
||||||
|
"/llama.cpp/ggml/src/*.cpp",
|
||||||
|
"/llama.cpp/src/*.h",
|
||||||
|
"/llama.cpp/src/*.cpp",
|
||||||
|
"/llama.cpp/src/models/*.h",
|
||||||
|
"/llama.cpp/src/models/*.cpp",
|
||||||
|
"/llama.cpp/tools/mtmd/*.h",
|
||||||
|
"/llama.cpp/tools/mtmd/*.cpp",
|
||||||
|
|
||||||
|
"/llama.cpp/convert_hf_to_gguf.py", # Yes, it's required
|
||||||
|
"/llama.cpp/common/build-info.cpp.in",
|
||||||
|
|
||||||
|
"/llama.cpp/ggml/src/ggml-cuda.cu",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal.m",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal.metal",
|
||||||
|
|
||||||
|
"/llama.cpp/include/llama.h",
|
||||||
|
"/llama.cpp/include/llama-cpp.h",
|
||||||
|
|
||||||
|
"/llama.cpp/ggml/src/ggml-cpu/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-cuda/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-metal/**/*",
|
||||||
|
"/llama.cpp/ggml/src/ggml-vulkan/**/*",
|
||||||
|
|
||||||
|
"/llama.cpp/ggml/src/llamafile/sgemm.h",
|
||||||
|
"/llama.cpp/ggml/src/llamafile/sgemm.cpp",
|
||||||
|
|
||||||
|
"/llama.cpp/pocs",
|
||||||
|
"/llama.cpp/vendor",
|
||||||
|
|
||||||
|
"/llama.cpp/CMakeLists.txt",
|
||||||
|
"/llama.cpp/common/CMakeLists.txt",
|
||||||
|
"/llama.cpp/ggml/CMakeLists.txt",
|
||||||
|
"/llama.cpp/ggml/src/CMakeLists.txt",
|
||||||
|
"/llama.cpp/src/CMakeLists.txt",
|
||||||
|
|
||||||
|
"/llama.cpp/cmake",
|
||||||
|
"/llama.cpp/ggml/cmake",
|
||||||
|
"/llama.cpp/common/cmake",
|
||||||
|
]
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
bindgen = { workspace = true }
|
||||||
|
cc = { workspace = true, features = ["parallel"] }
|
||||||
|
cmake = "0.1"
|
||||||
|
find_cuda_helper = "0.2.0"
|
||||||
|
glob = "0.3.3"
|
||||||
|
walkdir = "2"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
cuda = []
|
||||||
|
# Disables the need to dynamically link against libcuda.so / cuda.dll
|
||||||
|
cuda-no-vmm = ["cuda"]
|
||||||
|
metal = []
|
||||||
|
dynamic-link = []
|
||||||
|
vulkan = []
|
||||||
|
openmp = []
|
||||||
|
# Only has an impact on Android.
|
||||||
|
shared-stdcxx = []
|
||||||
|
system-ggml = []
|
||||||
|
mtmd = []
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# llama-cpp-sys
|
||||||
|
|
||||||
|
Raw bindings to llama.cpp with cuda support.
|
||||||
|
|
||||||
|
See [llama-cpp-2](https://crates.io/crates/llama-cpp-2) for a safe API.
|
||||||
@@ -0,0 +1,952 @@
|
|||||||
|
use cmake::Config;
|
||||||
|
use glob::glob;
|
||||||
|
use std::env;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Command;
|
||||||
|
use walkdir::DirEntry;
|
||||||
|
|
||||||
|
enum WindowsVariant {
|
||||||
|
Msvc,
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum AppleVariant {
|
||||||
|
MacOS,
|
||||||
|
Other,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum TargetOs {
|
||||||
|
Windows(WindowsVariant),
|
||||||
|
Apple(AppleVariant),
|
||||||
|
Linux,
|
||||||
|
Android,
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! debug_log {
|
||||||
|
($($arg:tt)*) => {
|
||||||
|
if std::env::var("BUILD_DEBUG").is_ok() {
|
||||||
|
println!("cargo:warning=[DEBUG] {}", format!($($arg)*));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_target_os() -> Result<(TargetOs, String), String> {
|
||||||
|
let target = env::var("TARGET").unwrap();
|
||||||
|
|
||||||
|
if target.contains("windows") {
|
||||||
|
if target.ends_with("-windows-msvc") {
|
||||||
|
Ok((TargetOs::Windows(WindowsVariant::Msvc), target))
|
||||||
|
} else {
|
||||||
|
Ok((TargetOs::Windows(WindowsVariant::Other), target))
|
||||||
|
}
|
||||||
|
} else if target.contains("apple") {
|
||||||
|
if target.ends_with("-apple-darwin") {
|
||||||
|
Ok((TargetOs::Apple(AppleVariant::MacOS), target))
|
||||||
|
} else {
|
||||||
|
Ok((TargetOs::Apple(AppleVariant::Other), target))
|
||||||
|
}
|
||||||
|
} else if target.contains("android")
|
||||||
|
|| target == "aarch64-linux-android"
|
||||||
|
|| target == "armv7-linux-androideabi"
|
||||||
|
|| target == "i686-linux-android"
|
||||||
|
|| target == "x86_64-linux-android"
|
||||||
|
{
|
||||||
|
// Handle both full android targets and short names like arm64-v8a that cargo ndk might use
|
||||||
|
Ok((TargetOs::Android, target))
|
||||||
|
} else if target.contains("linux") {
|
||||||
|
Ok((TargetOs::Linux, target))
|
||||||
|
} else {
|
||||||
|
Err(target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_cargo_target_dir() -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||||
|
let out_dir = env::var("OUT_DIR")?;
|
||||||
|
let path = PathBuf::from(out_dir);
|
||||||
|
let target_dir = path
|
||||||
|
.ancestors()
|
||||||
|
.nth(3)
|
||||||
|
.ok_or("OUT_DIR is not deep enough")?;
|
||||||
|
Ok(target_dir.to_path_buf())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec<String> {
|
||||||
|
// Use CARGO_CFG_TARGET_OS to detect TARGET platform, not HOST
|
||||||
|
// This fixes cross-compilation from Linux to Windows
|
||||||
|
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
||||||
|
let lib_pattern = if target_os == "windows" {
|
||||||
|
"*.lib"
|
||||||
|
} else if target_os == "macos" {
|
||||||
|
if build_shared_libs {
|
||||||
|
"*.dylib"
|
||||||
|
} else {
|
||||||
|
"*.a"
|
||||||
|
}
|
||||||
|
} else if build_shared_libs {
|
||||||
|
"*.so"
|
||||||
|
} else {
|
||||||
|
"*.a"
|
||||||
|
};
|
||||||
|
let libs_dir = out_dir.join("lib*");
|
||||||
|
let pattern = libs_dir.join(lib_pattern);
|
||||||
|
debug_log!("Extract libs {}", pattern.display());
|
||||||
|
|
||||||
|
let mut lib_names: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
// Process the libraries based on the pattern
|
||||||
|
for entry in glob(pattern.to_str().unwrap()).unwrap() {
|
||||||
|
match entry {
|
||||||
|
Ok(path) => {
|
||||||
|
let stem = path.file_stem().unwrap();
|
||||||
|
let stem_str = stem.to_str().unwrap();
|
||||||
|
|
||||||
|
// Remove the "lib" prefix if present
|
||||||
|
let lib_name = if stem_str.starts_with("lib") {
|
||||||
|
stem_str.strip_prefix("lib").unwrap_or(stem_str)
|
||||||
|
} else {
|
||||||
|
if path.extension() == Some(std::ffi::OsStr::new("a")) {
|
||||||
|
let target = path.parent().unwrap().join(format!("lib{}.a", stem_str));
|
||||||
|
std::fs::rename(&path, &target).unwrap_or_else(|e| {
|
||||||
|
panic!("Failed to rename {path:?} to {target:?}: {e:?}");
|
||||||
|
})
|
||||||
|
}
|
||||||
|
stem_str
|
||||||
|
};
|
||||||
|
lib_names.push(lib_name.to_string());
|
||||||
|
}
|
||||||
|
Err(e) => println!("cargo:warning=error={}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lib_names
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_lib_assets(out_dir: &Path) -> Vec<PathBuf> {
|
||||||
|
// Use CARGO_CFG_TARGET_OS to detect TARGET platform, not HOST
|
||||||
|
// This fixes cross-compilation from Linux to Windows
|
||||||
|
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
|
||||||
|
let shared_lib_pattern = if target_os == "windows" {
|
||||||
|
"*.dll"
|
||||||
|
} else if target_os == "macos" {
|
||||||
|
"*.dylib"
|
||||||
|
} else {
|
||||||
|
"*.so"
|
||||||
|
};
|
||||||
|
|
||||||
|
let shared_libs_dir = if target_os == "windows" { "bin" } else { "lib" };
|
||||||
|
let libs_dir = out_dir.join(shared_libs_dir);
|
||||||
|
let pattern = libs_dir.join(shared_lib_pattern);
|
||||||
|
debug_log!("Extract lib assets {}", pattern.display());
|
||||||
|
let mut files = Vec::new();
|
||||||
|
|
||||||
|
for entry in glob(pattern.to_str().unwrap()).unwrap() {
|
||||||
|
match entry {
|
||||||
|
Ok(path) => {
|
||||||
|
files.push(path);
|
||||||
|
}
|
||||||
|
Err(e) => eprintln!("cargo:warning=error={}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
files
|
||||||
|
}
|
||||||
|
|
||||||
|
fn macos_link_search_path() -> Option<String> {
|
||||||
|
let output = Command::new("clang")
|
||||||
|
.arg("--print-search-dirs")
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
if !output.status.success() {
|
||||||
|
println!(
|
||||||
|
"failed to run 'clang --print-search-dirs', continuing without a link search path"
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
|
for line in stdout.lines() {
|
||||||
|
if line.contains("libraries: =") {
|
||||||
|
let path = line.split('=').nth(1)?;
|
||||||
|
return Some(format!("{}/lib/darwin", path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("failed to determine link search path, continuing without it");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_android_ndk(ndk_path: &str) -> Result<(), String> {
|
||||||
|
let ndk_path = Path::new(ndk_path);
|
||||||
|
|
||||||
|
if !ndk_path.exists() {
|
||||||
|
return Err(format!(
|
||||||
|
"Android NDK path does not exist: {}",
|
||||||
|
ndk_path.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let toolchain_file = ndk_path.join("build/cmake/android.toolchain.cmake");
|
||||||
|
if !toolchain_file.exists() {
|
||||||
|
return Err(format!(
|
||||||
|
"Android NDK toolchain file not found: {}\n\
|
||||||
|
This indicates an incomplete NDK installation.",
|
||||||
|
toolchain_file.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_hidden(e: &DirEntry) -> bool {
|
||||||
|
e.file_name()
|
||||||
|
.to_str()
|
||||||
|
.map(|s| s.starts_with('.'))
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
println!("cargo:rerun-if-changed=build.rs");
|
||||||
|
|
||||||
|
let (target_os, target_triple) =
|
||||||
|
parse_target_os().unwrap_or_else(|t| panic!("Failed to parse target os {t}"));
|
||||||
|
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
|
||||||
|
|
||||||
|
let target_dir = get_cargo_target_dir().unwrap();
|
||||||
|
let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get CARGO_MANIFEST_DIR");
|
||||||
|
let llama_src = Path::new(&manifest_dir).join("llama.cpp");
|
||||||
|
let build_shared_libs = cfg!(feature = "dynamic-link");
|
||||||
|
|
||||||
|
let build_shared_libs = std::env::var("LLAMA_BUILD_SHARED_LIBS")
|
||||||
|
.map(|v| v == "1")
|
||||||
|
.unwrap_or(build_shared_libs);
|
||||||
|
let profile = env::var("LLAMA_LIB_PROFILE").unwrap_or("Release".to_string());
|
||||||
|
let static_crt = env::var("LLAMA_STATIC_CRT")
|
||||||
|
.map(|v| v == "1")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=LLAMA_LIB_PROFILE");
|
||||||
|
println!("cargo:rerun-if-env-changed=LLAMA_BUILD_SHARED_LIBS");
|
||||||
|
println!("cargo:rerun-if-env-changed=LLAMA_STATIC_CRT");
|
||||||
|
|
||||||
|
debug_log!("TARGET: {}", target_triple);
|
||||||
|
debug_log!("CARGO_MANIFEST_DIR: {}", manifest_dir);
|
||||||
|
debug_log!("TARGET_DIR: {}", target_dir.display());
|
||||||
|
debug_log!("OUT_DIR: {}", out_dir.display());
|
||||||
|
debug_log!("BUILD_SHARED: {}", build_shared_libs);
|
||||||
|
|
||||||
|
// Make sure that changes to the llama.cpp project trigger a rebuild.
|
||||||
|
let rebuild_on_children_of = [
|
||||||
|
llama_src.join("src"),
|
||||||
|
llama_src.join("ggml/src"),
|
||||||
|
llama_src.join("common"),
|
||||||
|
];
|
||||||
|
for entry in walkdir::WalkDir::new(&llama_src)
|
||||||
|
.into_iter()
|
||||||
|
.filter_entry(|e| !is_hidden(e))
|
||||||
|
{
|
||||||
|
let entry = entry.expect("Failed to obtain entry");
|
||||||
|
let rebuild = entry
|
||||||
|
.file_name()
|
||||||
|
.to_str()
|
||||||
|
.map(|f| f.starts_with("CMake"))
|
||||||
|
.unwrap_or_default()
|
||||||
|
|| rebuild_on_children_of
|
||||||
|
.iter()
|
||||||
|
.any(|src_folder| entry.path().starts_with(src_folder));
|
||||||
|
if rebuild {
|
||||||
|
println!("cargo:rerun-if-changed={}", entry.path().display());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Speed up build
|
||||||
|
env::set_var(
|
||||||
|
"CMAKE_BUILD_PARALLEL_LEVEL",
|
||||||
|
std::thread::available_parallelism()
|
||||||
|
.unwrap()
|
||||||
|
.get()
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
let mut bindings_builder = bindgen::Builder::default()
|
||||||
|
.header("wrapper.h")
|
||||||
|
.clang_arg(format!("-I{}", llama_src.join("include").display()))
|
||||||
|
.clang_arg(format!("-I{}", llama_src.join("ggml/include").display()))
|
||||||
|
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
|
||||||
|
.derive_partialeq(true)
|
||||||
|
.allowlist_function("ggml_.*")
|
||||||
|
.allowlist_type("ggml_.*")
|
||||||
|
.allowlist_function("llama_.*")
|
||||||
|
.allowlist_type("llama_.*")
|
||||||
|
.prepend_enum_name(false);
|
||||||
|
|
||||||
|
// Configure mtmd feature if enabled
|
||||||
|
if cfg!(feature = "mtmd") {
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.header("wrapper_mtmd.h")
|
||||||
|
.allowlist_function("mtmd_.*")
|
||||||
|
.allowlist_type("mtmd_.*");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure Android-specific bindgen settings
|
||||||
|
if matches!(target_os, TargetOs::Android) {
|
||||||
|
// Detect Android NDK from environment variables
|
||||||
|
let android_ndk = env::var("ANDROID_NDK")
|
||||||
|
.or_else(|_| env::var("ANDROID_NDK_ROOT"))
|
||||||
|
.or_else(|_| env::var("NDK_ROOT"))
|
||||||
|
.or_else(|_| env::var("CARGO_NDK_ANDROID_NDK"))
|
||||||
|
.or_else(|_| {
|
||||||
|
// Try to auto-detect NDK from Android SDK
|
||||||
|
if let Some(home) = env::home_dir() {
|
||||||
|
let android_home = env::var("ANDROID_HOME")
|
||||||
|
.or_else(|_| env::var("ANDROID_SDK_ROOT"))
|
||||||
|
.unwrap_or_else(|_| format!("{}/Android/Sdk", home.display()));
|
||||||
|
|
||||||
|
let ndk_dir = format!("{}/ndk", android_home);
|
||||||
|
if let Ok(entries) = std::fs::read_dir(&ndk_dir) {
|
||||||
|
let mut versions: Vec<_> = entries
|
||||||
|
.filter_map(|e| e.ok())
|
||||||
|
.filter(|e| e.file_type().map(|t| t.is_dir()).unwrap_or(false))
|
||||||
|
.filter_map(|e| e.file_name().to_str().map(|s| s.to_string()))
|
||||||
|
.collect();
|
||||||
|
versions.sort();
|
||||||
|
if let Some(latest) = versions.last() {
|
||||||
|
return Ok(format!("{}/{}", ndk_dir, latest));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(env::VarError::NotPresent)
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
panic!(
|
||||||
|
"Android NDK not found. Please set one of: ANDROID_NDK, NDK_ROOT, ANDROID_NDK_ROOT\n\
|
||||||
|
Current target: {}\n\
|
||||||
|
Download from: https://developer.android.com/ndk/downloads",
|
||||||
|
target_triple
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Get Android API level
|
||||||
|
let android_api = env::var("ANDROID_API_LEVEL")
|
||||||
|
.or_else(|_| env::var("ANDROID_PLATFORM").map(|p| p.replace("android-", "")))
|
||||||
|
.or_else(|_| env::var("CARGO_NDK_ANDROID_PLATFORM").map(|p| p.replace("android-", "")))
|
||||||
|
.unwrap_or_else(|_| "28".to_string());
|
||||||
|
|
||||||
|
// Determine host platform
|
||||||
|
let host_tag = if cfg!(target_os = "macos") {
|
||||||
|
"darwin-x86_64"
|
||||||
|
} else if cfg!(target_os = "linux") {
|
||||||
|
"linux-x86_64"
|
||||||
|
} else if cfg!(target_os = "windows") {
|
||||||
|
"windows-x86_64"
|
||||||
|
} else {
|
||||||
|
panic!("Unsupported host platform for Android NDK");
|
||||||
|
};
|
||||||
|
|
||||||
|
// Map Rust target to Android architecture
|
||||||
|
let android_target_prefix = if target_triple.contains("aarch64") {
|
||||||
|
"aarch64-linux-android"
|
||||||
|
} else if target_triple.contains("armv7") {
|
||||||
|
"arm-linux-androideabi"
|
||||||
|
} else if target_triple.contains("x86_64") {
|
||||||
|
"x86_64-linux-android"
|
||||||
|
} else if target_triple.contains("i686") {
|
||||||
|
"i686-linux-android"
|
||||||
|
} else {
|
||||||
|
panic!("Unsupported Android target: {}", target_triple);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Setup Android toolchain paths
|
||||||
|
let toolchain_path = format!("{}/toolchains/llvm/prebuilt/{}", android_ndk, host_tag);
|
||||||
|
let sysroot = format!("{}/sysroot", toolchain_path);
|
||||||
|
|
||||||
|
// Validate toolchain existence
|
||||||
|
if !std::path::Path::new(&toolchain_path).exists() {
|
||||||
|
panic!(
|
||||||
|
"Android NDK toolchain not found at: {}\n\
|
||||||
|
Please ensure you have the correct Android NDK for your platform.",
|
||||||
|
toolchain_path
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find clang builtin includes
|
||||||
|
let clang_builtin_includes = {
|
||||||
|
let clang_lib_path = format!("{}/lib/clang", toolchain_path);
|
||||||
|
std::fs::read_dir(&clang_lib_path).ok().and_then(|entries| {
|
||||||
|
entries
|
||||||
|
.filter_map(|e| e.ok())
|
||||||
|
.find(|entry| {
|
||||||
|
entry.file_type().map(|t| t.is_dir()).unwrap_or(false)
|
||||||
|
&& entry
|
||||||
|
.file_name()
|
||||||
|
.to_str()
|
||||||
|
.map(|name| name.chars().next().unwrap_or('0').is_ascii_digit())
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.and_then(|entry| {
|
||||||
|
let include_path =
|
||||||
|
format!("{}/{}/include", clang_lib_path, entry.file_name().to_str()?);
|
||||||
|
if std::path::Path::new(&include_path).exists() {
|
||||||
|
Some(include_path)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
};
|
||||||
|
|
||||||
|
// Configure bindgen for Android
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.clang_arg(format!("--sysroot={}", sysroot))
|
||||||
|
.clang_arg(format!("-D__ANDROID_API__={}", android_api))
|
||||||
|
.clang_arg("-D__ANDROID__");
|
||||||
|
|
||||||
|
// Add include paths in correct order
|
||||||
|
if let Some(ref builtin_includes) = clang_builtin_includes {
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.clang_arg("-isystem")
|
||||||
|
.clang_arg(builtin_includes);
|
||||||
|
}
|
||||||
|
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.clang_arg("-isystem")
|
||||||
|
.clang_arg(format!("{}/usr/include/{}", sysroot, android_target_prefix))
|
||||||
|
.clang_arg("-isystem")
|
||||||
|
.clang_arg(format!("{}/usr/include", sysroot))
|
||||||
|
.clang_arg("-include")
|
||||||
|
.clang_arg("stdbool.h")
|
||||||
|
.clang_arg("-include")
|
||||||
|
.clang_arg("stdint.h");
|
||||||
|
|
||||||
|
// Set additional clang args for cargo ndk compatibility
|
||||||
|
if env::var("CARGO_SUBCOMMAND").as_deref() == Ok("ndk") {
|
||||||
|
std::env::set_var(
|
||||||
|
"BINDGEN_EXTRA_CLANG_ARGS",
|
||||||
|
format!("--target={}", target_triple),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fix bindgen header discovery on Windows MSVC
|
||||||
|
// Use cc crate to discover MSVC include paths by compiling a dummy file
|
||||||
|
if matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) {
|
||||||
|
// Create a minimal dummy C file to extract compiler flags
|
||||||
|
let out_dir = env::var("OUT_DIR").unwrap();
|
||||||
|
let dummy_c = Path::new(&out_dir).join("dummy.c");
|
||||||
|
std::fs::write(&dummy_c, "int main() { return 0; }").unwrap();
|
||||||
|
|
||||||
|
// Use cc crate to get compiler with proper environment setup
|
||||||
|
let mut build = cc::Build::new();
|
||||||
|
build.file(&dummy_c);
|
||||||
|
|
||||||
|
// Get the actual compiler command cc would use
|
||||||
|
let compiler = build.try_get_compiler().unwrap();
|
||||||
|
|
||||||
|
// Extract include paths by checking compiler's environment
|
||||||
|
// cc crate sets up MSVC environment internally
|
||||||
|
let env_include = compiler
|
||||||
|
.env()
|
||||||
|
.iter()
|
||||||
|
.find(|(k, _)| k.eq_ignore_ascii_case("INCLUDE"))
|
||||||
|
.map(|(_, v)| v);
|
||||||
|
|
||||||
|
if let Some(include_paths) = env_include {
|
||||||
|
for include_path in include_paths
|
||||||
|
.to_string_lossy()
|
||||||
|
.split(';')
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
{
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.clang_arg("-isystem")
|
||||||
|
.clang_arg(include_path);
|
||||||
|
debug_log!("Added MSVC include path: {}", include_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add MSVC compatibility flags
|
||||||
|
bindings_builder = bindings_builder
|
||||||
|
.clang_arg(format!("--target={}", target_triple))
|
||||||
|
.clang_arg("-fms-compatibility")
|
||||||
|
.clang_arg("-fms-extensions");
|
||||||
|
|
||||||
|
debug_log!(
|
||||||
|
"Configured bindgen with MSVC toolchain for target: {}",
|
||||||
|
target_triple
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let bindings = bindings_builder
|
||||||
|
.generate()
|
||||||
|
.expect("Failed to generate bindings");
|
||||||
|
|
||||||
|
// Write the generated bindings to an output file
|
||||||
|
let bindings_path = out_dir.join("bindings.rs");
|
||||||
|
bindings
|
||||||
|
.write_to_file(bindings_path)
|
||||||
|
.expect("Failed to write bindings");
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-changed=wrapper.h");
|
||||||
|
println!("cargo:rerun-if-changed=wrapper_mtmd.h");
|
||||||
|
|
||||||
|
debug_log!("Bindings Created");
|
||||||
|
|
||||||
|
// Build with Cmake
|
||||||
|
|
||||||
|
let mut config = Config::new(&llama_src);
|
||||||
|
|
||||||
|
// Would require extra source files to pointlessly
|
||||||
|
// be included in what's uploaded to and downloaded from
|
||||||
|
// crates.io, so deactivating these instead
|
||||||
|
config.define("LLAMA_BUILD_TESTS", "OFF");
|
||||||
|
config.define("LLAMA_BUILD_EXAMPLES", "OFF");
|
||||||
|
config.define("LLAMA_BUILD_SERVER", "OFF");
|
||||||
|
config.define("LLAMA_BUILD_TOOLS", "OFF");
|
||||||
|
config.define("LLAMA_CURL", "OFF");
|
||||||
|
|
||||||
|
if cfg!(feature = "mtmd") {
|
||||||
|
config.define("LLAMA_BUILD_COMMON", "ON");
|
||||||
|
// mtmd support in llama-cpp is within the tools directory
|
||||||
|
config.define("LLAMA_BUILD_TOOLS", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass CMAKE_ environment variables down to CMake
|
||||||
|
for (key, value) in env::vars() {
|
||||||
|
if key.starts_with("CMAKE_") {
|
||||||
|
config.define(&key, &value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the target-cpu config value, if specified
|
||||||
|
let target_cpu = std::env::var("CARGO_ENCODED_RUSTFLAGS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|rustflags| {
|
||||||
|
rustflags
|
||||||
|
.split('\x1f')
|
||||||
|
.find(|f| f.contains("target-cpu="))
|
||||||
|
.and_then(|f| f.split("target-cpu=").nth(1))
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
if target_cpu == Some("native".into()) {
|
||||||
|
debug_log!("Detected target-cpu=native, compiling with GGML_NATIVE");
|
||||||
|
config.define("GGML_NATIVE", "ON");
|
||||||
|
}
|
||||||
|
// if native isn't specified, enable specific features for ggml instead
|
||||||
|
else {
|
||||||
|
// rust code isn't using `target-cpu=native`, so llama.cpp shouldn't use GGML_NATIVE either
|
||||||
|
config.define("GGML_NATIVE", "OFF");
|
||||||
|
|
||||||
|
// if `target-cpu` is set set, also set -march for llama.cpp to the same value
|
||||||
|
if let Some(ref cpu) = target_cpu {
|
||||||
|
debug_log!("Setting baseline architecture: -march={}", cpu);
|
||||||
|
config.cflag(&format!("-march={}", cpu));
|
||||||
|
config.cxxflag(&format!("-march={}", cpu));
|
||||||
|
}
|
||||||
|
|
||||||
|
// I expect this env var to always be present
|
||||||
|
let features = std::env::var("CARGO_CFG_TARGET_FEATURE")
|
||||||
|
.expect("Env var CARGO_CFG_TARGET_FEATURE not found.");
|
||||||
|
debug_log!("Compiling with target features: {}", features);
|
||||||
|
|
||||||
|
// list of rust target_features here:
|
||||||
|
// https://doc.rust-lang.org/reference/attributes/codegen.html#the-target_feature-attribute
|
||||||
|
// GGML config flags have been found by looking at:
|
||||||
|
// llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt
|
||||||
|
for feature in features.split(',') {
|
||||||
|
match feature {
|
||||||
|
"avx" => {
|
||||||
|
config.define("GGML_AVX", "ON");
|
||||||
|
}
|
||||||
|
"avx2" => {
|
||||||
|
config.define("GGML_AVX2", "ON");
|
||||||
|
}
|
||||||
|
"avx512bf16" => {
|
||||||
|
config.define("GGML_AVX512_BF16", "ON");
|
||||||
|
}
|
||||||
|
"avx512vbmi" => {
|
||||||
|
config.define("GGML_AVX512_VBMI", "ON");
|
||||||
|
}
|
||||||
|
"avx512vnni" => {
|
||||||
|
config.define("GGML_AVX512_VNNI", "ON");
|
||||||
|
}
|
||||||
|
"avxvnni" => {
|
||||||
|
config.define("GGML_AVX_VNNI", "ON");
|
||||||
|
}
|
||||||
|
"bmi2" => {
|
||||||
|
config.define("GGML_BMI2", "ON");
|
||||||
|
}
|
||||||
|
"f16c" => {
|
||||||
|
config.define("GGML_F16C", "ON");
|
||||||
|
}
|
||||||
|
"fma" => {
|
||||||
|
config.define("GGML_FMA", "ON");
|
||||||
|
}
|
||||||
|
"sse4.2" => {
|
||||||
|
config.define("GGML_SSE42", "ON");
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
debug_log!(
|
||||||
|
"Unrecognized cpu feature: '{}' - skipping GGML config for it.",
|
||||||
|
feature
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.define(
|
||||||
|
"BUILD_SHARED_LIBS",
|
||||||
|
if build_shared_libs { "ON" } else { "OFF" },
|
||||||
|
);
|
||||||
|
|
||||||
|
if matches!(target_os, TargetOs::Apple(_)) {
|
||||||
|
config.define("GGML_BLAS", "OFF");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc))
|
||||||
|
&& matches!(
|
||||||
|
profile.as_str(),
|
||||||
|
"Release" | "RelWithDebInfo" | "MinSizeRel"
|
||||||
|
))
|
||||||
|
{
|
||||||
|
// Debug Rust builds under MSVC turn off optimization even though we're ideally building the release profile of llama.cpp.
|
||||||
|
// Looks like an upstream bug:
|
||||||
|
// https://github.com/rust-lang/cmake-rs/issues/240
|
||||||
|
// For now explicitly reinject the optimization flags that a CMake Release build is expected to have on in this scenario.
|
||||||
|
// This fixes CPU inference performance when part of a Rust debug build.
|
||||||
|
for flag in &["/O2", "/DNDEBUG", "/Ob2"] {
|
||||||
|
config.cflag(flag);
|
||||||
|
config.cxxflag(flag);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.static_crt(static_crt);
|
||||||
|
|
||||||
|
if matches!(target_os, TargetOs::Android) {
|
||||||
|
// Android NDK Build Configuration
|
||||||
|
let android_ndk = env::var("ANDROID_NDK")
|
||||||
|
.or_else(|_| env::var("NDK_ROOT"))
|
||||||
|
.or_else(|_| env::var("ANDROID_NDK_ROOT"))
|
||||||
|
.unwrap_or_else(|_| {
|
||||||
|
panic!(
|
||||||
|
"Android NDK not found. Please set one of: ANDROID_NDK, NDK_ROOT, ANDROID_NDK_ROOT\n\
|
||||||
|
Download from: https://developer.android.com/ndk/downloads"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Validate NDK installation
|
||||||
|
if let Err(error) = validate_android_ndk(&android_ndk) {
|
||||||
|
panic!("Android NDK validation failed: {}", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rerun build script if NDK environment variables change
|
||||||
|
println!("cargo:rerun-if-env-changed=ANDROID_NDK");
|
||||||
|
println!("cargo:rerun-if-env-changed=NDK_ROOT");
|
||||||
|
println!("cargo:rerun-if-env-changed=ANDROID_NDK_ROOT");
|
||||||
|
|
||||||
|
// Set CMake toolchain file for Android
|
||||||
|
let toolchain_file = format!("{}/build/cmake/android.toolchain.cmake", android_ndk);
|
||||||
|
config.define("CMAKE_TOOLCHAIN_FILE", &toolchain_file);
|
||||||
|
|
||||||
|
// Configure Android platform (API level)
|
||||||
|
let android_platform = env::var("ANDROID_PLATFORM").unwrap_or_else(|_| {
|
||||||
|
env::var("ANDROID_API_LEVEL")
|
||||||
|
.map(|level| format!("android-{}", level))
|
||||||
|
.unwrap_or_else(|_| "android-28".to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
println!("cargo:rerun-if-env-changed=ANDROID_PLATFORM");
|
||||||
|
println!("cargo:rerun-if-env-changed=ANDROID_API_LEVEL");
|
||||||
|
config.define("ANDROID_PLATFORM", &android_platform);
|
||||||
|
|
||||||
|
// Map Rust target to Android ABI
|
||||||
|
let android_abi = if target_triple.contains("aarch64") {
|
||||||
|
"arm64-v8a"
|
||||||
|
} else if target_triple.contains("armv7") {
|
||||||
|
"armeabi-v7a"
|
||||||
|
} else if target_triple.contains("x86_64") {
|
||||||
|
"x86_64"
|
||||||
|
} else if target_triple.contains("i686") {
|
||||||
|
"x86"
|
||||||
|
} else {
|
||||||
|
panic!(
|
||||||
|
"Unsupported Android target: {}\n\
|
||||||
|
Supported targets: aarch64-linux-android, armv7-linux-androideabi, i686-linux-android, x86_64-linux-android",
|
||||||
|
target_triple
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
config.define("ANDROID_ABI", android_abi);
|
||||||
|
|
||||||
|
// Configure architecture-specific compiler flags
|
||||||
|
match android_abi {
|
||||||
|
"arm64-v8a" => {
|
||||||
|
config.cflag("-march=armv8-a");
|
||||||
|
config.cxxflag("-march=armv8-a");
|
||||||
|
}
|
||||||
|
"armeabi-v7a" => {
|
||||||
|
config.cflag("-march=armv7-a");
|
||||||
|
config.cxxflag("-march=armv7-a");
|
||||||
|
config.cflag("-mfpu=neon");
|
||||||
|
config.cxxflag("-mfpu=neon");
|
||||||
|
config.cflag("-mthumb");
|
||||||
|
config.cxxflag("-mthumb");
|
||||||
|
}
|
||||||
|
"x86_64" => {
|
||||||
|
config.cflag("-march=x86-64");
|
||||||
|
config.cxxflag("-march=x86-64");
|
||||||
|
}
|
||||||
|
"x86" => {
|
||||||
|
config.cflag("-march=i686");
|
||||||
|
config.cxxflag("-march=i686");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Android-specific CMake configurations
|
||||||
|
config.define("GGML_LLAMAFILE", "OFF");
|
||||||
|
|
||||||
|
// Link Android system libraries
|
||||||
|
println!("cargo:rustc-link-lib=log");
|
||||||
|
println!("cargo:rustc-link-lib=android");
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches!(target_os, TargetOs::Linux)
|
||||||
|
&& target_triple.contains("aarch64")
|
||||||
|
&& target_cpu != Some("native".into())
|
||||||
|
{
|
||||||
|
// If the target-cpu is not specified as native, we take off the native ARM64 support.
|
||||||
|
// It is useful in docker environments where the native feature is not enabled.
|
||||||
|
config.define("GGML_NATIVE", "OFF");
|
||||||
|
config.define("GGML_CPU_ARM_ARCH", "armv8-a");
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(feature = "vulkan") {
|
||||||
|
config.define("GGML_VULKAN", "ON");
|
||||||
|
match target_os {
|
||||||
|
TargetOs::Windows(_) => {
|
||||||
|
let vulkan_path = env::var("VULKAN_SDK").expect(
|
||||||
|
"Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set",
|
||||||
|
);
|
||||||
|
let vulkan_lib_path = Path::new(&vulkan_path).join("Lib");
|
||||||
|
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
|
||||||
|
println!("cargo:rustc-link-lib=vulkan-1");
|
||||||
|
|
||||||
|
// workaround for this error: "FileTracker : error FTK1011: could not create the new file tracking log file"
|
||||||
|
// it has to do with MSBuild FileTracker not respecting the path
|
||||||
|
// limit configuration set in the windows registry.
|
||||||
|
// I'm not sure why that's a thing, but this makes my builds work.
|
||||||
|
// (crates that depend on llama-cpp-rs w/ vulkan easily exceed the default PATH_MAX on windows)
|
||||||
|
env::set_var("TrackFileAccess", "false");
|
||||||
|
// since we disabled TrackFileAccess, we can now run into problems with parallel
|
||||||
|
// access to pdb files. /FS solves this.
|
||||||
|
config.cflag("/FS");
|
||||||
|
config.cxxflag("/FS");
|
||||||
|
}
|
||||||
|
TargetOs::Linux => {
|
||||||
|
// If we are not using system provided vulkan SDK, add vulkan libs for linking
|
||||||
|
if let Ok(vulkan_path) = env::var("VULKAN_SDK") {
|
||||||
|
let vulkan_lib_path = Path::new(&vulkan_path).join("lib");
|
||||||
|
println!("cargo:rustc-link-search={}", vulkan_lib_path.display());
|
||||||
|
}
|
||||||
|
println!("cargo:rustc-link-lib=vulkan");
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(feature = "cuda") {
|
||||||
|
config.define("GGML_CUDA", "ON");
|
||||||
|
|
||||||
|
if cfg!(feature = "cuda-no-vmm") {
|
||||||
|
config.define("GGML_CUDA_NO_VMM", "ON");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Android doesn't have OpenMP support AFAICT and openmp is a default feature. Do this here
|
||||||
|
// rather than modifying the defaults in Cargo.toml just in case someone enables the OpenMP feature
|
||||||
|
// and tries to build for Android anyway.
|
||||||
|
if cfg!(feature = "openmp") && !matches!(target_os, TargetOs::Android) {
|
||||||
|
config.define("GGML_OPENMP", "ON");
|
||||||
|
} else {
|
||||||
|
config.define("GGML_OPENMP", "OFF");
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(feature = "system-ggml") {
|
||||||
|
config.define("LLAMA_USE_SYSTEM_GGML", "ON");
|
||||||
|
}
|
||||||
|
|
||||||
|
// General
|
||||||
|
config
|
||||||
|
.profile(&profile)
|
||||||
|
.very_verbose(std::env::var("CMAKE_VERBOSE").is_ok()) // Not verbose by default
|
||||||
|
.always_configure(false);
|
||||||
|
|
||||||
|
let build_dir = config.build();
|
||||||
|
|
||||||
|
// Search paths
|
||||||
|
println!("cargo:rustc-link-search={}", out_dir.join("lib").display());
|
||||||
|
println!(
|
||||||
|
"cargo:rustc-link-search={}",
|
||||||
|
out_dir.join("lib64").display()
|
||||||
|
);
|
||||||
|
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||||
|
|
||||||
|
if cfg!(feature = "system-ggml") {
|
||||||
|
// Extract library directory from CMake's found GGML package
|
||||||
|
let cmake_cache = build_dir.join("build").join("CMakeCache.txt");
|
||||||
|
if let Ok(cache_contents) = std::fs::read_to_string(&cmake_cache) {
|
||||||
|
let mut ggml_lib_dirs = std::collections::HashSet::new();
|
||||||
|
|
||||||
|
// Parse CMakeCache.txt to find where GGML libraries were found
|
||||||
|
for line in cache_contents.lines() {
|
||||||
|
if line.starts_with("GGML_LIBRARY:")
|
||||||
|
|| line.starts_with("GGML_BASE_LIBRARY:")
|
||||||
|
|| line.starts_with("GGML_CPU_LIBRARY:")
|
||||||
|
{
|
||||||
|
if let Some(lib_path) = line.split('=').nth(1) {
|
||||||
|
if let Some(parent) = Path::new(lib_path).parent() {
|
||||||
|
ggml_lib_dirs.insert(parent.to_path_buf());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add each unique library directory to the search path
|
||||||
|
for lib_dir in ggml_lib_dirs {
|
||||||
|
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
||||||
|
debug_log!("Added system GGML library path: {}", lib_dir.display());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(feature = "cuda") && !build_shared_libs {
|
||||||
|
// Re-run build script if CUDA_PATH environment variable changes
|
||||||
|
println!("cargo:rerun-if-env-changed=CUDA_PATH");
|
||||||
|
|
||||||
|
// Add CUDA library directories to the linker search path
|
||||||
|
for lib_dir in find_cuda_helper::find_cuda_lib_dirs() {
|
||||||
|
println!("cargo:rustc-link-search=native={}", lib_dir.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Platform-specific linking
|
||||||
|
if cfg!(target_os = "windows") {
|
||||||
|
// ✅ On Windows, use dynamic linking.
|
||||||
|
// Static linking is problematic because NVIDIA does not provide culibos.lib,
|
||||||
|
// and static CUDA libraries (like cublas_static.lib) are usually not shipped.
|
||||||
|
|
||||||
|
println!("cargo:rustc-link-lib=cudart"); // Links to cudart64_*.dll
|
||||||
|
println!("cargo:rustc-link-lib=cublas"); // Links to cublas64_*.dll
|
||||||
|
println!("cargo:rustc-link-lib=cublasLt"); // Links to cublasLt64_*.dll
|
||||||
|
|
||||||
|
// Link to CUDA driver API (nvcuda.dll via cuda.lib)
|
||||||
|
if !cfg!(feature = "cuda-no-vmm") {
|
||||||
|
println!("cargo:rustc-link-lib=cuda");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// ✅ On non-Windows platforms (e.g., Linux), static linking is preferred and supported.
|
||||||
|
// Static libraries like cudart_static and cublas_static depend on culibos.
|
||||||
|
|
||||||
|
println!("cargo:rustc-link-lib=static=cudart_static");
|
||||||
|
println!("cargo:rustc-link-lib=static=cublas_static");
|
||||||
|
println!("cargo:rustc-link-lib=static=cublasLt_static");
|
||||||
|
|
||||||
|
// Link to CUDA driver API (libcuda.so)
|
||||||
|
if !cfg!(feature = "cuda-no-vmm") {
|
||||||
|
println!("cargo:rustc-link-lib=cuda");
|
||||||
|
}
|
||||||
|
|
||||||
|
// culibos is required when statically linking cudart_static
|
||||||
|
println!("cargo:rustc-link-lib=static=culibos");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link libraries
|
||||||
|
let llama_libs_kind = if build_shared_libs || cfg!(feature = "system-ggml") {
|
||||||
|
"dylib"
|
||||||
|
} else {
|
||||||
|
"static"
|
||||||
|
};
|
||||||
|
let llama_libs = extract_lib_names(&out_dir, build_shared_libs);
|
||||||
|
assert_ne!(llama_libs.len(), 0);
|
||||||
|
|
||||||
|
if cfg!(feature = "system-ggml") {
|
||||||
|
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml");
|
||||||
|
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml-base");
|
||||||
|
println!("cargo:rustc-link-lib={llama_libs_kind}=ggml-cpu");
|
||||||
|
}
|
||||||
|
for lib in llama_libs {
|
||||||
|
let link = format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib);
|
||||||
|
debug_log!("LINK {link}",);
|
||||||
|
println!("{link}",);
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenMP
|
||||||
|
if cfg!(feature = "openmp") && target_triple.contains("gnu") {
|
||||||
|
println!("cargo:rustc-link-lib=gomp");
|
||||||
|
}
|
||||||
|
|
||||||
|
match target_os {
|
||||||
|
TargetOs::Windows(WindowsVariant::Msvc) => {
|
||||||
|
println!("cargo:rustc-link-lib=advapi32");
|
||||||
|
if cfg!(debug_assertions) {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=msvcrtd");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TargetOs::Linux => {
|
||||||
|
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||||
|
}
|
||||||
|
TargetOs::Apple(variant) => {
|
||||||
|
println!("cargo:rustc-link-lib=framework=Foundation");
|
||||||
|
println!("cargo:rustc-link-lib=framework=Metal");
|
||||||
|
println!("cargo:rustc-link-lib=framework=MetalKit");
|
||||||
|
println!("cargo:rustc-link-lib=framework=Accelerate");
|
||||||
|
println!("cargo:rustc-link-lib=c++");
|
||||||
|
|
||||||
|
match variant {
|
||||||
|
AppleVariant::MacOS => {
|
||||||
|
// On (older) OSX we need to link against the clang runtime,
|
||||||
|
// which is hidden in some non-default path.
|
||||||
|
//
|
||||||
|
// More details at https://github.com/alexcrichton/curl-rust/issues/279.
|
||||||
|
if let Some(path) = macos_link_search_path() {
|
||||||
|
println!("cargo:rustc-link-lib=clang_rt.osx");
|
||||||
|
println!("cargo:rustc-link-search={}", path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
AppleVariant::Other => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy DLLs to target
|
||||||
|
if build_shared_libs {
|
||||||
|
let libs_assets = extract_lib_assets(&out_dir);
|
||||||
|
for asset in libs_assets {
|
||||||
|
let asset_clone = asset.clone();
|
||||||
|
let filename = asset_clone.file_name().unwrap();
|
||||||
|
let filename = filename.to_str().unwrap();
|
||||||
|
let dst = target_dir.join(filename);
|
||||||
|
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
||||||
|
if !dst.exists() {
|
||||||
|
std::fs::hard_link(asset.clone(), dst).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy DLLs to examples as well
|
||||||
|
if target_dir.join("examples").exists() {
|
||||||
|
let dst = target_dir.join("examples").join(filename);
|
||||||
|
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
||||||
|
if !dst.exists() {
|
||||||
|
std::fs::hard_link(asset.clone(), dst).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy DLLs to target/profile/deps as well for tests
|
||||||
|
let dst = target_dir.join("deps").join(filename);
|
||||||
|
debug_log!("HARD LINK {} TO {}", asset.display(), dst.display());
|
||||||
|
if !dst.exists() {
|
||||||
|
std::fs::hard_link(asset.clone(), dst).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||||
|
project("llama.cpp" C CXX)
|
||||||
|
include(CheckIncludeFileCXX)
|
||||||
|
|
||||||
|
#set(CMAKE_WARN_DEPRECATED YES)
|
||||||
|
set(CMAKE_WARN_UNUSED_CLI YES)
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
|
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
|
||||||
|
|
||||||
|
# Add path to modules
|
||||||
|
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||||
|
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
set(LLAMA_STANDALONE ON)
|
||||||
|
|
||||||
|
include(git-vars)
|
||||||
|
|
||||||
|
# configure project version
|
||||||
|
# TODO
|
||||||
|
else()
|
||||||
|
set(LLAMA_STANDALONE OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF)
|
||||||
|
|
||||||
|
option(LLAMA_WASM_MEM64 "llama: use 64-bit memory in WASM builds" ON)
|
||||||
|
|
||||||
|
if (EMSCRIPTEN)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
|
||||||
|
# Use 64-bit memory to support backend_get_memory queries
|
||||||
|
# TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using
|
||||||
|
if (LLAMA_WASM_MEM64)
|
||||||
|
add_compile_options("-sMEMORY64=1")
|
||||||
|
add_link_options("-sMEMORY64=1")
|
||||||
|
endif()
|
||||||
|
add_link_options("-sALLOW_MEMORY_GROWTH=1")
|
||||||
|
|
||||||
|
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF)
|
||||||
|
option(LLAMA_BUILD_HTML "llama: build HTML file" ON)
|
||||||
|
if (LLAMA_BUILD_HTML)
|
||||||
|
set(CMAKE_EXECUTABLE_SUFFIX ".html")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
if (MINGW)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_STANDALONE)
|
||||||
|
# enable parallel builds for msbuild
|
||||||
|
list(APPEND CMAKE_VS_GLOBALS UseMultiToolTask=true)
|
||||||
|
list(APPEND CMAKE_VS_GLOBALS EnforceProcessCountAcrossBuilds=true)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||||
|
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(LLAMA_TOOLS_INSTALL_DEFAULT ${LLAMA_STANDALONE})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# option list
|
||||||
|
#
|
||||||
|
|
||||||
|
# debug
|
||||||
|
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
||||||
|
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
||||||
|
|
||||||
|
# build
|
||||||
|
option(LLAMA_FATAL_WARNINGS "llama: enable -Werror flag" OFF)
|
||||||
|
|
||||||
|
# sanitizers
|
||||||
|
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
||||||
|
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
||||||
|
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||||
|
|
||||||
|
# utils
|
||||||
|
option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE})
|
||||||
|
|
||||||
|
# extra artifacts
|
||||||
|
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||||
|
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||||
|
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||||
|
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||||
|
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||||
|
|
||||||
|
# 3rd party libs
|
||||||
|
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||||
|
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
|
||||||
|
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
|
||||||
|
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||||
|
|
||||||
|
# Required for relocatable CMake package
|
||||||
|
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||||
|
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
|
||||||
|
|
||||||
|
if (NOT DEFINED LLAMA_BUILD_NUMBER)
|
||||||
|
set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER})
|
||||||
|
endif()
|
||||||
|
if (NOT DEFINED LLAMA_BUILD_COMMIT)
|
||||||
|
set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
|
||||||
|
endif()
|
||||||
|
set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER})
|
||||||
|
|
||||||
|
# override ggml options
|
||||||
|
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
|
||||||
|
set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS})
|
||||||
|
|
||||||
|
# change the default for these ggml options
|
||||||
|
if (NOT DEFINED GGML_LLAMAFILE)
|
||||||
|
set(GGML_LLAMAFILE_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT DEFINED GGML_CUDA_GRAPHS)
|
||||||
|
set(GGML_CUDA_GRAPHS_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# transition helpers
|
||||||
|
function (llama_option_depr TYPE OLD NEW)
|
||||||
|
if (${OLD})
|
||||||
|
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
|
||||||
|
set(${NEW} ON PARENT_SCOPE)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
|
||||||
|
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
|
||||||
|
llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
|
||||||
|
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
|
||||||
|
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)
|
||||||
|
llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
|
||||||
|
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
||||||
|
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
||||||
|
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
||||||
|
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (LLAMA_SANITIZE_THREAD)
|
||||||
|
message(STATUS "Using -fsanitize=thread")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=thread)
|
||||||
|
link_libraries (-fsanitize=thread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_ADDRESS)
|
||||||
|
message(STATUS "Using -fsanitize=address")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||||
|
link_libraries (-fsanitize=address)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_UNDEFINED)
|
||||||
|
message(STATUS "Using -fsanitize=undefined")
|
||||||
|
|
||||||
|
add_compile_options(-fsanitize=undefined)
|
||||||
|
link_libraries (-fsanitize=undefined)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include("cmake/license.cmake")
|
||||||
|
license_add_file("llama.cpp" "LICENSE")
|
||||||
|
|
||||||
|
#
|
||||||
|
# 3rd-party
|
||||||
|
#
|
||||||
|
|
||||||
|
if (LLAMA_USE_SYSTEM_GGML)
|
||||||
|
message(STATUS "Using system-provided libggml, skipping ggml build")
|
||||||
|
find_package(ggml REQUIRED)
|
||||||
|
add_library(ggml ALIAS ggml::ggml)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML)
|
||||||
|
set(GGML_BUILD_NUMBER ${LLAMA_BUILD_NUMBER})
|
||||||
|
set(GGML_BUILD_COMMIT ${LLAMA_BUILD_COMMIT})
|
||||||
|
add_subdirectory(ggml)
|
||||||
|
# ... otherwise assume ggml is added by a parent CMakeLists.txt
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# build the library
|
||||||
|
#
|
||||||
|
|
||||||
|
add_subdirectory(src)
|
||||||
|
|
||||||
|
#
|
||||||
|
# utils, programs, examples and tests
|
||||||
|
#
|
||||||
|
|
||||||
|
if (NOT LLAMA_BUILD_COMMON)
|
||||||
|
message(STATUS "LLAMA_BUILD_COMMON is OFF, disabling LLAMA_CURL")
|
||||||
|
set(LLAMA_CURL OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON)
|
||||||
|
add_subdirectory(common)
|
||||||
|
if (LLAMA_HTTPLIB)
|
||||||
|
add_subdirectory(vendor/cpp-httplib)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||||
|
include(CTest)
|
||||||
|
add_subdirectory(tests)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
|
||||||
|
add_subdirectory(examples)
|
||||||
|
add_subdirectory(pocs)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS)
|
||||||
|
add_subdirectory(tools)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Automatically add all files from the 'licenses' directory
|
||||||
|
file(GLOB EXTRA_LICENSES "${CMAKE_SOURCE_DIR}/licenses/LICENSE-*")
|
||||||
|
|
||||||
|
foreach(FILE_PATH ${EXTRA_LICENSES})
|
||||||
|
get_filename_component(FILE_NAME "${FILE_PATH}" NAME)
|
||||||
|
string(REGEX REPLACE "^LICENSE-" "" NAME "${FILE_NAME}")
|
||||||
|
license_add_file("${NAME}" "${FILE_PATH}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
if (LLAMA_BUILD_COMMON)
|
||||||
|
license_generate(common)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# install
|
||||||
|
#
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
|
set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||||
|
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||||
|
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||||
|
|
||||||
|
set(LLAMA_PUBLIC_HEADERS
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
|
||||||
|
|
||||||
|
set_target_properties(llama
|
||||||
|
PROPERTIES
|
||||||
|
PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
||||||
|
|
||||||
|
install(TARGETS llama LIBRARY PUBLIC_HEADER)
|
||||||
|
|
||||||
|
configure_package_config_file(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cmake/llama-config.cmake.in
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
||||||
|
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama
|
||||||
|
PATH_VARS LLAMA_INCLUDE_INSTALL_DIR
|
||||||
|
LLAMA_LIB_INSTALL_DIR
|
||||||
|
LLAMA_BIN_INSTALL_DIR )
|
||||||
|
|
||||||
|
write_basic_package_version_file(
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake
|
||||||
|
VERSION ${LLAMA_INSTALL_VERSION}
|
||||||
|
COMPATIBILITY SameMajorVersion)
|
||||||
|
|
||||||
|
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/llama-config.cmake
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/llama-version.cmake
|
||||||
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama)
|
||||||
|
|
||||||
|
install(
|
||||||
|
FILES convert_hf_to_gguf.py
|
||||||
|
PERMISSIONS
|
||||||
|
OWNER_READ
|
||||||
|
OWNER_WRITE
|
||||||
|
OWNER_EXECUTE
|
||||||
|
GROUP_READ
|
||||||
|
GROUP_EXECUTE
|
||||||
|
WORLD_READ
|
||||||
|
WORLD_EXECUTE
|
||||||
|
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||||
|
|
||||||
|
configure_file(cmake/llama.pc.in
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
||||||
|
@ONLY)
|
||||||
|
|
||||||
|
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc"
|
||||||
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
set( CMAKE_SYSTEM_NAME Darwin )
|
||||||
|
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
||||||
|
|
||||||
|
set( target arm64-apple-darwin-macho )
|
||||||
|
|
||||||
|
set( CMAKE_C_COMPILER clang )
|
||||||
|
set( CMAKE_CXX_COMPILER clang++ )
|
||||||
|
|
||||||
|
set( CMAKE_C_COMPILER_TARGET ${target} )
|
||||||
|
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
||||||
|
|
||||||
|
set( arch_c_flags "-march=armv8.4-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
||||||
|
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function" )
|
||||||
|
|
||||||
|
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||||
|
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
set( CMAKE_SYSTEM_NAME Windows )
|
||||||
|
set( CMAKE_SYSTEM_PROCESSOR arm64 )
|
||||||
|
|
||||||
|
set( target arm64-pc-windows-msvc )
|
||||||
|
|
||||||
|
set( CMAKE_C_COMPILER clang )
|
||||||
|
set( CMAKE_CXX_COMPILER clang++ )
|
||||||
|
|
||||||
|
set( CMAKE_C_COMPILER_TARGET ${target} )
|
||||||
|
set( CMAKE_CXX_COMPILER_TARGET ${target} )
|
||||||
|
|
||||||
|
set( arch_c_flags "-march=armv8.7-a -fvectorize -ffp-model=fast -fno-finite-math-only" )
|
||||||
|
set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function -Wno-gnu-zero-variadic-macro-arguments" )
|
||||||
|
|
||||||
|
set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||||
|
set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" )
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
set(BUILD_NUMBER 0)
|
||||||
|
set(BUILD_COMMIT "unknown")
|
||||||
|
set(BUILD_COMPILER "unknown")
|
||||||
|
set(BUILD_TARGET "unknown")
|
||||||
|
|
||||||
|
# Look for git
|
||||||
|
find_package(Git)
|
||||||
|
if(NOT Git_FOUND)
|
||||||
|
find_program(GIT_EXECUTABLE NAMES git git.exe)
|
||||||
|
if(GIT_EXECUTABLE)
|
||||||
|
set(Git_FOUND TRUE)
|
||||||
|
message(STATUS "Found Git: ${GIT_EXECUTABLE}")
|
||||||
|
else()
|
||||||
|
message(WARNING "Git not found. Build info will not be accurate.")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Get the commit count and hash
|
||||||
|
if(Git_FOUND)
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
OUTPUT_VARIABLE HEAD
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
RESULT_VARIABLE RES
|
||||||
|
)
|
||||||
|
if (RES EQUAL 0)
|
||||||
|
set(BUILD_COMMIT ${HEAD})
|
||||||
|
endif()
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${GIT_EXECUTABLE} rev-list --count HEAD
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
OUTPUT_VARIABLE COUNT
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
RESULT_VARIABLE RES
|
||||||
|
)
|
||||||
|
if (RES EQUAL 0)
|
||||||
|
set(BUILD_NUMBER ${COUNT})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
|
||||||
|
|
||||||
|
if(CMAKE_VS_PLATFORM_NAME)
|
||||||
|
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||||
|
else()
|
||||||
|
set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
endif()
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
include("ggml/cmake/common.cmake")
|
||||||
|
|
||||||
|
function(llama_add_compile_flags)
|
||||||
|
if (LLAMA_FATAL_WARNINGS)
|
||||||
|
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||||
|
list(APPEND C_FLAGS -Werror)
|
||||||
|
list(APPEND CXX_FLAGS -Werror)
|
||||||
|
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||||
|
add_compile_options(/WX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_ALL_WARNINGS)
|
||||||
|
if (NOT MSVC)
|
||||||
|
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
|
||||||
|
-Werror=implicit-int -Werror=implicit-function-declaration)
|
||||||
|
|
||||||
|
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
|
||||||
|
|
||||||
|
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
|
||||||
|
|
||||||
|
list(APPEND C_FLAGS ${WARNING_FLAGS})
|
||||||
|
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
|
||||||
|
|
||||||
|
ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
|
||||||
|
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
|
||||||
|
else()
|
||||||
|
# todo : msvc
|
||||||
|
set(C_FLAGS "" PARENT_SCOPE)
|
||||||
|
set(CXX_FLAGS "" PARENT_SCOPE)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
find_package(Git)
|
||||||
|
|
||||||
|
# the commit's SHA1
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_SHA1
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
|
||||||
|
# the date of the commit
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_DATE
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
|
||||||
|
# the subject of the commit
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" log -1 --format=%s
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
define_property(GLOBAL PROPERTY LICENSE_TEXT
|
||||||
|
BRIEF_DOCS "Embedded licenses"
|
||||||
|
FULL_DOCS "Global string containing all aggregated licenses"
|
||||||
|
)
|
||||||
|
|
||||||
|
function(license_add_file NAME FILE)
|
||||||
|
if(NOT IS_ABSOLUTE "${FILE}")
|
||||||
|
set(FILE "${CMAKE_CURRENT_SOURCE_DIR}/${FILE}")
|
||||||
|
endif()
|
||||||
|
if(EXISTS "${FILE}")
|
||||||
|
set(TITLE "License for ${NAME}")
|
||||||
|
string(REGEX REPLACE "." "=" UNDERLINE "${TITLE}")
|
||||||
|
file(READ "${FILE}" TEXT)
|
||||||
|
get_property(TMP GLOBAL PROPERTY LICENSE_TEXT)
|
||||||
|
string(APPEND TMP "R\"=L=(${TITLE}\n${UNDERLINE}\n\n${TEXT})=L=\",\n")
|
||||||
|
set_property(GLOBAL PROPERTY LICENSE_TEXT "${TMP}")
|
||||||
|
else()
|
||||||
|
message(WARNING "License file '${FILE}' not found")
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(license_generate TARGET_NAME)
|
||||||
|
message(STATUS "Generating embedded license file for target: ${TARGET_NAME}")
|
||||||
|
get_property(TEXT GLOBAL PROPERTY LICENSE_TEXT)
|
||||||
|
|
||||||
|
set(CPP_CONTENT "// Generated by CMake\n\n")
|
||||||
|
string(APPEND CPP_CONTENT "const char* LICENSES[] = {\n")
|
||||||
|
string(APPEND CPP_CONTENT "${TEXT}")
|
||||||
|
string(APPEND CPP_CONTENT "nullptr\n")
|
||||||
|
string(APPEND CPP_CONTENT "};\n")
|
||||||
|
|
||||||
|
set(CPP_FILE "${CMAKE_BINARY_DIR}/license.cpp")
|
||||||
|
file(WRITE "${CPP_FILE}" "${CPP_CONTENT}")
|
||||||
|
|
||||||
|
if(TARGET ${TARGET_NAME})
|
||||||
|
target_sources(${TARGET_NAME} PRIVATE "${CPP_FILE}")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Target '${TARGET_NAME}' does not exist")
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
set(LLAMA_VERSION @LLAMA_INSTALL_VERSION@)
|
||||||
|
set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
|
||||||
|
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
|
||||||
|
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||||
|
|
||||||
|
@PACKAGE_INIT@
|
||||||
|
|
||||||
|
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
|
||||||
|
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
|
||||||
|
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
||||||
|
|
||||||
|
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
|
||||||
|
|
||||||
|
find_library(llama_LIBRARY llama
|
||||||
|
REQUIRED
|
||||||
|
HINTS ${LLAMA_LIB_DIR}
|
||||||
|
NO_CMAKE_FIND_ROOT_PATH
|
||||||
|
)
|
||||||
|
|
||||||
|
add_library(llama UNKNOWN IMPORTED)
|
||||||
|
set_target_properties(llama
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
|
||||||
|
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
||||||
|
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||||
|
IMPORTED_LOCATION "${llama_LIBRARY}"
|
||||||
|
INTERFACE_COMPILE_FEATURES c_std_90
|
||||||
|
POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
check_required_components(Llama)
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
prefix=@CMAKE_INSTALL_PREFIX@
|
||||||
|
exec_prefix=@CMAKE_INSTALL_PREFIX@
|
||||||
|
libdir=@CMAKE_INSTALL_FULL_LIBDIR@
|
||||||
|
includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@
|
||||||
|
|
||||||
|
Name: llama
|
||||||
|
Description: Port of Facebook's LLaMA model in C/C++
|
||||||
|
Version: @LLAMA_INSTALL_VERSION@
|
||||||
|
Libs: -L${libdir} -lggml -lggml-base -lllama
|
||||||
|
Cflags: -I${includedir}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
set(CMAKE_SYSTEM_NAME Linux)
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR riscv64)
|
||||||
|
set(CMAKE_SYSTEM_VERSION 1)
|
||||||
|
|
||||||
|
if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)")
|
||||||
|
message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}")
|
||||||
|
else()
|
||||||
|
set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple")
|
||||||
|
if (DEFINED ENV{RISCV_ROOT_PATH})
|
||||||
|
file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain")
|
||||||
|
set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc)
|
||||||
|
set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++)
|
||||||
|
set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu")
|
||||||
|
set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||||
|
set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}")
|
||||||
|
set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}")
|
||||||
|
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic")
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
set( CMAKE_SYSTEM_NAME Windows )
|
||||||
|
set( CMAKE_SYSTEM_PROCESSOR x86_64 )
|
||||||
|
|
||||||
|
set( CMAKE_C_COMPILER clang )
|
||||||
|
set( CMAKE_CXX_COMPILER clang++ )
|
||||||
@@ -0,0 +1,157 @@
|
|||||||
|
# common
|
||||||
|
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
llama_add_compile_flags()
|
||||||
|
|
||||||
|
# Build info header
|
||||||
|
#
|
||||||
|
|
||||||
|
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
||||||
|
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
|
||||||
|
|
||||||
|
# Is git submodule
|
||||||
|
if(NOT IS_DIRECTORY "${GIT_DIR}")
|
||||||
|
file(READ ${GIT_DIR} REAL_GIT_DIR_LINK)
|
||||||
|
string(REGEX REPLACE "gitdir: (.*)\n$" "\\1" REAL_GIT_DIR ${REAL_GIT_DIR_LINK})
|
||||||
|
string(FIND "${REAL_GIT_DIR}" "/" SLASH_POS)
|
||||||
|
if (SLASH_POS EQUAL 0)
|
||||||
|
set(GIT_DIR "${REAL_GIT_DIR}")
|
||||||
|
else()
|
||||||
|
set(GIT_DIR "${PROJECT_SOURCE_DIR}/${REAL_GIT_DIR}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(EXISTS "${GIT_DIR}/index")
|
||||||
|
# For build-info.cpp below
|
||||||
|
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${GIT_DIR}/index")
|
||||||
|
else()
|
||||||
|
message(WARNING "Git index not found in git repository.")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in")
|
||||||
|
set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/build-info.cpp")
|
||||||
|
configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE})
|
||||||
|
|
||||||
|
set(TARGET build_info)
|
||||||
|
add_library(${TARGET} OBJECT ${OUTPUT_FILE})
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(TARGET common)
|
||||||
|
|
||||||
|
add_library(${TARGET} STATIC
|
||||||
|
arg.cpp
|
||||||
|
arg.h
|
||||||
|
base64.hpp
|
||||||
|
chat-parser.cpp
|
||||||
|
chat-parser.h
|
||||||
|
chat-parser-xml-toolcall.h
|
||||||
|
chat-parser-xml-toolcall.cpp
|
||||||
|
chat-peg-parser.cpp
|
||||||
|
chat-peg-parser.h
|
||||||
|
chat.cpp
|
||||||
|
chat.h
|
||||||
|
common.cpp
|
||||||
|
common.h
|
||||||
|
console.cpp
|
||||||
|
console.h
|
||||||
|
download.cpp
|
||||||
|
download.h
|
||||||
|
http.h
|
||||||
|
json-partial.cpp
|
||||||
|
json-partial.h
|
||||||
|
json-schema-to-grammar.cpp
|
||||||
|
llguidance.cpp
|
||||||
|
log.cpp
|
||||||
|
log.h
|
||||||
|
ngram-cache.cpp
|
||||||
|
ngram-cache.h
|
||||||
|
peg-parser.cpp
|
||||||
|
peg-parser.h
|
||||||
|
preset.cpp
|
||||||
|
preset.h
|
||||||
|
regex-partial.cpp
|
||||||
|
regex-partial.h
|
||||||
|
sampling.cpp
|
||||||
|
sampling.h
|
||||||
|
speculative.cpp
|
||||||
|
speculative.h
|
||||||
|
unicode.cpp
|
||||||
|
unicode.h
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(${TARGET} PUBLIC . ../vendor)
|
||||||
|
target_compile_features (${TARGET} PUBLIC cxx_std_17)
|
||||||
|
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||||
|
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||||
|
|
||||||
|
if (LLAMA_CURL)
|
||||||
|
# Use curl to download model url
|
||||||
|
find_package(CURL)
|
||||||
|
if (NOT CURL_FOUND)
|
||||||
|
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
|
||||||
|
endif()
|
||||||
|
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
|
||||||
|
include_directories(${CURL_INCLUDE_DIRS})
|
||||||
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||||
|
elseif (LLAMA_HTTPLIB)
|
||||||
|
# otherwise, use cpp-httplib
|
||||||
|
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||||
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_LLGUIDANCE)
|
||||||
|
include(ExternalProject)
|
||||||
|
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||||
|
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||||
|
|
||||||
|
# Set the correct library file extension based on platform
|
||||||
|
if (WIN32)
|
||||||
|
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||||
|
# Add Windows-specific libraries
|
||||||
|
set(LLGUIDANCE_PLATFORM_LIBS
|
||||||
|
ws2_32 # Windows Sockets API
|
||||||
|
userenv # For GetUserProfileDirectoryW
|
||||||
|
ntdll # For NT functions
|
||||||
|
bcrypt # For BCryptGenRandom
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||||
|
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
ExternalProject_Add(llguidance_ext
|
||||||
|
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||||
|
# v1.0.1:
|
||||||
|
GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
|
||||||
|
PREFIX ${CMAKE_BINARY_DIR}/llguidance
|
||||||
|
SOURCE_DIR ${LLGUIDANCE_SRC}
|
||||||
|
BUILD_IN_SOURCE TRUE
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND cargo build --release --package llguidance
|
||||||
|
INSTALL_COMMAND ""
|
||||||
|
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
|
||||||
|
UPDATE_COMMAND ""
|
||||||
|
)
|
||||||
|
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
|
||||||
|
|
||||||
|
add_library(llguidance STATIC IMPORTED)
|
||||||
|
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME})
|
||||||
|
add_dependencies(llguidance llguidance_ext)
|
||||||
|
|
||||||
|
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||||
|
# Add platform libraries to the main target
|
||||||
|
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,131 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
// pseudo-env variable to identify preset-only arguments
|
||||||
|
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
|
||||||
|
#define COMMON_ARG_PRESET_STOP_TIMEOUT "__PRESET_STOP_TIMEOUT"
|
||||||
|
|
||||||
|
//
|
||||||
|
// CLI argument parsing
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_arg {
|
||||||
|
std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
|
||||||
|
std::set<enum llama_example> excludes = {};
|
||||||
|
std::vector<const char *> args;
|
||||||
|
std::vector<const char *> args_neg; // for negated args like --no-xxx
|
||||||
|
const char * value_hint = nullptr; // help text or example for arg value
|
||||||
|
const char * value_hint_2 = nullptr; // for second arg value
|
||||||
|
const char * env = nullptr;
|
||||||
|
std::string help;
|
||||||
|
bool is_sparam = false; // is current arg a sampling param?
|
||||||
|
bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
|
||||||
|
void (*handler_void) (common_params & params) = nullptr;
|
||||||
|
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
||||||
|
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
|
||||||
|
void (*handler_int) (common_params & params, int) = nullptr;
|
||||||
|
void (*handler_bool) (common_params & params, bool) = nullptr;
|
||||||
|
|
||||||
|
common_arg() = default;
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, const std::string &)
|
||||||
|
) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, int)
|
||||||
|
) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params)
|
||||||
|
) : args(args), help(help), handler_void(handler) {}
|
||||||
|
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const std::initializer_list<const char *> & args_neg,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, bool)
|
||||||
|
) : args(args), args_neg(args_neg), help(help), handler_bool(handler) {}
|
||||||
|
|
||||||
|
// support 2 values for arg
|
||||||
|
common_arg(
|
||||||
|
const std::initializer_list<const char *> & args,
|
||||||
|
const char * value_hint,
|
||||||
|
const char * value_hint_2,
|
||||||
|
const std::string & help,
|
||||||
|
void (*handler)(common_params & params, const std::string &, const std::string &)
|
||||||
|
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
||||||
|
|
||||||
|
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
||||||
|
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
||||||
|
common_arg & set_env(const char * env);
|
||||||
|
common_arg & set_sparam();
|
||||||
|
common_arg & set_preset_only();
|
||||||
|
bool in_example(enum llama_example ex);
|
||||||
|
bool is_exclude(enum llama_example ex);
|
||||||
|
bool get_value_from_env(std::string & output) const;
|
||||||
|
bool has_value_from_env() const;
|
||||||
|
std::string to_string() const;
|
||||||
|
|
||||||
|
// for using as key in std::map
|
||||||
|
bool operator<(const common_arg& other) const {
|
||||||
|
if (args.empty() || other.args.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return strcmp(args[0], other.args[0]) < 0;
|
||||||
|
}
|
||||||
|
bool operator==(const common_arg& other) const {
|
||||||
|
if (args.empty() || other.args.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return strcmp(args[0], other.args[0]) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get all args and env vars (including negated args/env)
|
||||||
|
std::vector<std::string> get_args() const;
|
||||||
|
std::vector<std::string> get_env() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace common_arg_utils {
|
||||||
|
bool is_truthy(const std::string & value);
|
||||||
|
bool is_falsey(const std::string & value);
|
||||||
|
bool is_autoy(const std::string & value);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_params_context {
|
||||||
|
enum llama_example ex = LLAMA_EXAMPLE_COMMON;
|
||||||
|
common_params & params;
|
||||||
|
std::vector<common_arg> options;
|
||||||
|
void(*print_usage)(int, char **) = nullptr;
|
||||||
|
common_params_context(common_params & params) : params(params) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// parse input arguments from CLI
|
||||||
|
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
|
||||||
|
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||||
|
|
||||||
|
// parse input arguments from CLI into a map
|
||||||
|
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
|
||||||
|
|
||||||
|
// populate preset-only arguments
|
||||||
|
// these arguments are not treated as command line arguments
|
||||||
|
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||||
|
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||||
|
|
||||||
|
// initialize argument parser context - used by test-arg-parser and preset
|
||||||
|
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||||
@@ -0,0 +1,392 @@
|
|||||||
|
/*
|
||||||
|
This is free and unencumbered software released into the public domain.
|
||||||
|
|
||||||
|
Anyone is free to copy, modify, publish, use, compile, sell, or
|
||||||
|
distribute this software, either in source code form or as a compiled
|
||||||
|
binary, for any purpose, commercial or non-commercial, and by any
|
||||||
|
means.
|
||||||
|
|
||||||
|
In jurisdictions that recognize copyright laws, the author or authors
|
||||||
|
of this software dedicate any and all copyright interest in the
|
||||||
|
software to the public domain. We make this dedication for the benefit
|
||||||
|
of the public at large and to the detriment of our heirs and
|
||||||
|
successors. We intend this dedication to be an overt act of
|
||||||
|
relinquishment in perpetuity of all present and future rights to this
|
||||||
|
software under copyright law.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||||
|
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||||
|
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||||
|
OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
|
For more information, please refer to <http://unlicense.org>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef PUBLIC_DOMAIN_BASE64_HPP_
|
||||||
|
#define PUBLIC_DOMAIN_BASE64_HPP_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iterator>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
class base64_error : public std::runtime_error
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using std::runtime_error::runtime_error;
|
||||||
|
};
|
||||||
|
|
||||||
|
class base64
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
enum class alphabet
|
||||||
|
{
|
||||||
|
/** the alphabet is detected automatically */
|
||||||
|
auto_,
|
||||||
|
/** the standard base64 alphabet is used */
|
||||||
|
standard,
|
||||||
|
/** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
|
||||||
|
url_filename_safe
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class decoding_behavior
|
||||||
|
{
|
||||||
|
/** if the input is not padded, the remaining bits are ignored */
|
||||||
|
moderate,
|
||||||
|
/** if a padding character is encounter decoding is finished */
|
||||||
|
loose
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
Encodes all the elements from `in_begin` to `in_end` to `out`.
|
||||||
|
|
||||||
|
@warning The source and destination cannot overlap. The destination must be able to hold at least
|
||||||
|
`required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
|
||||||
|
|
||||||
|
@tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
|
||||||
|
8 bits
|
||||||
|
@tparam Output_iterator the destination; the elements written to it are from the type `char`
|
||||||
|
@param in_begin the beginning of the source
|
||||||
|
@param in_end the ending of the source
|
||||||
|
@param out the destination iterator
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@returns the iterator to the next element past the last element copied
|
||||||
|
@throws see `Input_iterator` and `Output_iterator`
|
||||||
|
*/
|
||||||
|
template<typename Input_iterator, typename Output_iterator>
|
||||||
|
static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
||||||
|
alphabet alphabet = alphabet::standard)
|
||||||
|
{
|
||||||
|
constexpr auto pad = '=';
|
||||||
|
const char* alpha = alphabet == alphabet::url_filename_safe
|
||||||
|
? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||||
|
: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||||
|
|
||||||
|
while (in_begin != in_end) {
|
||||||
|
std::uint8_t i0 = 0, i1 = 0, i2 = 0;
|
||||||
|
|
||||||
|
// first character
|
||||||
|
i0 = static_cast<std::uint8_t>(*in_begin);
|
||||||
|
++in_begin;
|
||||||
|
|
||||||
|
*out = alpha[i0 >> 2 & 0x3f];
|
||||||
|
++out;
|
||||||
|
|
||||||
|
// part of first character and second
|
||||||
|
if (in_begin != in_end) {
|
||||||
|
i1 = static_cast<std::uint8_t>(*in_begin);
|
||||||
|
++in_begin;
|
||||||
|
|
||||||
|
*out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
|
||||||
|
++out;
|
||||||
|
} else {
|
||||||
|
*out = alpha[(i0 & 0x3) << 4];
|
||||||
|
++out;
|
||||||
|
|
||||||
|
// last padding
|
||||||
|
*out = pad;
|
||||||
|
++out;
|
||||||
|
|
||||||
|
// last padding
|
||||||
|
*out = pad;
|
||||||
|
++out;
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// part of second character and third
|
||||||
|
if (in_begin != in_end) {
|
||||||
|
i2 = static_cast<std::uint8_t>(*in_begin);
|
||||||
|
++in_begin;
|
||||||
|
|
||||||
|
*out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
|
||||||
|
++out;
|
||||||
|
} else {
|
||||||
|
*out = alpha[(i1 & 0xf) << 2];
|
||||||
|
++out;
|
||||||
|
|
||||||
|
// last padding
|
||||||
|
*out = pad;
|
||||||
|
++out;
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// rest of third
|
||||||
|
*out = alpha[i2 & 0x3f];
|
||||||
|
++out;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Encodes a string.
|
||||||
|
|
||||||
|
@param str the string that should be encoded
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@returns the encoded base64 string
|
||||||
|
@throws see base64::encode()
|
||||||
|
*/
|
||||||
|
static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
|
||||||
|
{
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
result.reserve(required_encode_size(str.length()) + 1);
|
||||||
|
|
||||||
|
encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Encodes a char array.
|
||||||
|
|
||||||
|
@param buffer the char array
|
||||||
|
@param size the size of the array
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@returns the encoded string
|
||||||
|
*/
|
||||||
|
static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
|
||||||
|
{
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
result.reserve(required_encode_size(size) + 1);
|
||||||
|
|
||||||
|
encode(buffer, buffer + size, std::back_inserter(result), alphabet);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
|
||||||
|
in other words: inplace decoding is possible.
|
||||||
|
|
||||||
|
@warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
|
||||||
|
otherwise the behavior depends on the output iterator.
|
||||||
|
|
||||||
|
@tparam Input_iterator the source; the returned elements are cast to `char`
|
||||||
|
@tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
|
||||||
|
@param in_begin the beginning of the source
|
||||||
|
@param in_end the ending of the source
|
||||||
|
@param out the destination iterator
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@param behavior the behavior when an error was detected
|
||||||
|
@returns the iterator to the next element past the last element copied
|
||||||
|
@throws base64_error depending on the set behavior
|
||||||
|
@throws see `Input_iterator` and `Output_iterator`
|
||||||
|
*/
|
||||||
|
template<typename Input_iterator, typename Output_iterator>
|
||||||
|
static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
||||||
|
alphabet alphabet = alphabet::auto_,
|
||||||
|
decoding_behavior behavior = decoding_behavior::moderate)
|
||||||
|
{
|
||||||
|
//constexpr auto pad = '=';
|
||||||
|
std::uint8_t last = 0;
|
||||||
|
auto bits = 0;
|
||||||
|
|
||||||
|
while (in_begin != in_end) {
|
||||||
|
auto c = *in_begin;
|
||||||
|
++in_begin;
|
||||||
|
|
||||||
|
if (c == '=') {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto part = _base64_value(alphabet, c);
|
||||||
|
|
||||||
|
// enough bits for one byte
|
||||||
|
if (bits + 6 >= 8) {
|
||||||
|
*out = (last << (8 - bits)) | (part >> (bits - 2));
|
||||||
|
++out;
|
||||||
|
|
||||||
|
bits -= 2;
|
||||||
|
} else {
|
||||||
|
bits += 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
last = part;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check padding
|
||||||
|
if (behavior != decoding_behavior::loose) {
|
||||||
|
while (in_begin != in_end) {
|
||||||
|
auto c = *in_begin;
|
||||||
|
++in_begin;
|
||||||
|
|
||||||
|
if (c != '=') {
|
||||||
|
throw base64_error("invalid base64 character.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Decodes a string.
|
||||||
|
|
||||||
|
@param str the base64 encoded string
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@param behavior the behavior when an error was detected
|
||||||
|
@returns the decoded string
|
||||||
|
@throws see base64::decode()
|
||||||
|
*/
|
||||||
|
static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
|
||||||
|
decoding_behavior behavior = decoding_behavior::moderate)
|
||||||
|
{
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
result.reserve(max_decode_size(str.length()));
|
||||||
|
|
||||||
|
decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Decodes a string.
|
||||||
|
|
||||||
|
@param buffer the base64 encoded buffer
|
||||||
|
@param size the size of the buffer
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@param behavior the behavior when an error was detected
|
||||||
|
@returns the decoded string
|
||||||
|
@throws see base64::decode()
|
||||||
|
*/
|
||||||
|
static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
|
||||||
|
decoding_behavior behavior = decoding_behavior::moderate)
|
||||||
|
{
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
result.reserve(max_decode_size(size));
|
||||||
|
|
||||||
|
decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Decodes a string inplace.
|
||||||
|
|
||||||
|
@param[in,out] str the base64 encoded string
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@param behavior the behavior when an error was detected
|
||||||
|
@throws base64::decode_inplace()
|
||||||
|
*/
|
||||||
|
static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
|
||||||
|
decoding_behavior behavior = decoding_behavior::moderate)
|
||||||
|
{
|
||||||
|
str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Decodes a char array inplace.
|
||||||
|
|
||||||
|
@param[in,out] str the string array
|
||||||
|
@param size the length of the array
|
||||||
|
@param alphabet which alphabet should be used
|
||||||
|
@param behavior the behavior when an error was detected
|
||||||
|
@returns the pointer to the next element past the last element decoded
|
||||||
|
@throws base64::decode_inplace()
|
||||||
|
*/
|
||||||
|
static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
|
||||||
|
decoding_behavior behavior = decoding_behavior::moderate)
|
||||||
|
{
|
||||||
|
return decode(str, str + size, str, alphabet, behavior);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Returns the required decoding size for a given size. The value is calculated with the following formula:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\lceil \frac{size}{4} \rceil \cdot 3
|
||||||
|
$$
|
||||||
|
|
||||||
|
@param size the size of the encoded input
|
||||||
|
@returns the size of the resulting decoded buffer; this the absolute maximum
|
||||||
|
*/
|
||||||
|
static std::size_t max_decode_size(std::size_t size) noexcept
|
||||||
|
{
|
||||||
|
return (size / 4 + (size % 4 ? 1 : 0)) * 3;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
Returns the required encoding size for a given size. The value is calculated with the following formula:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\lceil \frac{size}{3} \rceil \cdot 4
|
||||||
|
$$
|
||||||
|
|
||||||
|
@param size the size of the decoded input
|
||||||
|
@returns the size of the resulting encoded buffer
|
||||||
|
*/
|
||||||
|
static std::size_t required_encode_size(std::size_t size) noexcept
|
||||||
|
{
|
||||||
|
return (size / 3 + (size % 3 ? 1 : 0)) * 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::uint8_t _base64_value(alphabet& alphabet, char c)
|
||||||
|
{
|
||||||
|
if (c >= 'A' && c <= 'Z') {
|
||||||
|
return c - 'A';
|
||||||
|
} else if (c >= 'a' && c <= 'z') {
|
||||||
|
return c - 'a' + 26;
|
||||||
|
} else if (c >= '0' && c <= '9') {
|
||||||
|
return c - '0' + 52;
|
||||||
|
}
|
||||||
|
|
||||||
|
// comes down to alphabet
|
||||||
|
if (alphabet == alphabet::standard) {
|
||||||
|
if (c == '+') {
|
||||||
|
return 62;
|
||||||
|
} else if (c == '/') {
|
||||||
|
return 63;
|
||||||
|
}
|
||||||
|
} else if (alphabet == alphabet::url_filename_safe) {
|
||||||
|
if (c == '-') {
|
||||||
|
return 62;
|
||||||
|
} else if (c == '_') {
|
||||||
|
return 63;
|
||||||
|
}
|
||||||
|
} // auto detect
|
||||||
|
else {
|
||||||
|
if (c == '+') {
|
||||||
|
alphabet = alphabet::standard;
|
||||||
|
|
||||||
|
return 62;
|
||||||
|
} else if (c == '/') {
|
||||||
|
alphabet = alphabet::standard;
|
||||||
|
|
||||||
|
return 63;
|
||||||
|
} else if (c == '-') {
|
||||||
|
alphabet = alphabet::url_filename_safe;
|
||||||
|
|
||||||
|
return 62;
|
||||||
|
} else if (c == '_') {
|
||||||
|
alphabet = alphabet::url_filename_safe;
|
||||||
|
|
||||||
|
return 63;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw base64_error("invalid base64 character.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // !PUBLIC_DOMAIN_BASE64_HPP_
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
int LLAMA_BUILD_NUMBER = @LLAMA_BUILD_NUMBER@;
|
||||||
|
char const *LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@";
|
||||||
|
char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
|
||||||
|
char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
|
||||||
@@ -0,0 +1,879 @@
|
|||||||
|
#include "chat.h"
|
||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "json-schema-to-grammar.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline void sort_uniq(std::vector<T> &vec) {
|
||||||
|
std::sort(vec.begin(), vec.end());
|
||||||
|
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline bool all_space(const T &str) {
|
||||||
|
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||||
|
size_t len = s.size();
|
||||||
|
if (len == 0) return 0;
|
||||||
|
size_t i = len;
|
||||||
|
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||||
|
--i;
|
||||||
|
unsigned char c = s[i];
|
||||||
|
if ((c & 0x80) == 0) {
|
||||||
|
return len;
|
||||||
|
} else if ((c & 0xC0) == 0xC0) {
|
||||||
|
size_t expected_len = 0;
|
||||||
|
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||||
|
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||||
|
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||||
|
else return i;
|
||||||
|
if (len - i >= expected_len) {
|
||||||
|
return len;
|
||||||
|
} else {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len - std::min(len, size_t(3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||||
|
s.resize(utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||||
|
return s.substr(0, utf8_truncate_safe(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||||
|
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||||
|
const auto saved_pos = builder.pos();
|
||||||
|
while (auto res = builder.try_find_literal(literal1)) {
|
||||||
|
builder.consume_spaces();
|
||||||
|
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||||
|
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||||
|
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||||
|
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||||
|
}
|
||||||
|
builder.move_to(builder.pos() + match_len);
|
||||||
|
res->groups[0].end = builder.pos();
|
||||||
|
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
builder.move_to(res->groups[0].begin + 1);
|
||||||
|
}
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
*/
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||||
|
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||||
|
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||||
|
std::string s = "\\";
|
||||||
|
s.push_back((char)c);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
if (isprint(c)) {
|
||||||
|
return std::string(1, (char)c);
|
||||||
|
}
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", c);
|
||||||
|
return std::string(buf);
|
||||||
|
};
|
||||||
|
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||||
|
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||||
|
int i = l;
|
||||||
|
while (i < r) {
|
||||||
|
const std::string &s = forbids[i];
|
||||||
|
if ((int)s.size() == depth) {
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
unsigned char c = (unsigned char)s[depth];
|
||||||
|
int j = i;
|
||||||
|
while (j < r && (int)forbids[j].size() > depth &&
|
||||||
|
(unsigned char)forbids[j][depth] == c) {
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
children.push_back({c, {i, j}});
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
std::vector<std::string> alts;
|
||||||
|
if (!children.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||||
|
alts.push_back(std::string("[^") + cls + "]");
|
||||||
|
}
|
||||||
|
for (auto &ch : children) {
|
||||||
|
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||||
|
if (!childExpr.empty()) {
|
||||||
|
std::string quoted_ch = "\"";
|
||||||
|
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||||
|
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||||
|
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||||
|
else {
|
||||||
|
char buf[16];
|
||||||
|
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||||
|
quoted_ch += buf;
|
||||||
|
}
|
||||||
|
quoted_ch += "\"";
|
||||||
|
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||||
|
alts.push_back(branch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (alts.empty()) return "";
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "( ";
|
||||||
|
for (size_t k = 0; k < alts.size(); ++k) {
|
||||||
|
if (k) oss << " | ";
|
||||||
|
oss << alts[k];
|
||||||
|
}
|
||||||
|
oss << " )";
|
||||||
|
return oss.str();
|
||||||
|
};
|
||||||
|
if (forbids.empty()) return "( . )*";
|
||||||
|
sort(forbids.begin(), forbids.end());
|
||||||
|
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||||
|
if (expr.empty()) {
|
||||||
|
std::string cls;
|
||||||
|
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||||
|
expr = std::string("( [^") + cls + "] )";
|
||||||
|
}
|
||||||
|
if (forbids.size() == 1)
|
||||||
|
return expr + "*";
|
||||||
|
else
|
||||||
|
return std::string("( ") + expr + " )*";
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.tool_sep.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
std::string key_val_sep = form.key_val_sep;
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
key_val_sep += "\n";
|
||||||
|
key_val_sep += *form.key_val_sep2;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(!key_val_sep.empty());
|
||||||
|
|
||||||
|
if (tools.is_array() && !tools.empty()) {
|
||||||
|
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||||
|
auto string_arg_val = form.last_val_end ?
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||||
|
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||||
|
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||||
|
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::string name = function.at("name");
|
||||||
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
|
|
||||||
|
struct parameter_rule {
|
||||||
|
std::string symbol_name;
|
||||||
|
bool is_required;
|
||||||
|
};
|
||||||
|
std::vector<parameter_rule> arg_rules;
|
||||||
|
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||||
|
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
std::vector<std::string> requiredParameters;
|
||||||
|
if (parameters.contains("required")) {
|
||||||
|
try { parameters.at("required").get_to(requiredParameters); }
|
||||||
|
catch (const std::runtime_error&) {
|
||||||
|
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort_uniq(requiredParameters);
|
||||||
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
|
std::string quoted_key = key;
|
||||||
|
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||||
|
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||||
|
quoted_key = gbnf_format_literal(key);
|
||||||
|
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||||
|
}
|
||||||
|
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||||
|
gbnf_format_literal(form.key_start) + " " +
|
||||||
|
gbnf_format_literal(quoted_key) + " " +
|
||||||
|
gbnf_format_literal(key_val_sep) + " " +
|
||||||
|
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||||
|
(form.raw_argval ?
|
||||||
|
string_arg_val :
|
||||||
|
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||||
|
) :
|
||||||
|
builder.add_schema(name + "-arg-" + key, value)
|
||||||
|
)
|
||||||
|
), required});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||||
|
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||||
|
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||||
|
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||||
|
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||||
|
);
|
||||||
|
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||||
|
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||||
|
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string quoted_name = name;
|
||||||
|
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||||
|
quoted_name = gbnf_format_literal(name);
|
||||||
|
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||||
|
}
|
||||||
|
quoted_name = gbnf_format_literal(quoted_name);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||||
|
}
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||||
|
gbnf_format_literal(form.tool_start) + " " +
|
||||||
|
quoted_name + " " +
|
||||||
|
gbnf_format_literal(form.tool_sep) + " " +
|
||||||
|
next_arg
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||||
|
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||||
|
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||||
|
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||||
|
builder.add_rule("root",
|
||||||
|
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||||
|
tool_call_multiple_with_end + "?" +
|
||||||
|
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// grammar trigger for tool call
|
||||||
|
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||||
|
GGML_ASSERT(!form.tool_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_start.empty());
|
||||||
|
GGML_ASSERT(!form.key_val_sep.empty());
|
||||||
|
GGML_ASSERT(!form.val_end.empty());
|
||||||
|
GGML_ASSERT(!form.tool_end.empty());
|
||||||
|
|
||||||
|
// Helper to choose return false or throw error
|
||||||
|
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||||
|
if (recovery) {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||||
|
};
|
||||||
|
// Drop substring from needle to end from a JSON
|
||||||
|
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||||
|
auto pos = json_str.rfind(needle);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||||
|
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||||
|
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||||
|
--pos;
|
||||||
|
}
|
||||||
|
json_str.resize(pos);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||||
|
auto rest = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(rest);
|
||||||
|
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||||
|
auto tool_str = arguments.dump();
|
||||||
|
if (partial_json(tool_str)) {
|
||||||
|
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||||
|
};
|
||||||
|
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||||
|
constexpr auto try_find_close = [](
|
||||||
|
common_chat_msg_parser & builder,
|
||||||
|
const std::string & end,
|
||||||
|
const std::optional<std::string> & alt_end,
|
||||||
|
const std::string & end_next,
|
||||||
|
const std::optional<std::string> & alt_end_next
|
||||||
|
) {
|
||||||
|
auto saved_pos = builder.pos();
|
||||||
|
auto tc = builder.try_find_literal(end);
|
||||||
|
auto val_end_size = end.size();
|
||||||
|
if (alt_end) {
|
||||||
|
auto pos_1 = builder.pos();
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||||
|
if (alt_end_next) {
|
||||||
|
builder.move_to(saved_pos);
|
||||||
|
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||||
|
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||||
|
tc2 = tc3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||||
|
tc = tc2;
|
||||||
|
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||||
|
builder.move_to(tc->groups[0].end);
|
||||||
|
val_end_size = alt_end->size();
|
||||||
|
} else {
|
||||||
|
builder.move_to(pos_1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(val_end_size, tc);
|
||||||
|
};
|
||||||
|
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||||
|
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||||
|
};
|
||||||
|
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||||
|
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||||
|
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool recovery = true;
|
||||||
|
const auto start_pos = builder.pos();
|
||||||
|
if (!all_space(form.scope_start)) {
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||||
|
if (all_space(tc->prelude)) {
|
||||||
|
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||||
|
} else {
|
||||||
|
builder.move_to(start_pos);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else return false;
|
||||||
|
}
|
||||||
|
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.tool_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find tool name
|
||||||
|
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||||
|
if (!func_name) {
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
if (!func_name) {
|
||||||
|
// Partial tool name not supported
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||||
|
}
|
||||||
|
// If the model generate multiple tool call and the first tool call has no argument
|
||||||
|
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||||
|
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||||
|
auto [sz, tc] = try_find_tool_end();
|
||||||
|
func_name = tc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tool name
|
||||||
|
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||||
|
std::string function_name = string_strip(func_name->prelude);
|
||||||
|
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||||
|
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||||
|
if (string_starts_with(function_name, "functions.")) {
|
||||||
|
static const std::regex re(":\\d+$");
|
||||||
|
if (std::regex_search(function_name, re)) {
|
||||||
|
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Argument JSON
|
||||||
|
json arguments = json::object();
|
||||||
|
|
||||||
|
// Helper to generate a partial argument JSON
|
||||||
|
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||||
|
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse all arg_key/arg_value pairs
|
||||||
|
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||||
|
gbnf_format_literal(form.key_start).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse arg_key
|
||||||
|
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||||
|
if (!key_res) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||||
|
}
|
||||||
|
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
auto &key = key_res->prelude;
|
||||||
|
recovery = false;
|
||||||
|
|
||||||
|
// Parse arg_value
|
||||||
|
if (form.key_val_sep2) {
|
||||||
|
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||||
|
gbnf_format_literal(tc->prelude).c_str(),
|
||||||
|
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||||
|
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, false);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto val_start = builder.pos();
|
||||||
|
|
||||||
|
// Test if arg_val is a partial JSON
|
||||||
|
std::optional<common_json> value_json = std::nullopt;
|
||||||
|
if (!form.raw_argval || !*form.raw_argval) {
|
||||||
|
try { value_json = builder.try_consume_json(); }
|
||||||
|
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||||
|
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||||
|
if (builder.pos() == val_start) {
|
||||||
|
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||||
|
builder.consume_spaces();
|
||||||
|
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||||
|
sv.remove_prefix(builder.pos());
|
||||||
|
std::string rest = "a";
|
||||||
|
if (sv.size() < 6) rest = sv;
|
||||||
|
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||||
|
value_json = {123, {"123", "123"}};
|
||||||
|
builder.consume_rest();
|
||||||
|
} else {
|
||||||
|
builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it is a JSON and followed by </arg_value>, parse as json
|
||||||
|
// cannot support streaming because it may be a plain text starting with JSON
|
||||||
|
if (value_json) {
|
||||||
|
auto json_end = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size()) {
|
||||||
|
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||||
|
arguments[key] = value_json->json;
|
||||||
|
auto json_str = arguments.dump();
|
||||||
|
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||||
|
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(json_str.back() == '}');
|
||||||
|
json_str.resize(json_str.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", json_str);
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
}
|
||||||
|
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||||
|
}
|
||||||
|
builder.move_to(json_end);
|
||||||
|
auto [val_end_size, tc] = try_find_val_end();
|
||||||
|
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||||
|
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||||
|
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||||
|
} else arguments[key] = value_json->json;
|
||||||
|
} else builder.move_to(val_start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not, parse as plain text
|
||||||
|
if (val_start == builder.pos()) {
|
||||||
|
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||||
|
auto &value_str = value_plain->prelude;
|
||||||
|
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||||
|
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||||
|
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
arguments[key] = value_str;
|
||||||
|
} else {
|
||||||
|
if (form.trim_raw_argval) {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||||
|
} else {
|
||||||
|
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception(
|
||||||
|
"Expected " + gbnf_format_literal(form.val_end) +
|
||||||
|
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||||
|
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume closing tag
|
||||||
|
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.tool_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||||
|
// Add the parsed tool call
|
||||||
|
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||||
|
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||||
|
}
|
||||||
|
recovery = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tool_call_arg = arguments.dump();
|
||||||
|
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||||
|
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||||
|
}
|
||||||
|
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||||
|
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||||
|
}
|
||||||
|
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||||
|
if (!all_space(tc->prelude)) {
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(tc->prelude).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (all_space(form.scope_end)) return true;
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() == builder.input().size())
|
||||||
|
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||||
|
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||||
|
gbnf_format_literal(form.scope_end).c_str(),
|
||||||
|
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||||
|
);
|
||||||
|
return return_error(builder, start_pos, recovery);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||||
|
auto pos = pos_;
|
||||||
|
auto tsize = result_.tool_calls.size();
|
||||||
|
try { return parse_xml_tool_calls(*this, form); }
|
||||||
|
catch (const xml_toolcall_syntax_exception&) {}
|
||||||
|
move_to(pos);
|
||||||
|
result_.tool_calls.resize(tsize);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||||
|
*/
|
||||||
|
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||||
|
constexpr auto rstrip = [](std::string &s) {
|
||||||
|
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||||
|
};
|
||||||
|
// Erase substring from l to r, along with additional spaces nearby
|
||||||
|
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||||
|
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||||
|
++l;
|
||||||
|
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||||
|
if (l < r) str[l] = '\n';
|
||||||
|
if (l + 1 < r) str[l + 1] = '\n';
|
||||||
|
if (l != 0) l += 2;
|
||||||
|
str.erase(l, r - l);
|
||||||
|
return l;
|
||||||
|
};
|
||||||
|
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||||
|
auto best_match = content.size();
|
||||||
|
for (auto pattern: list) {
|
||||||
|
if (pattern.size() == 0) continue;
|
||||||
|
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||||
|
auto match_len = content.size() - match_idx;
|
||||||
|
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||||
|
best_match = match_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (content.size() > best_match) {
|
||||||
|
content.erase(best_match);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||||
|
return trim_suffix(content, {
|
||||||
|
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||||
|
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||||
|
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||||
|
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||||
|
form.scope_end
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Trim leading spaces without affecting keyword matching
|
||||||
|
static const common_regex spaces_regex("\\s*");
|
||||||
|
{
|
||||||
|
auto tc = builder.consume_regex(spaces_regex);
|
||||||
|
auto spaces = builder.str(tc.groups[0]);
|
||||||
|
auto s1 = spaces.size();
|
||||||
|
trim_potential_partial_word(spaces);
|
||||||
|
auto s2 = spaces.size();
|
||||||
|
builder.move_to(builder.pos() - (s1 - s2));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse content
|
||||||
|
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||||
|
std::string unclosed_reasoning_content("");
|
||||||
|
for (;;) {
|
||||||
|
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||||
|
std::string content;
|
||||||
|
std::string tool_call_start;
|
||||||
|
|
||||||
|
if (tc) {
|
||||||
|
content = std::move(tc->prelude);
|
||||||
|
tool_call_start = builder.str(tc->groups[0]);
|
||||||
|
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||||
|
} else {
|
||||||
|
content = builder.consume_rest();
|
||||||
|
utf8_truncate_safe_resize(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle unclosed think block
|
||||||
|
if (reasoning_unclosed) {
|
||||||
|
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||||
|
unclosed_reasoning_content += content;
|
||||||
|
if (!(form.allow_toolcall_in_think && tc)) {
|
||||||
|
unclosed_reasoning_content += tool_call_start;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reasoning_unclosed = false;
|
||||||
|
std::string reasoning_content;
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
reasoning_content = std::move(content);
|
||||||
|
} else {
|
||||||
|
reasoning_content = content.substr(0, pos);
|
||||||
|
content.erase(0, pos + end_think.size());
|
||||||
|
}
|
||||||
|
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
trim_potential_partial_word(reasoning_content);
|
||||||
|
rstrip(reasoning_content);
|
||||||
|
if (reasoning_content.empty()) {
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
trim_potential_partial_word(unclosed_reasoning_content);
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
if (unclosed_reasoning_content.empty()) continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||||
|
builder.add_content(start_think);
|
||||||
|
builder.add_content(unclosed_reasoning_content);
|
||||||
|
builder.add_content(reasoning_content);
|
||||||
|
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||||
|
builder.add_content(end_think);
|
||||||
|
} else {
|
||||||
|
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
}
|
||||||
|
unclosed_reasoning_content.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple think block
|
||||||
|
bool toolcall_in_think = false;
|
||||||
|
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||||
|
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||||
|
builder.add_reasoning_content(reasoning_content);
|
||||||
|
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||||
|
} else {
|
||||||
|
think_start = think_end + end_think.size() - 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// This <tool_call> start is in thinking block, skip this tool call
|
||||||
|
// This <tool_call> start is in thinking block
|
||||||
|
if (form.allow_toolcall_in_think) {
|
||||||
|
unclosed_reasoning_content = content.substr(think_start + start_think.size());
|
||||||
|
} else {
|
||||||
|
unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start;
|
||||||
|
}
|
||||||
|
reasoning_unclosed = true;
|
||||||
|
content.resize(think_start);
|
||||||
|
toolcall_in_think = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
rstrip(content);
|
||||||
|
// Handle unclosed </think> token from content: delete all </think> token
|
||||||
|
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||||
|
pos = content.rfind(end_think, pos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Strip if needed
|
||||||
|
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||||
|
content = string_strip(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove potential partial suffix
|
||||||
|
if (builder.pos() == builder.input().size()) {
|
||||||
|
if (unclosed_reasoning_content.empty()) {
|
||||||
|
rstrip(content);
|
||||||
|
trim_potential_partial_word(content);
|
||||||
|
rstrip(content);
|
||||||
|
} else {
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
trim_potential_partial_word(unclosed_reasoning_content);
|
||||||
|
rstrip(unclosed_reasoning_content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// consume unclosed_reasoning_content if allow_toolcall_in_think is set
|
||||||
|
if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) {
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||||
|
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||||
|
} else {
|
||||||
|
if (content.empty()) {
|
||||||
|
content = start_think + unclosed_reasoning_content;
|
||||||
|
} else {
|
||||||
|
content += "\n\n" + start_think;
|
||||||
|
content += unclosed_reasoning_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unclosed_reasoning_content.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add content
|
||||||
|
if (!content.empty()) {
|
||||||
|
// If there are multiple content blocks
|
||||||
|
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
builder.add_content(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This <tool_call> start is in thinking block and toolcall_in_think not set, skip this tool call
|
||||||
|
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is no tool call and all content is parsed
|
||||||
|
if (!tc) {
|
||||||
|
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||||
|
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||||
|
if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.move_to(tc->groups[0].begin);
|
||||||
|
if (builder.try_consume_xml_tool_calls(form)) {
|
||||||
|
auto end_of_tool = builder.pos();
|
||||||
|
builder.consume_spaces();
|
||||||
|
if (builder.pos() != builder.input().size()) {
|
||||||
|
builder.move_to(end_of_tool);
|
||||||
|
if (!builder.result().content.empty()) {
|
||||||
|
builder.add_content("\n\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
static const common_regex next_char_regex(".");
|
||||||
|
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||||
|
rstrip(c);
|
||||||
|
builder.add_content(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse content uses reasoning and XML-Style tool call
|
||||||
|
*/
|
||||||
|
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||||
|
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||||
|
}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
|
// Sample config:
|
||||||
|
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||||
|
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||||
|
struct xml_tool_call_format {
|
||||||
|
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||||
|
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||||
|
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||||
|
std::string key_start; // <parameter name=\" // <arg_key>
|
||||||
|
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||||
|
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||||
|
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||||
|
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||||
|
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||||
|
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||||
|
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||||
|
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||||
|
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||||
|
// Defaults to std::nullopt, both will be allowed.
|
||||||
|
std::optional<bool> raw_argval = std::nullopt;
|
||||||
|
std::optional<std::string> last_val_end = std::nullopt;
|
||||||
|
std::optional<std::string> last_tool_end = std::nullopt;
|
||||||
|
bool trim_raw_argval = false;
|
||||||
|
bool allow_toolcall_in_think = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||||
|
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build grammar for xml-style tool call
|
||||||
|
* form.scope_start and form.scope_end can be empty.
|
||||||
|
* Requires data.format for model-specific hacks.
|
||||||
|
*/
|
||||||
|
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,133 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
#include "chat-parser-xml-toolcall.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_msg_parser {
|
||||||
|
std::string input_;
|
||||||
|
bool is_partial_;
|
||||||
|
common_chat_syntax syntax_;
|
||||||
|
std::string healing_marker_;
|
||||||
|
|
||||||
|
size_t pos_ = 0;
|
||||||
|
common_chat_msg result_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
const std::string & input() const { return input_; }
|
||||||
|
size_t pos() const { return pos_; }
|
||||||
|
const std::string & healing_marker() const { return healing_marker_; }
|
||||||
|
const bool & is_partial() const { return is_partial_; }
|
||||||
|
const common_chat_msg & result() const { return result_; }
|
||||||
|
const common_chat_syntax & syntax() const { return syntax_; }
|
||||||
|
|
||||||
|
void move_to(size_t pos) {
|
||||||
|
if (pos > input_.size()) {
|
||||||
|
throw std::runtime_error("Invalid position!");
|
||||||
|
}
|
||||||
|
pos_ = pos;
|
||||||
|
}
|
||||||
|
void move_back(size_t n) {
|
||||||
|
if (pos_ < n) {
|
||||||
|
throw std::runtime_error("Can't move back that far!");
|
||||||
|
}
|
||||||
|
pos_ -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the substring of the input at the given range
|
||||||
|
std::string str(const common_string_range & rng) const;
|
||||||
|
|
||||||
|
// Appends to the result.content field
|
||||||
|
void add_content(const std::string & content);
|
||||||
|
|
||||||
|
// Appends to the result.reasoning_content field
|
||||||
|
void add_reasoning_content(const std::string & reasoning_content);
|
||||||
|
|
||||||
|
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||||
|
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||||
|
|
||||||
|
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||||
|
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||||
|
|
||||||
|
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||||
|
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||||
|
|
||||||
|
// Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
|
||||||
|
bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
|
||||||
|
|
||||||
|
void finish();
|
||||||
|
|
||||||
|
bool consume_spaces();
|
||||||
|
|
||||||
|
void consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||||
|
|
||||||
|
std::string consume_rest();
|
||||||
|
|
||||||
|
struct find_regex_result {
|
||||||
|
std::string prelude;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
|
||||||
|
|
||||||
|
bool try_consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||||
|
|
||||||
|
find_regex_result consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<common_json> try_consume_json();
|
||||||
|
common_json consume_json();
|
||||||
|
|
||||||
|
struct consume_json_result {
|
||||||
|
nlohmann::ordered_json value;
|
||||||
|
bool is_partial;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||||
|
|
||||||
|
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||||
|
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||||
|
|
||||||
|
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||||
|
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||||
|
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||||
|
*/
|
||||||
|
consume_json_result consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||||
|
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||||
|
*/
|
||||||
|
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||||
|
|
||||||
|
// Parse content uses reasoning and XML-Style tool call
|
||||||
|
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||||
|
|
||||||
|
void clear_tools();
|
||||||
|
};
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
#include "chat-peg-parser.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
|
||||||
|
int count = 0;
|
||||||
|
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
|
||||||
|
if (max != -1 && count <= max) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sv.remove_suffix(1);
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
return sv;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) {
|
||||||
|
arena.visit(result, [this](const common_peg_ast_node & node) {
|
||||||
|
map(node);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
||||||
|
bool is_reasoning = node.tag == common_chat_peg_builder::REASONING;
|
||||||
|
bool is_content = node.tag == common_chat_peg_builder::CONTENT;
|
||||||
|
|
||||||
|
if (is_reasoning) {
|
||||||
|
result.reasoning_content = std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_content) {
|
||||||
|
result.content = std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) {
|
||||||
|
common_chat_peg_mapper::map(node);
|
||||||
|
|
||||||
|
bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN;
|
||||||
|
bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME;
|
||||||
|
bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID;
|
||||||
|
bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS;
|
||||||
|
|
||||||
|
if (is_tool_open) {
|
||||||
|
result.tool_calls.emplace_back();
|
||||||
|
current_tool = &result.tool_calls.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_tool_id && current_tool) {
|
||||||
|
current_tool->id = std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_tool_name && current_tool) {
|
||||||
|
current_tool->name = std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_tool_args && current_tool) {
|
||||||
|
current_tool->arguments = std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
|
||||||
|
common_chat_peg_mapper::map(node);
|
||||||
|
|
||||||
|
bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN;
|
||||||
|
bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME;
|
||||||
|
bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE;
|
||||||
|
bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN;
|
||||||
|
bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE;
|
||||||
|
bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME;
|
||||||
|
bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE;
|
||||||
|
bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE;
|
||||||
|
|
||||||
|
if (is_tool_open) {
|
||||||
|
result.tool_calls.emplace_back();
|
||||||
|
current_tool = &result.tool_calls.back();
|
||||||
|
arg_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_tool_name) {
|
||||||
|
current_tool->name = std::string(node.text);
|
||||||
|
current_tool->arguments = "{";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_arg_open) {
|
||||||
|
needs_closing_quote = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_arg_name && current_tool) {
|
||||||
|
if (arg_count > 0) {
|
||||||
|
current_tool->arguments += ",";
|
||||||
|
}
|
||||||
|
current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":";
|
||||||
|
++arg_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_arg_string && current_tool) {
|
||||||
|
// Serialize to JSON, but exclude the end quote
|
||||||
|
std::string dumped = json(trim_trailing_space(node.text)).dump();
|
||||||
|
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
|
||||||
|
needs_closing_quote = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_arg_close && current_tool) {
|
||||||
|
if (needs_closing_quote) {
|
||||||
|
current_tool->arguments += "\"";
|
||||||
|
needs_closing_quote = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_arg_json && current_tool) {
|
||||||
|
current_tool->arguments += std::string(trim_trailing_space(node.text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_tool_close && current_tool) {
|
||||||
|
if (needs_closing_quote) {
|
||||||
|
current_tool->arguments += "\"";
|
||||||
|
needs_closing_quote = false;
|
||||||
|
}
|
||||||
|
current_tool->arguments += "}";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
#include "peg-parser.h"
|
||||||
|
|
||||||
|
class common_chat_peg_builder : public common_peg_parser_builder {
|
||||||
|
public:
|
||||||
|
static constexpr const char * REASONING_BLOCK = "reasoning-block";
|
||||||
|
static constexpr const char * REASONING = "reasoning";
|
||||||
|
static constexpr const char * CONTENT = "content";
|
||||||
|
|
||||||
|
common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); }
|
||||||
|
common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); }
|
||||||
|
common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); }
|
||||||
|
};
|
||||||
|
|
||||||
|
inline common_peg_arena build_chat_peg_parser(const std::function<common_peg_parser(common_chat_peg_builder & builder)> & fn) {
|
||||||
|
common_chat_peg_builder builder;
|
||||||
|
builder.set_root(fn(builder));
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
class common_chat_peg_mapper {
|
||||||
|
public:
|
||||||
|
common_chat_msg & result;
|
||||||
|
|
||||||
|
common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {}
|
||||||
|
|
||||||
|
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||||
|
virtual void map(const common_peg_ast_node & node);
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_peg_native_builder : public common_chat_peg_builder {
|
||||||
|
public:
|
||||||
|
static constexpr const char * TOOL = "tool";
|
||||||
|
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||||
|
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||||
|
static constexpr const char * TOOL_ID = "tool-id";
|
||||||
|
static constexpr const char * TOOL_NAME = "tool-name";
|
||||||
|
static constexpr const char * TOOL_ARGS = "tool-args";
|
||||||
|
|
||||||
|
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||||
|
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||||
|
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||||
|
common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); }
|
||||||
|
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||||
|
common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_peg_native_mapper : public common_chat_peg_mapper {
|
||||||
|
common_chat_tool_call * current_tool;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||||
|
|
||||||
|
void map(const common_peg_ast_node & node) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline common_peg_arena build_chat_peg_native_parser(const std::function<common_peg_parser(common_chat_peg_native_builder & builder)> & fn) {
|
||||||
|
common_chat_peg_native_builder builder;
|
||||||
|
builder.set_root(fn(builder));
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
class common_chat_peg_constructed_builder : public common_chat_peg_builder {
|
||||||
|
public:
|
||||||
|
static constexpr const char * TOOL = "tool";
|
||||||
|
static constexpr const char * TOOL_OPEN = "tool-open";
|
||||||
|
static constexpr const char * TOOL_CLOSE = "tool-close";
|
||||||
|
static constexpr const char * TOOL_NAME = "tool-name";
|
||||||
|
static constexpr const char * TOOL_ARG = "tool-arg";
|
||||||
|
static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open";
|
||||||
|
static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close";
|
||||||
|
static constexpr const char * TOOL_ARG_NAME = "tool-arg-name";
|
||||||
|
static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value";
|
||||||
|
static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value";
|
||||||
|
|
||||||
|
common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); }
|
||||||
|
common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); }
|
||||||
|
common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); }
|
||||||
|
common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); }
|
||||||
|
common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); }
|
||||||
|
common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); }
|
||||||
|
common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); }
|
||||||
|
common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); }
|
||||||
|
common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); }
|
||||||
|
common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_peg_constructed_mapper : public common_chat_peg_mapper {
|
||||||
|
common_chat_tool_call * current_tool;
|
||||||
|
int arg_count = 0;
|
||||||
|
bool needs_closing_quote = false;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||||
|
|
||||||
|
void map(const common_peg_ast_node & node) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline common_peg_arena build_chat_peg_constructed_parser(const std::function<common_peg_parser(common_chat_peg_constructed_builder & builder)> & fn) {
|
||||||
|
common_chat_peg_constructed_builder builder;
|
||||||
|
builder.set_root(fn(builder));
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,234 @@
|
|||||||
|
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "peg-parser.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <chrono>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
struct common_chat_templates;
|
||||||
|
|
||||||
|
struct common_chat_tool_call {
|
||||||
|
std::string name;
|
||||||
|
std::string arguments;
|
||||||
|
std::string id;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_tool_call & other) const {
|
||||||
|
return name == other.name && arguments == other.arguments && id == other.id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_content_part {
|
||||||
|
std::string type;
|
||||||
|
std::string text;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_content_part & other) const {
|
||||||
|
return type == other.type && text == other.text;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
std::vector<common_chat_msg_content_part> content_parts;
|
||||||
|
std::vector<common_chat_tool_call> tool_calls;
|
||||||
|
std::string reasoning_content;
|
||||||
|
std::string tool_name;
|
||||||
|
std::string tool_call_id;
|
||||||
|
|
||||||
|
template <class T> T to_json_oaicompat() const;
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||||
|
}
|
||||||
|
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||||
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||||
|
if (ids_cache.size() <= i) {
|
||||||
|
auto id = tool_calls[i].id;
|
||||||
|
if (id.empty()) {
|
||||||
|
id = gen_tool_call_id();
|
||||||
|
}
|
||||||
|
ids_cache.push_back(id);
|
||||||
|
}
|
||||||
|
tool_calls[i].id = ids_cache[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool operator==(const common_chat_msg & other) const {
|
||||||
|
return role == other.role
|
||||||
|
&& content == other.content
|
||||||
|
&& content_parts == other.content_parts
|
||||||
|
&& tool_calls == other.tool_calls
|
||||||
|
&& reasoning_content == other.reasoning_content
|
||||||
|
&& tool_name == other.tool_name
|
||||||
|
&& tool_call_id == other.tool_call_id;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_chat_msg & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_diff {
|
||||||
|
std::string reasoning_content_delta;
|
||||||
|
std::string content_delta;
|
||||||
|
size_t tool_call_index = std::string::npos;
|
||||||
|
common_chat_tool_call tool_call_delta;
|
||||||
|
|
||||||
|
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new);
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_diff & other) const {
|
||||||
|
return content_delta == other.content_delta
|
||||||
|
&& tool_call_index == other.tool_call_index
|
||||||
|
&& tool_call_delta == other.tool_call_delta;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_tool {
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
std::string parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_tool_choice {
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_format {
|
||||||
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
COMMON_CHAT_FORMAT_GENERIC,
|
||||||
|
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||||
|
COMMON_CHAT_FORMAT_MAGISTRAL,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||||
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||||
|
COMMON_CHAT_FORMAT_GRANITE,
|
||||||
|
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||||
|
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||||
|
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||||
|
COMMON_CHAT_FORMAT_APERTUS,
|
||||||
|
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||||
|
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||||
|
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||||
|
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||||
|
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||||
|
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||||
|
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||||
|
|
||||||
|
// These are intended to be parsed by the PEG parser
|
||||||
|
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||||
|
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||||
|
COMMON_CHAT_FORMAT_PEG_CONSTRUCTED,
|
||||||
|
|
||||||
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_templates_inputs {
|
||||||
|
std::vector<common_chat_msg> messages;
|
||||||
|
std::string grammar;
|
||||||
|
std::string json_schema;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
|
bool use_jinja = true;
|
||||||
|
// Parameters below only supported when use_jinja is true
|
||||||
|
std::vector<common_chat_tool> tools;
|
||||||
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
|
bool parallel_tool_calls = false;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
bool enable_thinking = true;
|
||||||
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
|
std::map<std::string, std::string> chat_template_kwargs;
|
||||||
|
bool add_bos = false;
|
||||||
|
bool add_eos = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_params {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
std::string prompt;
|
||||||
|
std::string grammar;
|
||||||
|
bool grammar_lazy = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
|
std::vector<std::string> preserved_tokens;
|
||||||
|
std::vector<std::string> additional_stops;
|
||||||
|
std::string parser;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_syntax {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||||
|
bool reasoning_in_content = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
|
bool parse_tool_calls = true;
|
||||||
|
common_peg_arena parser = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||||
|
|
||||||
|
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||||
|
|
||||||
|
common_chat_templates_ptr common_chat_templates_init(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override,
|
||||||
|
const std::string & bos_token_override = "",
|
||||||
|
const std::string & eos_token_override = "");
|
||||||
|
|
||||||
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||||
|
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
struct common_chat_params common_chat_templates_apply(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs);
|
||||||
|
|
||||||
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
|
std::string common_chat_format_single(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
|
const common_chat_msg & new_msg,
|
||||||
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
|
// Returns an example of formatted chat
|
||||||
|
std::string common_chat_format_example(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
bool use_jinja,
|
||||||
|
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||||
|
|
||||||
|
const char* common_chat_format_name(common_chat_format format);
|
||||||
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||||
|
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||||
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||||
|
|
||||||
|
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
||||||
|
|
||||||
|
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||||
|
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||||
|
|
||||||
|
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||||
|
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||||
|
|
||||||
|
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,858 @@
|
|||||||
|
// Various helper functions and utilities
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-opt.h"
|
||||||
|
#include "llama-cpp.h"
|
||||||
|
|
||||||
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#if defined(_WIN32) && !defined(_WIN32_WINNT)
|
||||||
|
#define _WIN32_WINNT 0x0A00
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define DIRECTORY_SEPARATOR '\\'
|
||||||
|
#else
|
||||||
|
#define DIRECTORY_SEPARATOR '/'
|
||||||
|
#endif // _WIN32
|
||||||
|
|
||||||
|
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||||
|
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||||
|
|
||||||
|
#define print_build_info() do { \
|
||||||
|
fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
|
||||||
|
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
struct common_time_meas {
|
||||||
|
common_time_meas(int64_t & t_acc, bool disable = false);
|
||||||
|
~common_time_meas();
|
||||||
|
|
||||||
|
const int64_t t_start_us;
|
||||||
|
|
||||||
|
int64_t & t_acc;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_adapter_lora_info {
|
||||||
|
std::string path;
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
std::string task_name;
|
||||||
|
std::string prompt_prefix;
|
||||||
|
|
||||||
|
struct llama_adapter_lora * ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
|
// build info
|
||||||
|
extern int LLAMA_BUILD_NUMBER;
|
||||||
|
extern const char * LLAMA_COMMIT;
|
||||||
|
extern const char * LLAMA_COMPILER;
|
||||||
|
extern const char * LLAMA_BUILD_TARGET;
|
||||||
|
|
||||||
|
struct common_control_vector_load_info;
|
||||||
|
|
||||||
|
//
|
||||||
|
// CPU utils
|
||||||
|
//
|
||||||
|
|
||||||
|
struct cpu_params {
|
||||||
|
int n_threads = -1;
|
||||||
|
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
|
||||||
|
bool mask_valid = false; // Default: any CPU
|
||||||
|
enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
|
||||||
|
bool strict_cpu = false; // Use strict CPU placement
|
||||||
|
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
||||||
|
};
|
||||||
|
|
||||||
|
int32_t cpu_get_num_physical_cores();
|
||||||
|
int32_t cpu_get_num_math();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Common params
|
||||||
|
//
|
||||||
|
|
||||||
|
enum llama_example {
|
||||||
|
LLAMA_EXAMPLE_DEBUG,
|
||||||
|
LLAMA_EXAMPLE_COMMON,
|
||||||
|
LLAMA_EXAMPLE_SPECULATIVE,
|
||||||
|
LLAMA_EXAMPLE_COMPLETION,
|
||||||
|
LLAMA_EXAMPLE_CLI,
|
||||||
|
LLAMA_EXAMPLE_EMBEDDING,
|
||||||
|
LLAMA_EXAMPLE_PERPLEXITY,
|
||||||
|
LLAMA_EXAMPLE_RETRIEVAL,
|
||||||
|
LLAMA_EXAMPLE_PASSKEY,
|
||||||
|
LLAMA_EXAMPLE_IMATRIX,
|
||||||
|
LLAMA_EXAMPLE_BENCH,
|
||||||
|
LLAMA_EXAMPLE_SERVER,
|
||||||
|
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
||||||
|
LLAMA_EXAMPLE_EXPORT_LORA,
|
||||||
|
LLAMA_EXAMPLE_MTMD,
|
||||||
|
LLAMA_EXAMPLE_LOOKUP,
|
||||||
|
LLAMA_EXAMPLE_PARALLEL,
|
||||||
|
LLAMA_EXAMPLE_TTS,
|
||||||
|
LLAMA_EXAMPLE_DIFFUSION,
|
||||||
|
LLAMA_EXAMPLE_FINETUNE,
|
||||||
|
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||||
|
|
||||||
|
LLAMA_EXAMPLE_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_sampler_type {
|
||||||
|
COMMON_SAMPLER_TYPE_NONE = 0,
|
||||||
|
COMMON_SAMPLER_TYPE_DRY = 1,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
||||||
|
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
||||||
|
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
||||||
|
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
||||||
|
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
||||||
|
COMMON_SAMPLER_TYPE_XTC = 8,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL = 9,
|
||||||
|
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
// dimensionality reduction methods, used by cvector-generator
|
||||||
|
enum dimre_method {
|
||||||
|
DIMRE_METHOD_PCA,
|
||||||
|
DIMRE_METHOD_MEAN,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_conversation_mode {
|
||||||
|
COMMON_CONVERSATION_MODE_DISABLED = 0,
|
||||||
|
COMMON_CONVERSATION_MODE_ENABLED = 1,
|
||||||
|
COMMON_CONVERSATION_MODE_AUTO = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_grammar_trigger_type {
|
||||||
|
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||||
|
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||||
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||||
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_grammar_trigger {
|
||||||
|
common_grammar_trigger_type type;
|
||||||
|
std::string value;
|
||||||
|
llama_token token = LLAMA_TOKEN_NULL;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_params_sampling_config : uint64_t {
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// sampling parameters
|
||||||
|
struct common_params_sampling {
|
||||||
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
|
||||||
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
|
float min_p = 0.05f; // 0.0 = disabled
|
||||||
|
float xtc_probability = 0.00f; // 0.0 = disabled
|
||||||
|
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
||||||
|
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
||||||
|
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
|
float dynatemp_range = 0.00f; // 0.0 = disabled
|
||||||
|
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float penalty_repeat = 1.00f; // 1.0 = disabled
|
||||||
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
|
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
||||||
|
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
||||||
|
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
||||||
|
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
||||||
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float top_n_sigma = -1.00f;// -1.0 = disabled
|
||||||
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
bool ignore_eos = false;
|
||||||
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool timing_per_token = false;
|
||||||
|
|
||||||
|
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||||
|
|
||||||
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
std::vector<enum common_sampler_type> samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||||
|
COMMON_SAMPLER_TYPE_DRY,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
|
COMMON_SAMPLER_TYPE_MIN_P,
|
||||||
|
COMMON_SAMPLER_TYPE_XTC,
|
||||||
|
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
bool grammar_lazy = false;
|
||||||
|
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
||||||
|
std::set<llama_token> preserved_tokens;
|
||||||
|
|
||||||
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||||
|
|
||||||
|
bool backend_sampling = false;
|
||||||
|
|
||||||
|
bool has_logit_bias() const {
|
||||||
|
return !logit_bias.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// print the parameters into a string
|
||||||
|
std::string print() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_params_model {
|
||||||
|
std::string path = ""; // model local path // NOLINT
|
||||||
|
std::string url = ""; // model url to download // NOLINT
|
||||||
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
|
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||||
|
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_params_speculative {
|
||||||
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||||
|
|
||||||
|
int32_t n_ctx = 0; // draft context size
|
||||||
|
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||||
|
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||||
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
|
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||||
|
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||||
|
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||||
|
|
||||||
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
|
||||||
|
struct cpu_params cpuparams;
|
||||||
|
struct cpu_params cpuparams_batch;
|
||||||
|
|
||||||
|
struct common_params_model model;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_params_vocoder {
|
||||||
|
struct common_params_model model;
|
||||||
|
|
||||||
|
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||||
|
|
||||||
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_params_diffusion {
|
||||||
|
int32_t steps = 128;
|
||||||
|
bool visual_mode = false;
|
||||||
|
|
||||||
|
float eps = 0; // epsilon for timesteps
|
||||||
|
int32_t block_length = 0; // block length for generation
|
||||||
|
|
||||||
|
int32_t algorithm = 4; // default algorithm: low-confidence
|
||||||
|
float alg_temp = 0.0f; // algorithm temperature
|
||||||
|
|
||||||
|
float cfg_scale = 0; // classifier-free guidance scale
|
||||||
|
bool add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||||
|
enum common_reasoning_format {
|
||||||
|
COMMON_REASONING_FORMAT_NONE,
|
||||||
|
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
||||||
|
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
|
||||||
|
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
|
||||||
|
// do not extend this enum unless you absolutely have to
|
||||||
|
// in most cases, use COMMON_REASONING_FORMAT_AUTO
|
||||||
|
// see: https://github.com/ggml-org/llama.cpp/pull/15408
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct lr_opt {
|
||||||
|
float lr0 = 1e-5; // learning rate at first epoch
|
||||||
|
float lr_min = -1;
|
||||||
|
float decay_epochs = -1; // if >0, the learning rate starts at lr0 and decays to lr_min after this many epochs
|
||||||
|
float scale_epoch = 0;
|
||||||
|
float wd = 0;
|
||||||
|
unsigned epochs = 2;
|
||||||
|
|
||||||
|
unsigned epoch; // set by optimizer outer (epochs) loop
|
||||||
|
// learning rate decay - constant LR per epoch only for now
|
||||||
|
float get_lr(float e) const;
|
||||||
|
float get_lr() const { return get_lr(epoch); }
|
||||||
|
// must call after arg parse, before get_lr
|
||||||
|
void init();
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
|
||||||
|
|
||||||
|
struct common_params {
|
||||||
|
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
|
||||||
|
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
|
||||||
|
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
|
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
|
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||||
|
int32_t n_sequences = 1; // number of sequences to decode
|
||||||
|
int32_t grp_attn_n = 1; // group-attention factor
|
||||||
|
int32_t grp_attn_w = 512; // group-attention width
|
||||||
|
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
||||||
|
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||||
|
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||||
|
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||||
|
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||||
|
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||||
|
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||||
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
|
|
||||||
|
// offload params
|
||||||
|
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||||
|
|
||||||
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
|
||||||
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
|
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
||||||
|
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
|
||||||
|
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
|
||||||
|
|
||||||
|
// margin per device in bytes for fitting parameters to free memory:
|
||||||
|
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);
|
||||||
|
|
||||||
|
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
||||||
|
|
||||||
|
struct cpu_params cpuparams;
|
||||||
|
struct cpu_params cpuparams_batch;
|
||||||
|
|
||||||
|
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||||
|
void * cb_eval_user_data = nullptr;
|
||||||
|
|
||||||
|
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
||||||
|
|
||||||
|
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
||||||
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||||
|
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||||
|
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
|
||||||
|
|
||||||
|
struct common_params_sampling sampling;
|
||||||
|
struct common_params_speculative speculative;
|
||||||
|
struct common_params_vocoder vocoder;
|
||||||
|
struct common_params_diffusion diffusion;
|
||||||
|
|
||||||
|
struct common_params_model model;
|
||||||
|
|
||||||
|
std::string model_alias = ""; // model alias // NOLINT
|
||||||
|
std::string hf_token = ""; // HF token // NOLINT
|
||||||
|
std::string prompt = ""; // NOLINT
|
||||||
|
std::string system_prompt = ""; // NOLINT
|
||||||
|
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
||||||
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||||
|
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||||
|
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||||
|
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||||
|
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||||
|
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||||
|
|
||||||
|
// llama-debug specific options
|
||||||
|
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
|
||||||
|
bool save_logits = false; // whether to save logits to files // NOLINT
|
||||||
|
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
|
||||||
|
|
||||||
|
std::vector<std::string> in_files; // all input files
|
||||||
|
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
||||||
|
std::vector<llama_model_kv_override> kv_overrides;
|
||||||
|
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||||
|
|
||||||
|
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
|
||||||
|
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
|
||||||
|
|
||||||
|
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
||||||
|
|
||||||
|
int32_t verbosity = 3; // LOG_LEVEL_INFO
|
||||||
|
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||||
|
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||||
|
bool offline = false;
|
||||||
|
|
||||||
|
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||||
|
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||||
|
// (which is more convenient to use for plotting)
|
||||||
|
//
|
||||||
|
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
|
||||||
|
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
||||||
|
|
||||||
|
bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
|
||||||
|
size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
|
||||||
|
|
||||||
|
bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
|
||||||
|
size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
|
||||||
|
|
||||||
|
bool kl_divergence = false; // compute KL divergence
|
||||||
|
|
||||||
|
bool usage = false; // print usage
|
||||||
|
bool completion = false; // print source-able completion script
|
||||||
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
|
bool special = false; // enable special token output
|
||||||
|
bool interactive = false; // interactive mode
|
||||||
|
bool interactive_first = false; // wait for user input immediately
|
||||||
|
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
||||||
|
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
||||||
|
|
||||||
|
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
|
||||||
|
bool multiline_input = false; // reverse the usage of `\`
|
||||||
|
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||||
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool show_timings = true; // show timing information on CLI
|
||||||
|
bool ctx_shift = false; // context shift on infinite text generation
|
||||||
|
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||||
|
bool kv_unified = false; // enable unified KV cache
|
||||||
|
|
||||||
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
|
bool use_mmap = true; // enable mmap to use filesystem cache
|
||||||
|
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
||||||
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
|
bool display_prompt = true; // print prompt before generation
|
||||||
|
bool no_kv_offload = false; // disable KV offloading
|
||||||
|
bool warmup = true; // warmup run
|
||||||
|
bool check_tensors = false; // validate tensor data
|
||||||
|
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||||
|
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||||
|
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
||||||
|
|
||||||
|
bool single_turn = false; // single turn chat conversation
|
||||||
|
|
||||||
|
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||||
|
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||||
|
|
||||||
|
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
||||||
|
|
||||||
|
// multimodal models (see tools/mtmd)
|
||||||
|
struct common_params_model mmproj;
|
||||||
|
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||||
|
bool no_mmproj = false; // explicitly disable multimodal model
|
||||||
|
std::vector<std::string> image; // path to image file(s)
|
||||||
|
int image_min_tokens = -1;
|
||||||
|
int image_max_tokens = -1;
|
||||||
|
|
||||||
|
// finetune
|
||||||
|
struct lr_opt lr;
|
||||||
|
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||||
|
float val_split = 0.05f; // fraction of the data used for the validation set
|
||||||
|
|
||||||
|
// embedding
|
||||||
|
bool embedding = false; // get only sentence embedding
|
||||||
|
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||||
|
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||||
|
std::string embd_sep = "\n"; // separator of embeddings
|
||||||
|
std::string cls_sep = "\t"; // separator of classification sequences
|
||||||
|
|
||||||
|
// server params
|
||||||
|
int32_t port = 8080; // server listens on this network port
|
||||||
|
int32_t timeout_read = 600; // http read timeout in seconds
|
||||||
|
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||||
|
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||||
|
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||||
|
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||||
|
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||||
|
|
||||||
|
std::string hostname = "127.0.0.1";
|
||||||
|
std::string public_path = ""; // NOLINT
|
||||||
|
std::string api_prefix = ""; // NOLINT
|
||||||
|
std::string chat_template = ""; // NOLINT
|
||||||
|
bool use_jinja = true; // NOLINT
|
||||||
|
bool enable_chat_template = true;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||||
|
int reasoning_budget = -1;
|
||||||
|
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||||
|
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||||
|
|
||||||
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
|
std::string ssl_file_key = ""; // NOLINT
|
||||||
|
std::string ssl_file_cert = ""; // NOLINT
|
||||||
|
|
||||||
|
std::map<std::string, std::string> default_template_kwargs;
|
||||||
|
|
||||||
|
// webui configs
|
||||||
|
bool webui = true;
|
||||||
|
std::string webui_config_json;
|
||||||
|
|
||||||
|
// "advanced" endpoints are disabled by default for better security
|
||||||
|
bool endpoint_slots = true;
|
||||||
|
bool endpoint_props = false; // only control POST requests, not GET
|
||||||
|
bool endpoint_metrics = false;
|
||||||
|
|
||||||
|
// router server configs
|
||||||
|
std::string models_dir = ""; // directory containing models for the router server
|
||||||
|
std::string models_preset = ""; // directory containing model presets for the router server
|
||||||
|
int models_max = 4; // maximum number of models to load simultaneously
|
||||||
|
bool models_autoload = true; // automatically load models when requested via the router server
|
||||||
|
|
||||||
|
bool log_json = false;
|
||||||
|
|
||||||
|
std::string slot_save_path;
|
||||||
|
std::string media_path; // path to directory for loading media files
|
||||||
|
|
||||||
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
|
// batched-bench params
|
||||||
|
bool is_pp_shared = false;
|
||||||
|
bool is_tg_separate = false;
|
||||||
|
|
||||||
|
std::vector<int32_t> n_pp;
|
||||||
|
std::vector<int32_t> n_tg;
|
||||||
|
std::vector<int32_t> n_pl;
|
||||||
|
|
||||||
|
// retrieval params
|
||||||
|
std::vector<std::string> context_files; // context files to embed
|
||||||
|
|
||||||
|
int32_t chunk_size = 64; // chunk size for context embedding
|
||||||
|
|
||||||
|
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
||||||
|
|
||||||
|
// passkey params
|
||||||
|
int32_t n_junk = 250; // number of times to repeat the junk text
|
||||||
|
int32_t i_pos = -1; // position of the passkey in the junk text
|
||||||
|
|
||||||
|
// imatrix params
|
||||||
|
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
||||||
|
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
||||||
|
int32_t i_chunk = 0; // start processing from this chunk
|
||||||
|
int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat)
|
||||||
|
|
||||||
|
bool process_output = false; // collect data for the output tensor
|
||||||
|
bool compute_ppl = true; // whether to compute perplexity
|
||||||
|
bool show_statistics = false; // show imatrix statistics per tensor
|
||||||
|
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
|
||||||
|
|
||||||
|
// cvector-generator params
|
||||||
|
int n_pca_batch = 100;
|
||||||
|
int n_pca_iterations = 1000;
|
||||||
|
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
||||||
|
std::string cvector_positive_file = "tools/cvector-generator/positive.txt";
|
||||||
|
std::string cvector_negative_file = "tools/cvector-generator/negative.txt";
|
||||||
|
|
||||||
|
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
||||||
|
|
||||||
|
// batched-bench params
|
||||||
|
bool batched_bench_output_jsonl = false;
|
||||||
|
|
||||||
|
// common params
|
||||||
|
std::string out_file; // output filename for all example programs
|
||||||
|
// optional callback for model loading progress and cancellation:
|
||||||
|
// called with a progress value between 0.0 and 1.0.
|
||||||
|
// return false from callback to abort model loading or true to continue
|
||||||
|
llama_progress_callback load_progress_callback = NULL;
|
||||||
|
void * load_progress_callback_user_data = NULL;
|
||||||
|
|
||||||
|
bool has_speculative() const {
|
||||||
|
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// call once at the start of a program if it uses libcommon
|
||||||
|
// initializes the logging system and prints info about the build
|
||||||
|
void common_init();
|
||||||
|
|
||||||
|
std::string common_params_get_system_info(const common_params & params);
|
||||||
|
|
||||||
|
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||||
|
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
||||||
|
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
|
||||||
|
bool set_process_priority(enum ggml_sched_priority prio);
|
||||||
|
|
||||||
|
//
|
||||||
|
// String utils
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifdef __GNUC__
|
||||||
|
# if defined(__MINGW32__) && !defined(__clang__)
|
||||||
|
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
|
# else
|
||||||
|
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
|
std::string string_format(const char * fmt, ...);
|
||||||
|
|
||||||
|
std::string string_strip(const std::string & str);
|
||||||
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
|
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
||||||
|
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
||||||
|
std::string string_repeat(const std::string & str, size_t n);
|
||||||
|
|
||||||
|
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||||
|
|
||||||
|
std::string regex_escape(const std::string & s);
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
|
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||||
|
std::vector<T> values;
|
||||||
|
std::istringstream str_stream(str);
|
||||||
|
std::string token;
|
||||||
|
while (std::getline(str_stream, token, delim)) {
|
||||||
|
T value;
|
||||||
|
std::istringstream token_stream(token);
|
||||||
|
token_stream >> value;
|
||||||
|
values.push_back(value);
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||||
|
{
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t begin_pos = 0;
|
||||||
|
size_t separator_pos = input.find(separator);
|
||||||
|
while (separator_pos != std::string::npos) {
|
||||||
|
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||||
|
parts.emplace_back(part);
|
||||||
|
begin_pos = separator_pos + 1;
|
||||||
|
separator_pos = input.find(separator, begin_pos);
|
||||||
|
}
|
||||||
|
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool string_starts_with(const std::string & str,
|
||||||
|
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
||||||
|
return str.rfind(prefix, 0) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// While we wait for C++20's std::string::ends_with...
|
||||||
|
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||||
|
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||||
|
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||||
|
|
||||||
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
std::string string_from(bool value);
|
||||||
|
std::string string_from(const std::vector<int> & values);
|
||||||
|
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
||||||
|
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Filesystem utils
|
||||||
|
//
|
||||||
|
|
||||||
|
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
|
||||||
|
bool fs_create_directory_with_parents(const std::string & path);
|
||||||
|
bool fs_is_directory(const std::string & path);
|
||||||
|
|
||||||
|
std::string fs_get_cache_directory();
|
||||||
|
std::string fs_get_cache_file(const std::string & filename);
|
||||||
|
|
||||||
|
struct common_file_info {
|
||||||
|
std::string path;
|
||||||
|
std::string name;
|
||||||
|
size_t size = 0; // in bytes
|
||||||
|
bool is_dir = false;
|
||||||
|
};
|
||||||
|
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||||
|
|
||||||
|
//
|
||||||
|
// TTY utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// Auto-detect if colors can be enabled based on terminal and environment
|
||||||
|
bool tty_can_use_colors();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Model utils
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_sampler;
|
||||||
|
|
||||||
|
// note: defines the model, context, samplers, ets. lifetimes
|
||||||
|
struct common_init_result {
|
||||||
|
common_init_result(common_params & params);
|
||||||
|
~common_init_result();
|
||||||
|
|
||||||
|
llama_model * model();
|
||||||
|
llama_context * context();
|
||||||
|
|
||||||
|
common_sampler * sampler(llama_seq_id seq_id);
|
||||||
|
void reset_samplers();
|
||||||
|
|
||||||
|
std::vector<llama_adapter_lora_ptr> & lora();
|
||||||
|
|
||||||
|
void free_context();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
std::unique_ptr<impl> pimpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||||
|
|
||||||
|
common_init_result_ptr common_init_from_params(common_params & params);
|
||||||
|
|
||||||
|
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||||
|
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||||
|
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
||||||
|
|
||||||
|
// clear LoRA adapters from context, then apply new list of adapters
|
||||||
|
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
||||||
|
|
||||||
|
std::string get_model_endpoint();
|
||||||
|
|
||||||
|
//
|
||||||
|
// Batch utils
|
||||||
|
//
|
||||||
|
|
||||||
|
void common_batch_clear(struct llama_batch & batch);
|
||||||
|
|
||||||
|
void common_batch_add(
|
||||||
|
struct llama_batch & batch,
|
||||||
|
llama_token id,
|
||||||
|
llama_pos pos,
|
||||||
|
const std::vector<llama_seq_id> & seq_ids,
|
||||||
|
bool logits);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Token utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// longest common prefix
|
||||||
|
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
||||||
|
|
||||||
|
// longet common subsequence
|
||||||
|
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Vocab utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// tokenizes a string into a vector of tokens
|
||||||
|
// should work similar to Python's `tokenizer.encode`
|
||||||
|
std::vector<llama_token> common_tokenize(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
const std::string & text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special = false);
|
||||||
|
|
||||||
|
std::vector<llama_token> common_tokenize(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const std::string & text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special = false);
|
||||||
|
|
||||||
|
// tokenizes a token into a piece, optionally renders special/control tokens
|
||||||
|
// should work similar to Python's `tokenizer.id_to_piece`
|
||||||
|
std::string common_token_to_piece(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
llama_token token,
|
||||||
|
bool special = true);
|
||||||
|
|
||||||
|
std::string common_token_to_piece(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
llama_token token,
|
||||||
|
bool special = true);
|
||||||
|
|
||||||
|
// detokenizes a vector of tokens into a string
|
||||||
|
// should work similar to Python's `tokenizer.decode`
|
||||||
|
// optionally renders special/control tokens
|
||||||
|
std::string common_detokenize(
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special = true);
|
||||||
|
|
||||||
|
std::string common_detokenize(
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const std::vector<llama_token> & tokens,
|
||||||
|
bool special = true);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Embedding utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO: repace embd_norm with an enum
|
||||||
|
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
||||||
|
|
||||||
|
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Control vector utils
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_control_vector_data {
|
||||||
|
int n_embd;
|
||||||
|
|
||||||
|
// stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
|
||||||
|
std::vector<float> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_control_vector_load_info {
|
||||||
|
float strength;
|
||||||
|
|
||||||
|
std::string fname;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Load control vectors, scale each by strength, and add them together.
|
||||||
|
// On error, returns {-1, empty}
|
||||||
|
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Split utils
|
||||||
|
//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const char * const LLM_KV_SPLIT_NO = "split.no";
|
||||||
|
const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
||||||
|
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// MoE utils
|
||||||
|
//
|
||||||
|
|
||||||
|
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||||
|
|
||||||
|
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||||
|
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||||
|
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// training utils
|
||||||
|
//
|
||||||
|
|
||||||
|
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||||
|
|
||||||
|
// "adamw" or "sgd" (case insensitive)
|
||||||
|
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,41 @@
|
|||||||
|
// Console functions
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
enum display_type {
|
||||||
|
DISPLAY_TYPE_RESET = 0,
|
||||||
|
DISPLAY_TYPE_INFO,
|
||||||
|
DISPLAY_TYPE_PROMPT,
|
||||||
|
DISPLAY_TYPE_REASONING,
|
||||||
|
DISPLAY_TYPE_USER_INPUT,
|
||||||
|
DISPLAY_TYPE_ERROR
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace console {
|
||||||
|
void init(bool use_simple_io, bool use_advanced_display);
|
||||||
|
void cleanup();
|
||||||
|
void set_display(display_type display);
|
||||||
|
bool readline(std::string & line, bool multiline_input);
|
||||||
|
|
||||||
|
namespace spinner {
|
||||||
|
void start();
|
||||||
|
void stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: the logging API below output directly to stdout
|
||||||
|
// it can negatively impact performance if used on inference thread
|
||||||
|
// only use in in a dedicated CLI thread
|
||||||
|
// for logging in inference thread, use log.h instead
|
||||||
|
|
||||||
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
|
void log(const char * fmt, ...);
|
||||||
|
|
||||||
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
|
void error(const char * fmt, ...);
|
||||||
|
|
||||||
|
void flush();
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,84 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct common_params_model;
|
||||||
|
|
||||||
|
using common_header = std::pair<std::string, std::string>;
|
||||||
|
using common_header_list = std::vector<common_header>;
|
||||||
|
|
||||||
|
struct common_remote_params {
|
||||||
|
common_header_list headers;
|
||||||
|
long timeout = 0; // in seconds, 0 means no timeout
|
||||||
|
long max_size = 0; // unlimited if 0
|
||||||
|
};
|
||||||
|
|
||||||
|
// get remote file content, returns <http_code, raw_response_body>
|
||||||
|
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
|
||||||
|
|
||||||
|
// split HF repo with tag into <repo, tag>
|
||||||
|
// for example: "user/model:tag" -> <"user/model", "tag">
|
||||||
|
// if tag is not present, default to "latest"
|
||||||
|
// example: "user/model" -> <"user/model", "latest">
|
||||||
|
std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag);
|
||||||
|
|
||||||
|
struct common_cached_model_info {
|
||||||
|
std::string manifest_path;
|
||||||
|
std::string user;
|
||||||
|
std::string model;
|
||||||
|
std::string tag;
|
||||||
|
size_t size = 0; // GGUF size in bytes
|
||||||
|
// return string representation like "user/model:tag"
|
||||||
|
// if tag is "latest", it will be omitted
|
||||||
|
std::string to_string() const {
|
||||||
|
return user + "/" + model + (tag == "latest" ? "" : ":" + tag);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_hf_file_res {
|
||||||
|
std::string repo; // repo name with ":tag" removed
|
||||||
|
std::string ggufFile;
|
||||||
|
std::string mmprojFile;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
||||||
|
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
||||||
|
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
||||||
|
*
|
||||||
|
* Return pair of <repo, file> (with "repo" already having tag removed)
|
||||||
|
*
|
||||||
|
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
||||||
|
*/
|
||||||
|
common_hf_file_res common_get_hf_file(
|
||||||
|
const std::string & hf_repo_with_tag,
|
||||||
|
const std::string & bearer_token,
|
||||||
|
bool offline,
|
||||||
|
const common_header_list & headers = {}
|
||||||
|
);
|
||||||
|
|
||||||
|
// returns true if download succeeded
|
||||||
|
bool common_download_model(
|
||||||
|
const common_params_model & model,
|
||||||
|
const std::string & bearer_token,
|
||||||
|
bool offline,
|
||||||
|
const common_header_list & headers = {}
|
||||||
|
);
|
||||||
|
|
||||||
|
// returns list of cached models
|
||||||
|
std::vector<common_cached_model_info> common_list_cached_models();
|
||||||
|
|
||||||
|
// download single file from url to local path
|
||||||
|
// returns status code or -1 on error
|
||||||
|
int common_download_file_single(const std::string & url,
|
||||||
|
const std::string & path,
|
||||||
|
const std::string & bearer_token,
|
||||||
|
bool offline,
|
||||||
|
const common_header_list & headers = {});
|
||||||
|
|
||||||
|
// resolve and download model from Docker registry
|
||||||
|
// return local path to downloaded model file
|
||||||
|
std::string common_docker_resolve_model(const std::string & docker);
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cpp-httplib/httplib.h>
|
||||||
|
|
||||||
|
struct common_http_url {
|
||||||
|
std::string scheme;
|
||||||
|
std::string user;
|
||||||
|
std::string password;
|
||||||
|
std::string host;
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
|
static common_http_url common_http_parse_url(const std::string & url) {
|
||||||
|
common_http_url parts;
|
||||||
|
auto scheme_end = url.find("://");
|
||||||
|
|
||||||
|
if (scheme_end == std::string::npos) {
|
||||||
|
throw std::runtime_error("invalid URL: no scheme");
|
||||||
|
}
|
||||||
|
parts.scheme = url.substr(0, scheme_end);
|
||||||
|
|
||||||
|
if (parts.scheme != "http" && parts.scheme != "https") {
|
||||||
|
throw std::runtime_error("unsupported URL scheme: " + parts.scheme);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rest = url.substr(scheme_end + 3);
|
||||||
|
auto at_pos = rest.find('@');
|
||||||
|
|
||||||
|
if (at_pos != std::string::npos) {
|
||||||
|
auto auth = rest.substr(0, at_pos);
|
||||||
|
auto colon_pos = auth.find(':');
|
||||||
|
if (colon_pos != std::string::npos) {
|
||||||
|
parts.user = auth.substr(0, colon_pos);
|
||||||
|
parts.password = auth.substr(colon_pos + 1);
|
||||||
|
} else {
|
||||||
|
parts.user = auth;
|
||||||
|
}
|
||||||
|
rest = rest.substr(at_pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto slash_pos = rest.find('/');
|
||||||
|
|
||||||
|
if (slash_pos != std::string::npos) {
|
||||||
|
parts.host = rest.substr(0, slash_pos);
|
||||||
|
parts.path = rest.substr(slash_pos);
|
||||||
|
} else {
|
||||||
|
parts.host = rest;
|
||||||
|
parts.path = "/";
|
||||||
|
}
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<httplib::Client, common_http_url> common_http_client(const std::string & url) {
|
||||||
|
common_http_url parts = common_http_parse_url(url);
|
||||||
|
|
||||||
|
if (parts.host.empty()) {
|
||||||
|
throw std::runtime_error("error: invalid URL format");
|
||||||
|
}
|
||||||
|
|
||||||
|
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||||
|
|
||||||
|
if (!parts.user.empty()) {
|
||||||
|
cli.set_basic_auth(parts.user, parts.password);
|
||||||
|
}
|
||||||
|
|
||||||
|
cli.set_follow_location(true);
|
||||||
|
|
||||||
|
return { std::move(cli), std::move(parts) };
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||||
|
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||||
|
}
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
#include "json-partial.h"
|
||||||
|
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum common_json_stack_element_type {
|
||||||
|
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_json_stack_element {
|
||||||
|
common_json_stack_element_type type;
|
||||||
|
std::string key;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
std::string::const_iterator it = input.begin();
|
||||||
|
const auto end = input.end();
|
||||||
|
return common_json_parse(it, end, healing_marker, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||||
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||||
|
std::size_t position;
|
||||||
|
bool found_error;
|
||||||
|
std::string last_token;
|
||||||
|
std::string exception_message;
|
||||||
|
std::vector<common_json_stack_element> stack;
|
||||||
|
|
||||||
|
json_error_locator() : position(0), found_error(false) {}
|
||||||
|
|
||||||
|
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||||
|
this->position = position - 1;
|
||||||
|
this->found_error = true;
|
||||||
|
this->last_token = last_token;
|
||||||
|
this->exception_message = ex.what();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void close_value() {
|
||||||
|
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||||
|
stack.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool null() override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool boolean(bool) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_integer(number_integer_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool string(string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool binary(binary_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_object(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_object() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool key(string_t & key) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_array(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_array() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
json_error_locator err_loc;
|
||||||
|
auto start = it;
|
||||||
|
json::sax_parse(it, end, &err_loc);
|
||||||
|
|
||||||
|
if (err_loc.found_error) {
|
||||||
|
it = start;
|
||||||
|
auto temptative_end = it + err_loc.position;
|
||||||
|
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||||
|
|
||||||
|
auto input = std::string(it, temptative_end);
|
||||||
|
try {
|
||||||
|
out.json = json::parse(input);
|
||||||
|
// out.json = json::parse(it, temptative_end);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & ex) {
|
||||||
|
// No, needs healing.
|
||||||
|
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||||
|
}
|
||||||
|
auto can_parse = [](const std::string & str) {
|
||||||
|
try {
|
||||||
|
auto _ = json::parse(str); // NOLINT
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||||
|
if (last_non_sp_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
auto last_non_sp_char = str[last_non_sp_pos];
|
||||||
|
// Used to detect stops on a number, which may not be complete.
|
||||||
|
auto was_maybe_number = [&]() {
|
||||||
|
if (!str.empty() && std::isspace(str.back())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::isdigit(last_non_sp_char) ||
|
||||||
|
last_non_sp_char == '.' ||
|
||||||
|
last_non_sp_char == 'e' ||
|
||||||
|
last_non_sp_char == 'E' ||
|
||||||
|
last_non_sp_char == '-';
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string closing;
|
||||||
|
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||||
|
auto & el = err_loc.stack[i - 1];
|
||||||
|
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
closing += "}";
|
||||||
|
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
closing += "]";
|
||||||
|
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
throw std::runtime_error("Unexpected stack element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||||
|
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||||
|
|
||||||
|
auto is_high_surrogate = [&](const std::string & s) {
|
||||||
|
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||||
|
return s.length() >= 4 &&
|
||||||
|
s[0] == '\\' && s[1] == 'u' &&
|
||||||
|
std::tolower(s[2]) == 'd' &&
|
||||||
|
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||||
|
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||||
|
// backslash (\)
|
||||||
|
std::string unicode_marker_padding = "udc00";
|
||||||
|
std::smatch last_unicode_seq;
|
||||||
|
|
||||||
|
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||||
|
std::smatch second_last_seq;
|
||||||
|
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||||
|
|
||||||
|
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||||
|
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||||
|
|
||||||
|
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||||
|
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||||
|
unicode_marker_padding += "\\udc00";
|
||||||
|
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||||
|
if (is_high_surrogate(second_last_seq.str())) {
|
||||||
|
// If this follows a high surrogate, pad it to be a low surrogate
|
||||||
|
if (last_unicode_seq.length() == 2) {
|
||||||
|
unicode_marker_padding = "dc00";
|
||||||
|
} else if (last_unicode_seq.length() == 3) {
|
||||||
|
unicode_marker_padding = "c00";
|
||||||
|
} else {
|
||||||
|
// The original unicode_marker_padding is already padded with 0s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||||
|
|
||||||
|
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
// We're inside an object value
|
||||||
|
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an object value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + ": 1" + closing)) {
|
||||||
|
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||||
|
// Was about to create an object
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an object value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an object value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||||
|
// Was inside an object value string after a partial unicode escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
// find last :
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to opening : for object value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an array value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an array value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an array value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||||
|
// Was inside an array value string after a partial unicode escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||||
|
// Had just finished a value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of("[,");
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to last [ or , for array value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||||
|
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\": 1" + closing)) {
|
||||||
|
// Was inside an object key string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||||
|
// Was inside an object key string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||||
|
// Was inside an object key string after a partial unicode escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// handle unclosed top-level primitive
|
||||||
|
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||||
|
if (can_parse(str + "\"")) {
|
||||||
|
// Was inside an string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||||
|
// Was inside an string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||||
|
} else {
|
||||||
|
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||||
|
// fprintf(stderr, "Closing: TODO\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.json = json::parse(it, end);
|
||||||
|
it = end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
|
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||||
|
struct common_healing_marker {
|
||||||
|
// Raw marker.
|
||||||
|
std::string marker;
|
||||||
|
|
||||||
|
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||||
|
std::string json_dump_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||||
|
struct common_json {
|
||||||
|
nlohmann::ordered_json json;
|
||||||
|
|
||||||
|
common_healing_marker healing_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||||
|
//
|
||||||
|
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||||
|
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||||
|
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||||
|
//
|
||||||
|
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
||||||
|
|
||||||
|
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,43 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json_fwd.hpp>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
||||||
|
bool force_gbnf = false);
|
||||||
|
|
||||||
|
class common_schema_converter;
|
||||||
|
|
||||||
|
// Probes a JSON schema to extract information about its structure and type constraints.
|
||||||
|
class common_schema_info {
|
||||||
|
std::unique_ptr<common_schema_converter> impl_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_schema_info();
|
||||||
|
~common_schema_info();
|
||||||
|
|
||||||
|
common_schema_info(const common_schema_info &) = delete;
|
||||||
|
common_schema_info & operator=(const common_schema_info &) = delete;
|
||||||
|
common_schema_info(common_schema_info &&) noexcept;
|
||||||
|
common_schema_info & operator=(common_schema_info &&) noexcept;
|
||||||
|
|
||||||
|
void resolve_refs(nlohmann::ordered_json & schema);
|
||||||
|
bool resolves_to_string(const nlohmann::ordered_json & schema);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_grammar_builder {
|
||||||
|
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||||
|
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||||
|
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_grammar_options {
|
||||||
|
bool dotall = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string gbnf_format_literal(const std::string & literal);
|
||||||
|
|
||||||
|
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||||
@@ -0,0 +1,258 @@
|
|||||||
|
#include "sampling.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
|
||||||
|
# include "llguidance.h"
|
||||||
|
# include <cmath>
|
||||||
|
|
||||||
|
struct llama_sampler_llg {
|
||||||
|
const llama_vocab * vocab;
|
||||||
|
std::string grammar_kind;
|
||||||
|
std::string grammar_data;
|
||||||
|
LlgTokenizer * tokenizer;
|
||||||
|
LlgMatcher * grammar;
|
||||||
|
};
|
||||||
|
|
||||||
|
static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
||||||
|
const char * grammar_data) {
|
||||||
|
LlgConstraintInit cinit;
|
||||||
|
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
||||||
|
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
||||||
|
if (log_level && *log_level) {
|
||||||
|
cinit.log_stderr_level = atoi(log_level);
|
||||||
|
}
|
||||||
|
auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
|
||||||
|
if (llg_matcher_get_error(c)) {
|
||||||
|
LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
|
||||||
|
llg_free_matcher(c);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
||||||
|
return "llguidance";
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
||||||
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||||
|
if (ctx->grammar) {
|
||||||
|
llg_matcher_consume_token(ctx->grammar, token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||||
|
if (ctx->grammar) {
|
||||||
|
const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
|
||||||
|
if (mask == nullptr) {
|
||||||
|
if (llg_matcher_compute_mask(ctx->grammar) == 0) {
|
||||||
|
mask = llg_matcher_get_mask(ctx->grammar);
|
||||||
|
} else {
|
||||||
|
LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
|
||||||
|
llg_free_matcher(ctx->grammar);
|
||||||
|
ctx->grammar = nullptr;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
|
auto token = cur_p->data[i].id;
|
||||||
|
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
||||||
|
cur_p->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||||
|
if (ctx->grammar) {
|
||||||
|
llg_matcher_reset(ctx->grammar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
||||||
|
|
||||||
|
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
||||||
|
|
||||||
|
// copy the state
|
||||||
|
{
|
||||||
|
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
||||||
|
|
||||||
|
if (ctx->grammar) {
|
||||||
|
result_ctx->grammar_kind = ctx->grammar_kind;
|
||||||
|
result_ctx->grammar_data = ctx->grammar_data;
|
||||||
|
result_ctx->grammar = llg_clone_matcher(ctx->grammar);
|
||||||
|
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
||||||
|
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx->grammar) {
|
||||||
|
llg_free_matcher(ctx->grammar);
|
||||||
|
llg_free_tokenizer(ctx->tokenizer);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_sampler_i llama_sampler_llg_i = {
|
||||||
|
/* .name = */ llama_sampler_llg_name,
|
||||||
|
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||||
|
/* .apply = */ llama_sampler_llg_apply,
|
||||||
|
/* .reset = */ llama_sampler_llg_reset,
|
||||||
|
/* .clone = */ llama_sampler_llg_clone,
|
||||||
|
/* .free = */ llama_sampler_llg_free,
|
||||||
|
/* .backend_init = */ NULL,
|
||||||
|
/* .backend_accept = */ NULL,
|
||||||
|
/* .backend_apply = */ NULL,
|
||||||
|
/* .backend_set_input = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||||
|
uint32_t * output_tokens, size_t output_tokens_len) {
|
||||||
|
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
||||||
|
int r = 0;
|
||||||
|
try {
|
||||||
|
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
||||||
|
true);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
||||||
|
}
|
||||||
|
if (r < 0) {
|
||||||
|
return -r;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
||||||
|
// TODO store the tokenizer in the vocab somehow
|
||||||
|
static const llama_vocab * vocab_cache;
|
||||||
|
static LlgTokenizer * tokenizer_cache;
|
||||||
|
|
||||||
|
if (vocab_cache == vocab) {
|
||||||
|
return llg_clone_tokenizer(tokenizer_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tok_eos = llama_vocab_eot(vocab);
|
||||||
|
if (tok_eos == LLAMA_TOKEN_NULL) {
|
||||||
|
tok_eos = llama_vocab_eos(vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
|
auto token_lens = new uint32_t[vocab_size];
|
||||||
|
// we typically have ~7 bytes per token; let's go on the safe side here
|
||||||
|
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
||||||
|
auto token_bytes = new uint8_t[token_bytes_size];
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
for (size_t i = 0; i < vocab_size; i++) {
|
||||||
|
size_t max_token = 1024;
|
||||||
|
if (token_bytes_size - offset < max_token) {
|
||||||
|
GGML_ABORT("token_bytes buffer too small\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token token = i;
|
||||||
|
auto dp = (char *) token_bytes + offset;
|
||||||
|
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
||||||
|
if (size < 0) {
|
||||||
|
GGML_ABORT("llama_detokenize failed\n");
|
||||||
|
}
|
||||||
|
if (size == 0) {
|
||||||
|
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
||||||
|
if (size < 0) {
|
||||||
|
GGML_ABORT("llama_detokenize failed\n");
|
||||||
|
}
|
||||||
|
if (size != 0) {
|
||||||
|
*dp = '\xff'; // special token prefix marker
|
||||||
|
size += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token_lens[i] = size;
|
||||||
|
offset += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
LlgTokenizerInit tinit = {
|
||||||
|
/* .vocab_size = */ (uint32_t) vocab_size,
|
||||||
|
/* .tok_eos = */ (uint32_t) tok_eos,
|
||||||
|
/* .token_lens = */ token_lens,
|
||||||
|
/* .token_bytes = */ token_bytes,
|
||||||
|
/* .tokenizer_json = */ nullptr,
|
||||||
|
/* .tokenize_assumes_string = */ true,
|
||||||
|
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
||||||
|
/* .use_approximate_greedy_tokenize_fn = */ false,
|
||||||
|
/* .tokenize_user_data = */ vocab,
|
||||||
|
/* .slices = */ nullptr,
|
||||||
|
};
|
||||||
|
|
||||||
|
char error_buffer[1024];
|
||||||
|
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
||||||
|
|
||||||
|
delete[] token_bytes;
|
||||||
|
delete[] token_lens;
|
||||||
|
|
||||||
|
if (tokenizer == nullptr) {
|
||||||
|
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
||||||
|
return tokenizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tokenizer_cache) {
|
||||||
|
llg_free_tokenizer(tokenizer_cache);
|
||||||
|
}
|
||||||
|
vocab_cache = vocab;
|
||||||
|
tokenizer_cache = tokenizer;
|
||||||
|
|
||||||
|
return llg_clone_tokenizer(tokenizer_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
||||||
|
const char * grammar_data) {
|
||||||
|
auto * ctx = new llama_sampler_llg;
|
||||||
|
|
||||||
|
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
||||||
|
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
||||||
|
*ctx = {
|
||||||
|
/* .vocab = */ vocab,
|
||||||
|
/* .grammar_kind = */ grammar_kind,
|
||||||
|
/* .grammar_data = */ grammar_data,
|
||||||
|
/* .tokenizer = */ tokenizer,
|
||||||
|
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
||||||
|
};
|
||||||
|
if (ctx->grammar) {
|
||||||
|
GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
|
||||||
|
llg_matcher_get_mask_byte_size(ctx->grammar));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*ctx = {
|
||||||
|
/* .vocab = */ vocab,
|
||||||
|
/* .grammar_kind = */ {},
|
||||||
|
/* .grammar_data = */ {},
|
||||||
|
/* .tokenizer = */ nullptr,
|
||||||
|
/* .grammar = */ nullptr,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return llama_sampler_init(
|
||||||
|
/* .iface = */ &llama_sampler_llg_i,
|
||||||
|
/* .ctx = */ ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
||||||
|
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
@@ -0,0 +1,446 @@
|
|||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <cstdarg>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <mutex>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if defined(_WIN32)
|
||||||
|
# include <io.h>
|
||||||
|
# include <windows.h>
|
||||||
|
# define isatty _isatty
|
||||||
|
# define fileno _fileno
|
||||||
|
#else
|
||||||
|
# include <unistd.h>
|
||||||
|
#endif // defined(_WIN32)
|
||||||
|
|
||||||
|
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
||||||
|
|
||||||
|
void common_log_set_verbosity_thold(int verbosity) {
|
||||||
|
common_log_verbosity_thold = verbosity;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t t_us() {
|
||||||
|
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
// colors
|
||||||
|
enum common_log_col : int {
|
||||||
|
COMMON_LOG_COL_DEFAULT = 0,
|
||||||
|
COMMON_LOG_COL_BOLD,
|
||||||
|
COMMON_LOG_COL_RED,
|
||||||
|
COMMON_LOG_COL_GREEN,
|
||||||
|
COMMON_LOG_COL_YELLOW,
|
||||||
|
COMMON_LOG_COL_BLUE,
|
||||||
|
COMMON_LOG_COL_MAGENTA,
|
||||||
|
COMMON_LOG_COL_CYAN,
|
||||||
|
COMMON_LOG_COL_WHITE,
|
||||||
|
};
|
||||||
|
|
||||||
|
// disable colors by default
|
||||||
|
static std::vector<const char *> g_col = {
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_log_entry {
|
||||||
|
enum ggml_log_level level;
|
||||||
|
|
||||||
|
bool prefix;
|
||||||
|
|
||||||
|
int64_t timestamp;
|
||||||
|
|
||||||
|
std::vector<char> msg;
|
||||||
|
|
||||||
|
// signals the worker thread to stop
|
||||||
|
bool is_end;
|
||||||
|
|
||||||
|
void print(FILE * file = nullptr) const {
|
||||||
|
FILE * fcur = file;
|
||||||
|
if (!fcur) {
|
||||||
|
// stderr displays DBG messages only when their verbosity level is not higher than the threshold
|
||||||
|
// these messages will still be logged to a file
|
||||||
|
if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
fcur = stdout;
|
||||||
|
|
||||||
|
if (level != GGML_LOG_LEVEL_NONE) {
|
||||||
|
fcur = stderr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
|
||||||
|
if (timestamp) {
|
||||||
|
// [M.s.ms.us]
|
||||||
|
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
||||||
|
g_col[COMMON_LOG_COL_BLUE],
|
||||||
|
(int) (timestamp / 1000000 / 60),
|
||||||
|
(int) (timestamp / 1000000 % 60),
|
||||||
|
(int) (timestamp / 1000 % 1000),
|
||||||
|
(int) (timestamp % 1000),
|
||||||
|
g_col[COMMON_LOG_COL_DEFAULT]);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (level) {
|
||||||
|
case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
|
||||||
|
case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
|
||||||
|
case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
|
||||||
|
case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(fcur, "%s", msg.data());
|
||||||
|
|
||||||
|
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
|
||||||
|
fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(fcur);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_log {
|
||||||
|
// default capacity - will be expanded if needed
|
||||||
|
common_log() : common_log(256) {}
|
||||||
|
|
||||||
|
common_log(size_t capacity) {
|
||||||
|
file = nullptr;
|
||||||
|
prefix = false;
|
||||||
|
timestamps = false;
|
||||||
|
running = false;
|
||||||
|
t_start = t_us();
|
||||||
|
|
||||||
|
// initial message size - will be expanded if longer messages arrive
|
||||||
|
entries.resize(capacity);
|
||||||
|
for (auto & entry : entries) {
|
||||||
|
entry.msg.resize(256);
|
||||||
|
}
|
||||||
|
|
||||||
|
head = 0;
|
||||||
|
tail = 0;
|
||||||
|
|
||||||
|
resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
~common_log() {
|
||||||
|
pause();
|
||||||
|
if (file) {
|
||||||
|
fclose(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mtx;
|
||||||
|
std::thread thrd;
|
||||||
|
std::condition_variable cv;
|
||||||
|
|
||||||
|
FILE * file;
|
||||||
|
|
||||||
|
bool prefix;
|
||||||
|
bool timestamps;
|
||||||
|
bool running;
|
||||||
|
|
||||||
|
int64_t t_start;
|
||||||
|
|
||||||
|
// ring buffer of entries
|
||||||
|
std::vector<common_log_entry> entries;
|
||||||
|
size_t head;
|
||||||
|
size_t tail;
|
||||||
|
|
||||||
|
// worker thread copies into this
|
||||||
|
common_log_entry cur;
|
||||||
|
|
||||||
|
public:
|
||||||
|
void add(enum ggml_log_level level, const char * fmt, va_list args) {
|
||||||
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
|
if (!running) {
|
||||||
|
// discard messages while the worker thread is paused
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto & entry = entries[tail];
|
||||||
|
|
||||||
|
{
|
||||||
|
// cannot use args twice, so make a copy in case we need to expand the buffer
|
||||||
|
va_list args_copy;
|
||||||
|
va_copy(args_copy, args);
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
|
||||||
|
if (n >= entry.msg.size()) {
|
||||||
|
entry.msg.resize(n + 1);
|
||||||
|
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// hack for bolding arguments
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
for (int i = 0; fmt[i] != 0; i++) {
|
||||||
|
if (fmt[i] == '%') {
|
||||||
|
ss << LOG_COL_BOLD;
|
||||||
|
while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
|
||||||
|
ss << LOG_COL_DEFAULT;
|
||||||
|
if (fmt[i] == 0) break;
|
||||||
|
}
|
||||||
|
ss << fmt[i];
|
||||||
|
}
|
||||||
|
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
|
||||||
|
if (n >= entry.msg.size()) {
|
||||||
|
entry.msg.resize(n + 1);
|
||||||
|
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
va_end(args_copy);
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.level = level;
|
||||||
|
entry.prefix = prefix;
|
||||||
|
entry.timestamp = 0;
|
||||||
|
if (timestamps) {
|
||||||
|
entry.timestamp = t_us() - t_start;
|
||||||
|
}
|
||||||
|
entry.is_end = false;
|
||||||
|
|
||||||
|
tail = (tail + 1) % entries.size();
|
||||||
|
if (tail == head) {
|
||||||
|
// expand the buffer
|
||||||
|
std::vector<common_log_entry> new_entries(2*entries.size());
|
||||||
|
|
||||||
|
size_t new_tail = 0;
|
||||||
|
|
||||||
|
do {
|
||||||
|
new_entries[new_tail] = std::move(entries[head]);
|
||||||
|
|
||||||
|
head = (head + 1) % entries.size();
|
||||||
|
new_tail = (new_tail + 1);
|
||||||
|
} while (head != tail);
|
||||||
|
|
||||||
|
head = 0;
|
||||||
|
tail = new_tail;
|
||||||
|
|
||||||
|
for (size_t i = tail; i < new_entries.size(); i++) {
|
||||||
|
new_entries[i].msg.resize(256);
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = std::move(new_entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
cv.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resume() {
|
||||||
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
|
if (running) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
running = true;
|
||||||
|
|
||||||
|
thrd = std::thread([this]() {
|
||||||
|
while (true) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mtx);
|
||||||
|
cv.wait(lock, [this]() { return head != tail; });
|
||||||
|
|
||||||
|
cur = entries[head];
|
||||||
|
|
||||||
|
head = (head + 1) % entries.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur.is_end) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.print(); // stdout and stderr
|
||||||
|
|
||||||
|
if (file) {
|
||||||
|
cur.print(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void pause() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
|
if (!running) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
running = false;
|
||||||
|
|
||||||
|
// push an entry to signal the worker thread to stop
|
||||||
|
{
|
||||||
|
auto & entry = entries[tail];
|
||||||
|
entry.is_end = true;
|
||||||
|
|
||||||
|
tail = (tail + 1) % entries.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
cv.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
thrd.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_file(const char * path) {
|
||||||
|
pause();
|
||||||
|
|
||||||
|
if (file) {
|
||||||
|
fclose(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (path) {
|
||||||
|
file = fopen(path, "w");
|
||||||
|
} else {
|
||||||
|
file = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_colors(bool colors) {
|
||||||
|
pause();
|
||||||
|
|
||||||
|
if (colors) {
|
||||||
|
g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
|
||||||
|
g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
|
||||||
|
g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
|
||||||
|
g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
|
||||||
|
g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
|
||||||
|
g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
|
||||||
|
g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
|
||||||
|
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
||||||
|
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < g_col.size(); i++) {
|
||||||
|
g_col[i] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_prefix(bool prefix) {
|
||||||
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
|
this->prefix = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_timestamps(bool timestamps) {
|
||||||
|
std::lock_guard<std::mutex> lock(mtx);
|
||||||
|
|
||||||
|
this->timestamps = timestamps;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// public API
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_log * common_log_init() {
|
||||||
|
return new common_log;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_log * common_log_main() {
|
||||||
|
static struct common_log log;
|
||||||
|
static std::once_flag init_flag;
|
||||||
|
std::call_once(init_flag, [&]() {
|
||||||
|
// Set default to auto-detect colors
|
||||||
|
log.set_colors(tty_can_use_colors());
|
||||||
|
});
|
||||||
|
|
||||||
|
return &log;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_pause(struct common_log * log) {
|
||||||
|
log->pause();
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_resume(struct common_log * log) {
|
||||||
|
log->resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_free(struct common_log * log) {
|
||||||
|
delete log;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
|
||||||
|
va_list args;
|
||||||
|
va_start(args, fmt);
|
||||||
|
log->add(level, fmt, args);
|
||||||
|
va_end(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_set_file(struct common_log * log, const char * file) {
|
||||||
|
log->set_file(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||||
|
if (colors == LOG_COLORS_AUTO) {
|
||||||
|
log->set_colors(tty_can_use_colors());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (colors == LOG_COLORS_DISABLED) {
|
||||||
|
log->set_colors(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
|
||||||
|
log->set_colors(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||||
|
log->set_prefix(prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
||||||
|
log->set_timestamps(timestamps);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_flush(struct common_log * log) {
|
||||||
|
log->pause();
|
||||||
|
log->resume();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int common_get_verbosity(enum ggml_log_level level) {
|
||||||
|
switch (level) {
|
||||||
|
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
|
||||||
|
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
|
||||||
|
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
|
||||||
|
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
|
||||||
|
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
|
||||||
|
case GGML_LOG_LEVEL_NONE:
|
||||||
|
default:
|
||||||
|
return LOG_LEVEL_OUTPUT;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
|
||||||
|
auto verbosity = common_get_verbosity(level);
|
||||||
|
if (verbosity <= common_log_verbosity_thold) {
|
||||||
|
common_log_add(common_log_main(), level, "%s", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h" // for ggml_log_level
|
||||||
|
|
||||||
|
#define LOG_CLR_TO_EOL "\033[K\r"
|
||||||
|
#define LOG_COL_DEFAULT "\033[0m"
|
||||||
|
#define LOG_COL_BOLD "\033[1m"
|
||||||
|
#define LOG_COL_RED "\033[31m"
|
||||||
|
#define LOG_COL_GREEN "\033[32m"
|
||||||
|
#define LOG_COL_YELLOW "\033[33m"
|
||||||
|
#define LOG_COL_BLUE "\033[34m"
|
||||||
|
#define LOG_COL_MAGENTA "\033[35m"
|
||||||
|
#define LOG_COL_CYAN "\033[36m"
|
||||||
|
#define LOG_COL_WHITE "\033[37m"
|
||||||
|
|
||||||
|
#ifndef __GNUC__
|
||||||
|
# define LOG_ATTRIBUTE_FORMAT(...)
|
||||||
|
#elif defined(__MINGW32__) && !defined(__clang__)
|
||||||
|
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
|
#else
|
||||||
|
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define LOG_LEVEL_DEBUG 4
|
||||||
|
#define LOG_LEVEL_INFO 3
|
||||||
|
#define LOG_LEVEL_WARN 2
|
||||||
|
#define LOG_LEVEL_ERROR 1
|
||||||
|
#define LOG_LEVEL_OUTPUT 0 // output data from tools
|
||||||
|
|
||||||
|
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
|
||||||
|
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
|
||||||
|
|
||||||
|
enum log_colors {
|
||||||
|
LOG_COLORS_AUTO = -1,
|
||||||
|
LOG_COLORS_DISABLED = 0,
|
||||||
|
LOG_COLORS_ENABLED = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
||||||
|
// set via common_log_set_verbosity()
|
||||||
|
extern int common_log_verbosity_thold;
|
||||||
|
|
||||||
|
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
||||||
|
|
||||||
|
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
|
// the common_log uses an internal worker thread to print/write log messages
|
||||||
|
// when the worker thread is paused, incoming log messages are discarded
|
||||||
|
struct common_log;
|
||||||
|
|
||||||
|
struct common_log * common_log_init();
|
||||||
|
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
|
||||||
|
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
|
||||||
|
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
|
||||||
|
void common_log_free (struct common_log * log);
|
||||||
|
|
||||||
|
LOG_ATTRIBUTE_FORMAT(3, 4)
|
||||||
|
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...);
|
||||||
|
|
||||||
|
// defaults: file = NULL, colors = false, prefix = false, timestamps = false
|
||||||
|
//
|
||||||
|
// regular log output:
|
||||||
|
//
|
||||||
|
// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
||||||
|
// llm_load_tensors: ggml ctx size = 0.27 MiB
|
||||||
|
// llm_load_tensors: offloading 32 repeating layers to GPU
|
||||||
|
// llm_load_tensors: offloading non-repeating layers to GPU
|
||||||
|
//
|
||||||
|
// with prefix = true, timestamps = true, the log output will look like this:
|
||||||
|
//
|
||||||
|
// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
||||||
|
// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
|
||||||
|
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
||||||
|
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
||||||
|
//
|
||||||
|
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||||
|
// I - info (stdout, V = LOG_DEFAULT_INFO)
|
||||||
|
// W - warning (stderr, V = LOG_DEFAULT_WARN)
|
||||||
|
// E - error (stderr, V = LOG_DEFAULT_ERROR)
|
||||||
|
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
|
||||||
|
//
|
||||||
|
|
||||||
|
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||||
|
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
||||||
|
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||||
|
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||||
|
void common_log_flush (struct common_log * log); // flush all pending log messages
|
||||||
|
|
||||||
|
// helper macros for logging
|
||||||
|
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||||
|
//
|
||||||
|
// for example:
|
||||||
|
//
|
||||||
|
// LOG_DBG("this is a debug message: %d\n", expensive_function());
|
||||||
|
//
|
||||||
|
// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
|
||||||
|
//
|
||||||
|
|
||||||
|
#define LOG_TMPL(level, verbosity, ...) \
|
||||||
|
do { \
|
||||||
|
if ((verbosity) <= common_log_verbosity_thold) { \
|
||||||
|
common_log_add(common_log_main(), (level), __VA_ARGS__); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
|
||||||
|
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||||
|
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
|
||||||
|
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
|
||||||
|
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
|
||||||
|
|
||||||
|
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
||||||
|
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
||||||
|
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
||||||
|
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
||||||
|
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
#include "ngram-cache.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <fstream>
|
||||||
|
#include <thread>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
|
||||||
|
std::vector<llama_token> & inp, int nnew, bool print_progress) {
|
||||||
|
const int64_t t_start_ms = ggml_time_ms();
|
||||||
|
const int64_t inp_size = inp.size();
|
||||||
|
|
||||||
|
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
|
||||||
|
int64_t n_done = 0;
|
||||||
|
|
||||||
|
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
|
||||||
|
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
|
||||||
|
for (int64_t i = i_start; i < inp_size; ++i) {
|
||||||
|
const int64_t ngram_start = i - ngram_size;
|
||||||
|
common_ngram ngram(&inp[ngram_start], ngram_size);
|
||||||
|
const llama_token token = inp[i];
|
||||||
|
|
||||||
|
common_ngram_cache::iterator part_it = ngram_cache.find(ngram);
|
||||||
|
if (part_it == ngram_cache.end()) {
|
||||||
|
common_ngram_cache_part part;
|
||||||
|
part.emplace(token, 1);
|
||||||
|
ngram_cache.emplace(ngram, part);
|
||||||
|
} else {
|
||||||
|
common_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
|
||||||
|
if (token_count_it == part_it->second.end()) {
|
||||||
|
part_it->second.emplace(token, 1);
|
||||||
|
} else {
|
||||||
|
token_count_it->second++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
++n_done;
|
||||||
|
|
||||||
|
if (print_progress && n_done % 10000000 == 0) {
|
||||||
|
const int64_t t_now_ms = ggml_time_ms();
|
||||||
|
const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
|
||||||
|
const int64_t eta_min = eta_ms / (60*1000);
|
||||||
|
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get a token from the combined, speculative sequence of inp and draft.
|
||||||
|
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
|
||||||
|
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
|
||||||
|
}
|
||||||
|
|
||||||
|
// If sample size or percentage are below these thresholds the draft is aborted early:
|
||||||
|
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
|
||||||
|
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
|
||||||
|
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
|
||||||
|
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
|
||||||
|
|
||||||
|
// Helper function that tries to draft a token from only the static ngram cache:
|
||||||
|
static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
|
||||||
|
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
||||||
|
if (part_static_it == nc_static.end()) {
|
||||||
|
return LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
const common_ngram_cache_part part_static = part_static_it->second;
|
||||||
|
|
||||||
|
int max_count_static = 0;
|
||||||
|
int sum_count_static = 0;
|
||||||
|
llama_token max_token = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
for (std::pair<llama_token, int> token_count_static : part_static) {
|
||||||
|
const llama_token token = token_count_static.first;
|
||||||
|
const int32_t count_static = token_count_static.second;
|
||||||
|
|
||||||
|
if (count_static > max_count_static) {
|
||||||
|
max_token = token;
|
||||||
|
max_count_static = count_static;
|
||||||
|
}
|
||||||
|
sum_count_static += count_static;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
|
||||||
|
return LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
|
||||||
|
return LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
return max_token;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
|
||||||
|
static llama_token try_draft(
|
||||||
|
common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
|
||||||
|
const int * min_sample_size, const int * min_percent) {
|
||||||
|
|
||||||
|
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) {
|
||||||
|
const common_ngram ngram_primary = ngrams_primary[i];
|
||||||
|
|
||||||
|
common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
|
||||||
|
if (part_primary_it == nc_primary.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const common_ngram_cache_part part_primary = part_primary_it->second;
|
||||||
|
|
||||||
|
int max_count_primary = 0;
|
||||||
|
int max_count_static = 0;
|
||||||
|
int sum_count_primary = 0;
|
||||||
|
llama_token max_token = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
||||||
|
const llama_token token = token_count_primary.first;
|
||||||
|
|
||||||
|
common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
|
||||||
|
|
||||||
|
const int32_t count_primary = token_count_primary.second;
|
||||||
|
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
||||||
|
|
||||||
|
if (count_primary*count_static > max_count_primary*max_count_static) {
|
||||||
|
max_token = token;
|
||||||
|
max_count_primary = count_primary;
|
||||||
|
max_count_static = count_static;
|
||||||
|
}
|
||||||
|
sum_count_primary += count_primary;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sum_count_primary < min_sample_size[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
|
||||||
|
continue;;
|
||||||
|
}
|
||||||
|
drafted_token = max_token;
|
||||||
|
}
|
||||||
|
|
||||||
|
return drafted_token;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_cache_draft(
|
||||||
|
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
||||||
|
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
|
||||||
|
) {
|
||||||
|
GGML_ASSERT(draft.size() == 1);
|
||||||
|
const int inp_size = inp.size();
|
||||||
|
|
||||||
|
if (inp_size < LLAMA_NGRAM_STATIC) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
while ((int) draft.size()-1 < n_draft) {
|
||||||
|
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
|
||||||
|
common_ngram ngram_static;
|
||||||
|
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
|
||||||
|
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
|
||||||
|
}
|
||||||
|
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
||||||
|
common_ngram_cache_part part_static;
|
||||||
|
if (part_static_it != nc_static.end()) {
|
||||||
|
part_static = part_static_it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// cd = context + dynamic
|
||||||
|
std::vector<common_ngram> ngrams_cd;
|
||||||
|
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
|
||||||
|
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
|
||||||
|
common_ngram ngram_cd;
|
||||||
|
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
|
||||||
|
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
|
||||||
|
}
|
||||||
|
ngrams_cd.push_back(ngram_cd);
|
||||||
|
}
|
||||||
|
if (drafted_token == LLAMA_TOKEN_NULL) {
|
||||||
|
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
|
||||||
|
}
|
||||||
|
if (drafted_token == LLAMA_TOKEN_NULL) {
|
||||||
|
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
|
||||||
|
}
|
||||||
|
if (drafted_token == LLAMA_TOKEN_NULL) {
|
||||||
|
drafted_token = try_draft(nc_static, ngram_static);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drafted_token == LLAMA_TOKEN_NULL) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||||
|
draft.push_back(drafted_token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||||
|
std::ofstream file_out(filename, std::ios::binary);
|
||||||
|
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||||
|
const common_ngram ngram = item.first;
|
||||||
|
common_ngram_cache_part token_counts = item.second;
|
||||||
|
GGML_ASSERT(!token_counts.empty());
|
||||||
|
const int32_t ntokens = token_counts.size();
|
||||||
|
GGML_ASSERT(ntokens > 0);
|
||||||
|
|
||||||
|
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(common_ngram));
|
||||||
|
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
|
||||||
|
for (std::pair<llama_token, int32_t> item2 : token_counts) {
|
||||||
|
const llama_token token = item2.first;
|
||||||
|
const int32_t count = item2.second;
|
||||||
|
GGML_ASSERT(count > 0);
|
||||||
|
|
||||||
|
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
|
||||||
|
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||||
|
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||||
|
if (!hashmap_file) {
|
||||||
|
throw std::ifstream::failure("Unable to open file " + filename);
|
||||||
|
}
|
||||||
|
common_ngram_cache ngram_cache;
|
||||||
|
|
||||||
|
common_ngram ngram;
|
||||||
|
int32_t ntokens;
|
||||||
|
llama_token token;
|
||||||
|
int32_t count;
|
||||||
|
|
||||||
|
char * ngramc = reinterpret_cast<char*>(&ngram);
|
||||||
|
char * ntokensc = reinterpret_cast<char*>(&ntokens);
|
||||||
|
char * tokenc = reinterpret_cast<char*>(&token);
|
||||||
|
char * countc = reinterpret_cast<char*>(&count);
|
||||||
|
while(hashmap_file.read(ngramc, sizeof(common_ngram))) {
|
||||||
|
GGML_ASSERT(!hashmap_file.eof());
|
||||||
|
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
|
||||||
|
GGML_ASSERT(ntokens > 0);
|
||||||
|
common_ngram_cache_part token_counts;
|
||||||
|
|
||||||
|
for (int i = 0; i < ntokens; ++i) {
|
||||||
|
GGML_ASSERT(!hashmap_file.eof());
|
||||||
|
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
|
||||||
|
GGML_ASSERT(!hashmap_file.eof());
|
||||||
|
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
|
||||||
|
GGML_ASSERT(count > 0);
|
||||||
|
token_counts.emplace(token, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
ngram_cache.emplace(ngram, token_counts);
|
||||||
|
}
|
||||||
|
GGML_ASSERT(hashmap_file.eof());
|
||||||
|
|
||||||
|
return ngram_cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
|
||||||
|
for (std::pair<common_ngram, common_ngram_cache_part> ngram_part : ngram_cache_add) {
|
||||||
|
const common_ngram ngram = ngram_part.first;
|
||||||
|
common_ngram_cache_part part = ngram_part.second;
|
||||||
|
|
||||||
|
common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
|
||||||
|
if (part_merged_it == ngram_cache_target.end()) {
|
||||||
|
ngram_cache_target.emplace(ngram, part);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (std::pair<llama_token, int32_t> token_count : part) {
|
||||||
|
const llama_token token = token_count.first;
|
||||||
|
const int32_t count = token_count.second;
|
||||||
|
GGML_ASSERT(count > 0);
|
||||||
|
|
||||||
|
common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
|
||||||
|
if (token_count_merged_it == part_merged_it->second.end()) {
|
||||||
|
part_merged_it->second.emplace(token, count);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
token_count_merged_it->second += count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define LLAMA_NGRAM_MIN 1
|
||||||
|
#define LLAMA_NGRAM_MAX 4
|
||||||
|
#define LLAMA_NGRAM_STATIC 2
|
||||||
|
|
||||||
|
// Data structures to map n-grams to empirical token probabilities:
|
||||||
|
|
||||||
|
struct common_ngram {
|
||||||
|
llama_token tokens[LLAMA_NGRAM_MAX];
|
||||||
|
|
||||||
|
common_ngram() {
|
||||||
|
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
||||||
|
tokens[i] = LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common_ngram(const llama_token * input, const int ngram_size) {
|
||||||
|
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
||||||
|
tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const common_ngram & other) const {
|
||||||
|
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
||||||
|
if (tokens[i] != other.tokens[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_token_hash_function {
|
||||||
|
size_t operator()(const llama_token token) const {
|
||||||
|
// see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
|
||||||
|
return token * 11400714819323198485llu;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_ngram_hash_function {
|
||||||
|
size_t operator()(const common_ngram & ngram) const {
|
||||||
|
size_t hash = common_token_hash_function{}(ngram.tokens[0]);
|
||||||
|
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
|
||||||
|
hash ^= common_token_hash_function{}(ngram.tokens[i]);
|
||||||
|
}
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// token -> number of times token has been seen
|
||||||
|
typedef std::unordered_map<llama_token, int32_t> common_ngram_cache_part;
|
||||||
|
|
||||||
|
// n-gram -> empirical distribution of following tokens
|
||||||
|
typedef std::unordered_map<common_ngram, common_ngram_cache_part, common_ngram_hash_function> common_ngram_cache;
|
||||||
|
|
||||||
|
|
||||||
|
// Update an ngram cache with tokens.
|
||||||
|
// ngram_cache: the cache to modify.
|
||||||
|
// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
|
||||||
|
// inp_data: the token sequence with which to update ngram_cache.
|
||||||
|
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
|
||||||
|
// print_progress: whether to print progress to stderr.
|
||||||
|
//
|
||||||
|
// In order to get correct results inp_data can ONLY BE APPENDED TO.
|
||||||
|
// Changes in the middle need a complete rebuild.
|
||||||
|
void common_ngram_cache_update(
|
||||||
|
common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
|
||||||
|
|
||||||
|
// Try to draft tokens from ngram caches.
|
||||||
|
// inp: the tokens generated so far.
|
||||||
|
// draft: the token sequence to draft. Expected to initially contain the previously sampled token.
|
||||||
|
// n_draft: maximum number of tokens to add to draft.
|
||||||
|
// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
|
||||||
|
// nc_context: ngram cache based on current context.
|
||||||
|
// nc_dynamic: ngram cache based on previous user generations.
|
||||||
|
// nc_static: ngram cache generated from a large text corpus, used for validation.
|
||||||
|
void common_ngram_cache_draft(
|
||||||
|
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
||||||
|
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static);
|
||||||
|
|
||||||
|
// Save an ngram cache to a file.
|
||||||
|
// ngram_cache: the ngram cache to save.
|
||||||
|
// filename: the path under which to save the ngram cache.
|
||||||
|
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||||
|
|
||||||
|
// Load an ngram cache saved with common_ngram_cache_save.
|
||||||
|
// filename: the path from which to load the ngram cache.
|
||||||
|
// returns: an ngram cache containing the information saved to filename.
|
||||||
|
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||||
|
|
||||||
|
// Merge two ngram caches.
|
||||||
|
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||||
|
// ngram_cache_add: the ngram cache to add to ngram_cache_target.
|
||||||
|
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add);
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <nlohmann/json_fwd.hpp>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
struct common_grammar_builder;
|
||||||
|
|
||||||
|
class common_peg_parser_builder;
|
||||||
|
|
||||||
|
using common_peg_parser_id = size_t;
|
||||||
|
constexpr common_peg_parser_id COMMON_PEG_INVALID_PARSER_ID = static_cast<common_peg_parser_id>(-1);
|
||||||
|
|
||||||
|
using common_peg_ast_id = size_t;
|
||||||
|
constexpr common_peg_ast_id COMMON_PEG_INVALID_AST_ID = static_cast<common_peg_ast_id>(-1);
|
||||||
|
|
||||||
|
// Lightweight wrapper around common_peg_parser_id for convenience
|
||||||
|
class common_peg_parser {
|
||||||
|
common_peg_parser_id id_;
|
||||||
|
common_peg_parser_builder & builder_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_peg_parser(const common_peg_parser & other) : id_(other.id_), builder_(other.builder_) {}
|
||||||
|
common_peg_parser(common_peg_parser_id id, common_peg_parser_builder & builder) : id_(id), builder_(builder) {}
|
||||||
|
|
||||||
|
common_peg_parser & operator=(const common_peg_parser & other);
|
||||||
|
common_peg_parser & operator+=(const common_peg_parser & other);
|
||||||
|
common_peg_parser & operator|=(const common_peg_parser & other);
|
||||||
|
|
||||||
|
operator common_peg_parser_id() const { return id_; }
|
||||||
|
common_peg_parser_id id() const { return id_; }
|
||||||
|
|
||||||
|
common_peg_parser_builder & builder() const { return builder_; }
|
||||||
|
|
||||||
|
// Creates a sequence
|
||||||
|
common_peg_parser operator+(const common_peg_parser & other) const;
|
||||||
|
|
||||||
|
// Creates a sequence separated by spaces.
|
||||||
|
common_peg_parser operator<<(const common_peg_parser & other) const;
|
||||||
|
|
||||||
|
// Creates a choice
|
||||||
|
common_peg_parser operator|(const common_peg_parser & other) const;
|
||||||
|
|
||||||
|
common_peg_parser operator+(const char * str) const;
|
||||||
|
common_peg_parser operator+(const std::string & str) const;
|
||||||
|
common_peg_parser operator<<(const char * str) const;
|
||||||
|
common_peg_parser operator<<(const std::string & str) const;
|
||||||
|
common_peg_parser operator|(const char * str) const;
|
||||||
|
common_peg_parser operator|(const std::string & str) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
common_peg_parser operator+(const char * str, const common_peg_parser & p);
|
||||||
|
common_peg_parser operator+(const std::string & str, const common_peg_parser & p);
|
||||||
|
common_peg_parser operator<<(const char * str, const common_peg_parser & p);
|
||||||
|
common_peg_parser operator<<(const std::string & str, const common_peg_parser & p);
|
||||||
|
common_peg_parser operator|(const char * str, const common_peg_parser & p);
|
||||||
|
common_peg_parser operator|(const std::string & str, const common_peg_parser & p);
|
||||||
|
|
||||||
|
enum common_peg_parse_result_type {
|
||||||
|
COMMON_PEG_PARSE_RESULT_FAIL = 0,
|
||||||
|
COMMON_PEG_PARSE_RESULT_SUCCESS = 1,
|
||||||
|
COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
const char * common_peg_parse_result_type_name(common_peg_parse_result_type type);
|
||||||
|
|
||||||
|
struct common_peg_ast_node {
|
||||||
|
common_peg_ast_id id;
|
||||||
|
std::string rule;
|
||||||
|
std::string tag;
|
||||||
|
size_t start;
|
||||||
|
size_t end;
|
||||||
|
std::string_view text;
|
||||||
|
std::vector<common_peg_ast_id> children;
|
||||||
|
|
||||||
|
bool is_partial = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_parse_result;
|
||||||
|
|
||||||
|
using common_peg_ast_visitor = std::function<void(const common_peg_ast_node & node)>;
|
||||||
|
|
||||||
|
class common_peg_ast_arena {
|
||||||
|
std::vector<common_peg_ast_node> nodes_;
|
||||||
|
public:
|
||||||
|
common_peg_ast_id add_node(
|
||||||
|
const std::string & rule,
|
||||||
|
const std::string & tag,
|
||||||
|
size_t start,
|
||||||
|
size_t end,
|
||||||
|
std::string_view text,
|
||||||
|
std::vector<common_peg_ast_id> children,
|
||||||
|
bool is_partial = false
|
||||||
|
) {
|
||||||
|
common_peg_ast_id id = nodes_.size();
|
||||||
|
nodes_.push_back({id, rule, tag, start, end, text, std::move(children), is_partial});
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
const common_peg_ast_node & get(common_peg_ast_id id) const { return nodes_.at(id); }
|
||||||
|
|
||||||
|
size_t size() const { return nodes_.size(); }
|
||||||
|
|
||||||
|
void clear() { nodes_.clear(); }
|
||||||
|
|
||||||
|
void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const;
|
||||||
|
void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_parse_result {
|
||||||
|
common_peg_parse_result_type type = COMMON_PEG_PARSE_RESULT_FAIL;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = 0;
|
||||||
|
|
||||||
|
std::vector<common_peg_ast_id> nodes;
|
||||||
|
|
||||||
|
common_peg_parse_result() = default;
|
||||||
|
|
||||||
|
common_peg_parse_result(common_peg_parse_result_type type, size_t start)
|
||||||
|
: type(type), start(start), end(start) {}
|
||||||
|
|
||||||
|
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end)
|
||||||
|
: type(type), start(start), end(end) {}
|
||||||
|
|
||||||
|
common_peg_parse_result(common_peg_parse_result_type type, size_t start, size_t end, std::vector<common_peg_ast_id> nodes)
|
||||||
|
: type(type), start(start), end(end), nodes(std::move(nodes)) {}
|
||||||
|
|
||||||
|
bool fail() const { return type == COMMON_PEG_PARSE_RESULT_FAIL; }
|
||||||
|
bool need_more_input() const { return type == COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT; }
|
||||||
|
bool success() const { return type == COMMON_PEG_PARSE_RESULT_SUCCESS; }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_parse_context {
|
||||||
|
std::string input;
|
||||||
|
bool is_partial;
|
||||||
|
common_peg_ast_arena ast;
|
||||||
|
|
||||||
|
int parse_depth;
|
||||||
|
|
||||||
|
common_peg_parse_context()
|
||||||
|
: is_partial(false), parse_depth(0) {}
|
||||||
|
|
||||||
|
common_peg_parse_context(const std::string & input)
|
||||||
|
: input(input), is_partial(false), parse_depth(0) {}
|
||||||
|
|
||||||
|
common_peg_parse_context(const std::string & input, bool is_partial)
|
||||||
|
: input(input), is_partial(is_partial), parse_depth(0) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_peg_arena;
|
||||||
|
|
||||||
|
// Parser variants
|
||||||
|
struct common_peg_epsilon_parser {};
|
||||||
|
|
||||||
|
struct common_peg_start_parser {};
|
||||||
|
|
||||||
|
struct common_peg_end_parser {};
|
||||||
|
|
||||||
|
struct common_peg_literal_parser {
|
||||||
|
std::string literal;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_sequence_parser {
|
||||||
|
std::vector<common_peg_parser_id> children;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_choice_parser {
|
||||||
|
std::vector<common_peg_parser_id> children;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_repetition_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
int min_count;
|
||||||
|
int max_count; // -1 for unbounded
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_and_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_not_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_any_parser {};
|
||||||
|
|
||||||
|
struct common_peg_space_parser {};
|
||||||
|
|
||||||
|
struct common_peg_chars_parser {
|
||||||
|
struct char_range {
|
||||||
|
uint32_t start;
|
||||||
|
uint32_t end;
|
||||||
|
bool contains(uint32_t codepoint) const { return codepoint >= start && codepoint <= end; }
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string pattern;
|
||||||
|
std::vector<char_range> ranges;
|
||||||
|
bool negated;
|
||||||
|
int min_count;
|
||||||
|
int max_count; // -1 for unbounded
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_json_string_parser {};
|
||||||
|
|
||||||
|
struct common_peg_until_parser {
|
||||||
|
std::vector<std::string> delimiters;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_schema_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
std::string name;
|
||||||
|
std::shared_ptr<nlohmann::ordered_json> schema;
|
||||||
|
|
||||||
|
// Indicates if the GBNF should accept a raw string that matches the schema.
|
||||||
|
bool raw;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_rule_parser {
|
||||||
|
std::string name;
|
||||||
|
common_peg_parser_id child;
|
||||||
|
bool trigger;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_ref_parser {
|
||||||
|
std::string name;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_atomic_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_peg_tag_parser {
|
||||||
|
common_peg_parser_id child;
|
||||||
|
std::string tag;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Variant holding all parser types
|
||||||
|
using common_peg_parser_variant = std::variant<
|
||||||
|
common_peg_epsilon_parser,
|
||||||
|
common_peg_start_parser,
|
||||||
|
common_peg_end_parser,
|
||||||
|
common_peg_literal_parser,
|
||||||
|
common_peg_sequence_parser,
|
||||||
|
common_peg_choice_parser,
|
||||||
|
common_peg_repetition_parser,
|
||||||
|
common_peg_and_parser,
|
||||||
|
common_peg_not_parser,
|
||||||
|
common_peg_any_parser,
|
||||||
|
common_peg_space_parser,
|
||||||
|
common_peg_chars_parser,
|
||||||
|
common_peg_json_string_parser,
|
||||||
|
common_peg_until_parser,
|
||||||
|
common_peg_schema_parser,
|
||||||
|
common_peg_rule_parser,
|
||||||
|
common_peg_ref_parser,
|
||||||
|
common_peg_atomic_parser,
|
||||||
|
common_peg_tag_parser
|
||||||
|
>;
|
||||||
|
|
||||||
|
class common_peg_arena {
|
||||||
|
std::vector<common_peg_parser_variant> parsers_;
|
||||||
|
std::unordered_map<std::string, common_peg_parser_id> rules_;
|
||||||
|
common_peg_parser_id root_ = COMMON_PEG_INVALID_PARSER_ID;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const common_peg_parser_variant & get(common_peg_parser_id id) const { return parsers_.at(id); }
|
||||||
|
common_peg_parser_variant & get(common_peg_parser_id id) { return parsers_.at(id); }
|
||||||
|
|
||||||
|
size_t size() const { return parsers_.size(); }
|
||||||
|
bool empty() const { return parsers_.empty(); }
|
||||||
|
|
||||||
|
common_peg_parser_id get_rule(const std::string & name) const;
|
||||||
|
bool has_rule(const std::string & name) const { return rules_.find(name) != rules_.end(); }
|
||||||
|
|
||||||
|
common_peg_parser_id root() const { return root_; }
|
||||||
|
void set_root(common_peg_parser_id id) { root_ = id; }
|
||||||
|
|
||||||
|
common_peg_parse_result parse(common_peg_parse_context & ctx, size_t start = 0) const;
|
||||||
|
common_peg_parse_result parse(common_peg_parser_id id, common_peg_parse_context & ctx, size_t start) const;
|
||||||
|
|
||||||
|
void resolve_refs();
|
||||||
|
|
||||||
|
void build_grammar(const common_grammar_builder & builder, bool lazy = false) const;
|
||||||
|
|
||||||
|
std::string dump(common_peg_parser_id id) const;
|
||||||
|
|
||||||
|
nlohmann::json to_json() const;
|
||||||
|
static common_peg_arena from_json(const nlohmann::json & j);
|
||||||
|
|
||||||
|
std::string save() const;
|
||||||
|
void load(const std::string & data);
|
||||||
|
|
||||||
|
friend class common_peg_parser_builder;
|
||||||
|
|
||||||
|
private:
|
||||||
|
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||||
|
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||||
|
|
||||||
|
common_peg_parser_id resolve_ref(common_peg_parser_id id);
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_peg_parser_builder {
|
||||||
|
common_peg_arena arena_;
|
||||||
|
|
||||||
|
common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); }
|
||||||
|
common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); }
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_peg_parser_builder();
|
||||||
|
|
||||||
|
// Match nothing, always succeed.
|
||||||
|
// S -> ε
|
||||||
|
common_peg_parser eps() { return add(common_peg_epsilon_parser{}); }
|
||||||
|
|
||||||
|
// Matches the start of the input.
|
||||||
|
// S -> ^
|
||||||
|
common_peg_parser start() { return add(common_peg_start_parser{}); }
|
||||||
|
|
||||||
|
// Matches the end of the input.
|
||||||
|
// S -> $
|
||||||
|
common_peg_parser end() { return add(common_peg_end_parser{}); }
|
||||||
|
|
||||||
|
// Matches an exact literal string.
|
||||||
|
// S -> "hello"
|
||||||
|
common_peg_parser literal(const std::string & literal) { return add(common_peg_literal_parser{literal}); }
|
||||||
|
|
||||||
|
// Matches a sequence of parsers in order, all must succeed.
|
||||||
|
// S -> A B C
|
||||||
|
common_peg_parser sequence() { return add(common_peg_sequence_parser{}); }
|
||||||
|
common_peg_parser sequence(const std::vector<common_peg_parser_id> & parsers);
|
||||||
|
common_peg_parser sequence(const std::vector<common_peg_parser> & parsers);
|
||||||
|
common_peg_parser sequence(std::initializer_list<common_peg_parser> parsers);
|
||||||
|
|
||||||
|
// Matches the first parser that succeeds from a list of alternatives.
|
||||||
|
// S -> A | B | C
|
||||||
|
common_peg_parser choice() { return add(common_peg_choice_parser{}); }
|
||||||
|
common_peg_parser choice(const std::vector<common_peg_parser_id> & parsers);
|
||||||
|
common_peg_parser choice(const std::vector<common_peg_parser> & parsers);
|
||||||
|
common_peg_parser choice(std::initializer_list<common_peg_parser> parsers);
|
||||||
|
|
||||||
|
// Matches one or more repetitions of a parser.
|
||||||
|
// S -> A+
|
||||||
|
common_peg_parser one_or_more(const common_peg_parser & p) { return repeat(p, 1, -1); }
|
||||||
|
|
||||||
|
// Matches zero or more repetitions of a parser, always succeeds.
|
||||||
|
// S -> A*
|
||||||
|
common_peg_parser zero_or_more(const common_peg_parser & p) { return repeat(p, 0, -1); }
|
||||||
|
|
||||||
|
// Matches zero or one occurrence of a parser, always succeeds.
|
||||||
|
// S -> A?
|
||||||
|
common_peg_parser optional(const common_peg_parser & p) { return repeat(p, 0, 1); }
|
||||||
|
|
||||||
|
// Positive lookahead: succeeds if child parser succeeds, consumes no input.
|
||||||
|
// S -> &A
|
||||||
|
common_peg_parser peek(const common_peg_parser & p) { return add(common_peg_and_parser{p}); }
|
||||||
|
|
||||||
|
// Negative lookahead: succeeds if child parser fails, consumes no input.
|
||||||
|
// S -> !A
|
||||||
|
common_peg_parser negate(const common_peg_parser & p) { return add(common_peg_not_parser{p}); }
|
||||||
|
|
||||||
|
// Matches any single character.
|
||||||
|
// S -> .
|
||||||
|
common_peg_parser any() { return add(common_peg_any_parser{}); }
|
||||||
|
|
||||||
|
// Matches between min and max repetitions of characters from a character class.
|
||||||
|
// S -> [a-z]{m,n}
|
||||||
|
//
|
||||||
|
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||||
|
common_peg_parser chars(const std::string & classes, int min = 1, int max = -1);
|
||||||
|
|
||||||
|
// Creates a lightweight reference to a named rule (resolved during build()).
|
||||||
|
// Use this for forward references in recursive grammars.
|
||||||
|
// expr_ref -> expr
|
||||||
|
common_peg_parser ref(const std::string & name) { return add(common_peg_ref_parser{name}); }
|
||||||
|
|
||||||
|
// Matches zero or more whitespace characters (space, tab, newline).
|
||||||
|
// S -> [ \t\n]*
|
||||||
|
common_peg_parser space() { return add(common_peg_space_parser{}); }
|
||||||
|
|
||||||
|
// Matches all characters until a delimiter is found (delimiter not consumed).
|
||||||
|
// S -> (!delim .)*
|
||||||
|
common_peg_parser until(const std::string & delimiter) { return add(common_peg_until_parser{{delimiter}}); }
|
||||||
|
|
||||||
|
// Matches all characters until one of the delimiters in the list is found (delimiter not consumed).
|
||||||
|
// S -> (!delim .)*
|
||||||
|
common_peg_parser until_one_of(const std::vector<std::string> & delimiters) { return add(common_peg_until_parser{delimiters}); }
|
||||||
|
|
||||||
|
// Matches everything
|
||||||
|
// S -> .*
|
||||||
|
common_peg_parser rest() { return until_one_of({}); }
|
||||||
|
|
||||||
|
// Matches between min and max repetitions of a parser (inclusive).
|
||||||
|
// S -> A{m,n}
|
||||||
|
// Use -1 for max to represent unbounded repetition (equivalent to {m,})
|
||||||
|
common_peg_parser repeat(const common_peg_parser & p, int min, int max) { return add(common_peg_repetition_parser{p, min,max}); }
|
||||||
|
|
||||||
|
// Matches exactly n repetitions of a parser.
|
||||||
|
// S -> A{n}
|
||||||
|
common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); }
|
||||||
|
|
||||||
|
// Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null.
|
||||||
|
// value -> object | array | string | number | true | false | null
|
||||||
|
common_peg_parser json();
|
||||||
|
common_peg_parser json_object();
|
||||||
|
common_peg_parser json_string();
|
||||||
|
common_peg_parser json_array();
|
||||||
|
common_peg_parser json_number();
|
||||||
|
common_peg_parser json_bool();
|
||||||
|
common_peg_parser json_null();
|
||||||
|
|
||||||
|
// Matches JSON string content without the surrounding quotes.
|
||||||
|
// Useful for extracting content within a JSON string.
|
||||||
|
common_peg_parser json_string_content();
|
||||||
|
|
||||||
|
// Matches a JSON object member with a key and associated parser as the
|
||||||
|
// value.
|
||||||
|
common_peg_parser json_member(const std::string & key, const common_peg_parser & p);
|
||||||
|
|
||||||
|
// Wraps a parser with JSON schema metadata for grammar generation.
|
||||||
|
// Used internally to convert JSON schemas to GBNF grammar rules.
|
||||||
|
common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false);
|
||||||
|
|
||||||
|
// Creates a named rule, stores it in the grammar, and returns a ref.
|
||||||
|
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||||
|
// auto json = p.rule("json", json_obj | json_arr | ...)
|
||||||
|
common_peg_parser rule(const std::string & name, const common_peg_parser & p, bool trigger = false);
|
||||||
|
|
||||||
|
// Creates a named rule using a builder function, and returns a ref.
|
||||||
|
// If trigger=true, marks this rule as an entry point for lazy grammar generation.
|
||||||
|
// auto json = p.rule("json", [&]() { return json_object() | json_array() | ... })
|
||||||
|
common_peg_parser rule(const std::string & name, const std::function<common_peg_parser()> & builder, bool trigger = false);
|
||||||
|
|
||||||
|
// Creates a trigger rule. When generating a lazy grammar from the parser,
|
||||||
|
// only trigger rules and descendents are emitted.
|
||||||
|
common_peg_parser trigger_rule(const std::string & name, const common_peg_parser & p) { return rule(name, p, true); }
|
||||||
|
common_peg_parser trigger_rule(const std::string & name, const std::function<common_peg_parser()> & builder) { return rule(name, builder, true); }
|
||||||
|
|
||||||
|
// Creates an atomic parser. Atomic parsers do not create an AST node if
|
||||||
|
// the child results in a partial parse, i.e. NEEDS_MORE_INPUT. This is
|
||||||
|
// intended for situations where partial output is undesirable.
|
||||||
|
common_peg_parser atomic(const common_peg_parser & p) { return add(common_peg_atomic_parser{p}); }
|
||||||
|
|
||||||
|
// Tags create nodes in the generated AST for semantic purposes.
|
||||||
|
// Unlike rules, you can tag multiple nodes with the same tag.
|
||||||
|
common_peg_parser tag(const std::string & tag, const common_peg_parser & p) { return add(common_peg_tag_parser{p.id(), tag}); }
|
||||||
|
|
||||||
|
void set_root(const common_peg_parser & p);
|
||||||
|
|
||||||
|
common_peg_arena build();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper function for building parsers
|
||||||
|
common_peg_arena build_peg_parser(const std::function<common_peg_parser(common_peg_parser_builder & builder)> & fn);
|
||||||
@@ -0,0 +1,483 @@
|
|||||||
|
#include "arg.h"
|
||||||
|
#include "preset.h"
|
||||||
|
#include "peg-parser.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "download.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
static std::string rm_leading_dashes(const std::string & str) {
|
||||||
|
size_t pos = 0;
|
||||||
|
while (pos < str.size() && str[pos] == '-') {
|
||||||
|
++pos;
|
||||||
|
}
|
||||||
|
return str.substr(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// only allow a subset of args for remote presets for security reasons
|
||||||
|
// do not add more args unless absolutely necessary
|
||||||
|
// args that output to files are strictly prohibited
|
||||||
|
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
||||||
|
static const std::set<std::string> allowed_options = {
|
||||||
|
"model-url",
|
||||||
|
"hf-repo",
|
||||||
|
"hf-repo-draft",
|
||||||
|
"hf-repo-v", // vocoder
|
||||||
|
"hf-file-v", // vocoder
|
||||||
|
"mmproj-url",
|
||||||
|
"pooling",
|
||||||
|
"jinja",
|
||||||
|
"batch-size",
|
||||||
|
"ubatch-size",
|
||||||
|
"cache-reuse",
|
||||||
|
"chat-template-kwargs",
|
||||||
|
"mmap",
|
||||||
|
// note: sampling params are automatically allowed by default
|
||||||
|
// negated args will be added automatically if the positive arg is specified above
|
||||||
|
};
|
||||||
|
|
||||||
|
std::set<std::string> allowed_keys;
|
||||||
|
|
||||||
|
for (const auto & it : key_to_opt) {
|
||||||
|
const std::string & key = it.first;
|
||||||
|
const common_arg & opt = it.second;
|
||||||
|
if (allowed_options.find(key) != allowed_options.end() || opt.is_sparam) {
|
||||||
|
allowed_keys.insert(key);
|
||||||
|
// also add variant keys (args without leading dashes and env vars)
|
||||||
|
for (const auto & arg : opt.get_args()) {
|
||||||
|
allowed_keys.insert(rm_leading_dashes(arg));
|
||||||
|
}
|
||||||
|
for (const auto & env : opt.get_env()) {
|
||||||
|
allowed_keys.insert(env);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return allowed_keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||||
|
std::vector<std::string> args;
|
||||||
|
|
||||||
|
if (!bin_path.empty()) {
|
||||||
|
args.push_back(bin_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & [opt, value] : options) {
|
||||||
|
if (opt.is_preset_only) {
|
||||||
|
continue; // skip preset-only options (they are not CLI args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// use the last arg as the main arg (i.e. --long-form)
|
||||||
|
args.push_back(opt.args.back());
|
||||||
|
|
||||||
|
// handle value(s)
|
||||||
|
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
|
||||||
|
// flag option, no value
|
||||||
|
if (common_arg_utils::is_falsey(value)) {
|
||||||
|
// use negative arg if available
|
||||||
|
if (!opt.args_neg.empty()) {
|
||||||
|
args.back() = opt.args_neg.back();
|
||||||
|
} else {
|
||||||
|
// otherwise, skip the flag
|
||||||
|
// TODO: maybe throw an error instead?
|
||||||
|
args.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (opt.value_hint != nullptr) {
|
||||||
|
// single value
|
||||||
|
args.push_back(value);
|
||||||
|
}
|
||||||
|
if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"common_preset::to_args(): option '%s' has two values, which is not supported yet",
|
||||||
|
opt.args.back()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_preset::to_ini() const {
|
||||||
|
std::ostringstream ss;
|
||||||
|
|
||||||
|
ss << "[" << name << "]\n";
|
||||||
|
for (const auto & [opt, value] : options) {
|
||||||
|
auto espaced_value = value;
|
||||||
|
string_replace_all(espaced_value, "\n", "\\\n");
|
||||||
|
ss << rm_leading_dashes(opt.args.back()) << " = ";
|
||||||
|
ss << espaced_value << "\n";
|
||||||
|
}
|
||||||
|
ss << "\n";
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
|
||||||
|
// try if option exists, update it
|
||||||
|
for (auto & [opt, val] : options) {
|
||||||
|
if (opt.env && env == opt.env) {
|
||||||
|
val = value;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if option does not exist, we need to add it
|
||||||
|
if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"%s: option with env '%s' not found in ctx_params",
|
||||||
|
__func__, env.c_str()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
options[ctx.key_to_opt.at(env)] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_preset::unset_option(const std::string & env) {
|
||||||
|
for (auto it = options.begin(); it != options.end(); ) {
|
||||||
|
const common_arg & opt = it->first;
|
||||||
|
if (opt.env && env == opt.env) {
|
||||||
|
it = options.erase(it);
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_preset::get_option(const std::string & env, std::string & value) const {
|
||||||
|
for (const auto & [opt, val] : options) {
|
||||||
|
if (opt.env && env == opt.env) {
|
||||||
|
value = val;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_preset::merge(const common_preset & other) {
|
||||||
|
for (const auto & [opt, val] : other.options) {
|
||||||
|
options[opt] = val; // overwrite existing options
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_preset::apply_to_params(common_params & params) const {
|
||||||
|
for (const auto & [opt, val] : options) {
|
||||||
|
// apply each option to params
|
||||||
|
if (opt.handler_string) {
|
||||||
|
opt.handler_string(params, val);
|
||||||
|
} else if (opt.handler_int) {
|
||||||
|
opt.handler_int(params, std::stoi(val));
|
||||||
|
} else if (opt.handler_bool) {
|
||||||
|
opt.handler_bool(params, common_arg_utils::is_truthy(val));
|
||||||
|
} else if (opt.handler_str_str) {
|
||||||
|
// not supported yet
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"%s: option with two values is not supported yet",
|
||||||
|
__func__
|
||||||
|
));
|
||||||
|
} else if (opt.handler_void) {
|
||||||
|
opt.handler_void(params);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("unknown handler type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::map<std::string, std::map<std::string, std::string>> parse_ini_from_file(const std::string & path) {
|
||||||
|
std::map<std::string, std::map<std::string, std::string>> parsed;
|
||||||
|
|
||||||
|
if (!std::filesystem::exists(path)) {
|
||||||
|
throw std::runtime_error("preset file does not exist: " + path);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ifstream file(path);
|
||||||
|
if (!file.good()) {
|
||||||
|
throw std::runtime_error("failed to open server preset file: " + path);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string contents((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||||
|
|
||||||
|
static const auto parser = build_peg_parser([](auto & p) {
|
||||||
|
// newline ::= "\r\n" / "\n" / "\r"
|
||||||
|
auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r"));
|
||||||
|
|
||||||
|
// ws ::= [ \t]*
|
||||||
|
auto ws = p.rule("ws", p.chars("[ \t]", 0, -1));
|
||||||
|
|
||||||
|
// comment ::= [;#] (!newline .)*
|
||||||
|
auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any()));
|
||||||
|
|
||||||
|
// eol ::= ws comment? (newline / EOF)
|
||||||
|
auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end()));
|
||||||
|
|
||||||
|
// ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]*
|
||||||
|
auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1));
|
||||||
|
|
||||||
|
// value ::= (!eol-start .)*
|
||||||
|
auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end()));
|
||||||
|
auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any()));
|
||||||
|
|
||||||
|
// header-line ::= "[" ws ident ws "]" eol
|
||||||
|
auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol);
|
||||||
|
|
||||||
|
// kv-line ::= ident ws "=" ws value eol
|
||||||
|
auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol);
|
||||||
|
|
||||||
|
// comment-line ::= ws comment (newline / EOF)
|
||||||
|
auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end()));
|
||||||
|
|
||||||
|
// blank-line ::= ws (newline / EOF)
|
||||||
|
auto blank_line = p.rule("blank-line", ws + (newline | p.end()));
|
||||||
|
|
||||||
|
// line ::= header-line / kv-line / comment-line / blank-line
|
||||||
|
auto line = p.rule("line", header_line | kv_line | comment_line | blank_line);
|
||||||
|
|
||||||
|
// ini ::= line* EOF
|
||||||
|
auto ini = p.rule("ini", p.zero_or_more(line) + p.end());
|
||||||
|
|
||||||
|
return ini;
|
||||||
|
});
|
||||||
|
|
||||||
|
common_peg_parse_context ctx(contents);
|
||||||
|
const auto result = parser.parse(ctx);
|
||||||
|
if (!result.success()) {
|
||||||
|
throw std::runtime_error("failed to parse server config file: " + path);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string current_section = COMMON_PRESET_DEFAULT_NAME;
|
||||||
|
std::string current_key;
|
||||||
|
|
||||||
|
ctx.ast.visit(result, [&](const auto & node) {
|
||||||
|
if (node.tag == "section-name") {
|
||||||
|
const std::string section = std::string(node.text);
|
||||||
|
current_section = section;
|
||||||
|
parsed[current_section] = {};
|
||||||
|
} else if (node.tag == "key") {
|
||||||
|
const std::string key = std::string(node.text);
|
||||||
|
current_key = key;
|
||||||
|
} else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) {
|
||||||
|
parsed[current_section][current_key] = std::string(node.text);
|
||||||
|
current_key.clear();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::map<std::string, common_arg> get_map_key_opt(common_params_context & ctx_params) {
|
||||||
|
std::map<std::string, common_arg> mapping;
|
||||||
|
for (const auto & opt : ctx_params.options) {
|
||||||
|
for (const auto & env : opt.get_env()) {
|
||||||
|
mapping[env] = opt;
|
||||||
|
}
|
||||||
|
for (const auto & arg : opt.get_args()) {
|
||||||
|
mapping[rm_leading_dashes(arg)] = opt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mapping;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_bool_arg(const common_arg & arg) {
|
||||||
|
return !arg.args_neg.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
|
||||||
|
// if this is a negated arg, we need to reverse the value
|
||||||
|
for (const auto & neg_arg : arg.args_neg) {
|
||||||
|
if (rm_leading_dashes(neg_arg) == key) {
|
||||||
|
return common_arg_utils::is_truthy(value) ? "false" : "true";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// otherwise, not negated
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
||||||
|
: ctx_params(common_params_parser_init(default_params, ex)) {
|
||||||
|
common_params_add_preset_options(ctx_params.options);
|
||||||
|
key_to_opt = get_map_key_opt(ctx_params);
|
||||||
|
|
||||||
|
// setup allowed keys if only_remote_allowed is true
|
||||||
|
if (only_remote_allowed) {
|
||||||
|
filter_allowed_keys = true;
|
||||||
|
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
||||||
|
common_presets out;
|
||||||
|
auto ini_data = parse_ini_from_file(path);
|
||||||
|
|
||||||
|
for (auto section : ini_data) {
|
||||||
|
common_preset preset;
|
||||||
|
if (section.first.empty()) {
|
||||||
|
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
||||||
|
} else {
|
||||||
|
preset.name = section.first;
|
||||||
|
}
|
||||||
|
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
||||||
|
for (const auto & [key, value] : section.second) {
|
||||||
|
if (key == "version") {
|
||||||
|
// skip version key (reserved for future use)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
|
||||||
|
if (filter_allowed_keys && allowed_keys.find(key) == allowed_keys.end()) {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"option '%s' is not allowed in remote presets",
|
||||||
|
key.c_str()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if (key_to_opt.find(key) != key_to_opt.end()) {
|
||||||
|
const auto & opt = key_to_opt.at(key);
|
||||||
|
if (is_bool_arg(opt)) {
|
||||||
|
preset.options[opt] = parse_bool_arg(opt, key, value);
|
||||||
|
} else {
|
||||||
|
preset.options[opt] = value;
|
||||||
|
}
|
||||||
|
LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(string_format(
|
||||||
|
"option '%s' not recognized in preset '%s'",
|
||||||
|
key.c_str(), preset.name.c_str()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (preset.name == "*") {
|
||||||
|
// handle global preset
|
||||||
|
global = preset;
|
||||||
|
} else {
|
||||||
|
out[preset.name] = preset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_presets common_preset_context::load_from_cache() const {
|
||||||
|
common_presets out;
|
||||||
|
|
||||||
|
auto cached_models = common_list_cached_models();
|
||||||
|
for (const auto & model : cached_models) {
|
||||||
|
common_preset preset;
|
||||||
|
preset.name = model.to_string();
|
||||||
|
preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
|
||||||
|
out[preset.name] = preset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct local_model {
|
||||||
|
std::string name;
|
||||||
|
std::string path;
|
||||||
|
std::string path_mmproj;
|
||||||
|
};
|
||||||
|
|
||||||
|
common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
|
||||||
|
if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
|
||||||
|
throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<local_model> models;
|
||||||
|
auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
|
||||||
|
auto files = fs_list(subdir_path, false);
|
||||||
|
common_file_info model_file;
|
||||||
|
common_file_info first_shard_file;
|
||||||
|
common_file_info mmproj_file;
|
||||||
|
for (const auto & file : files) {
|
||||||
|
if (string_ends_with(file.name, ".gguf")) {
|
||||||
|
if (file.name.find("mmproj") != std::string::npos) {
|
||||||
|
mmproj_file = file;
|
||||||
|
} else if (file.name.find("-00001-of-") != std::string::npos) {
|
||||||
|
first_shard_file = file;
|
||||||
|
} else {
|
||||||
|
model_file = file;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// single file model
|
||||||
|
local_model model{
|
||||||
|
/* name */ name,
|
||||||
|
/* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
|
||||||
|
/* path_mmproj */ mmproj_file.path // can be empty
|
||||||
|
};
|
||||||
|
if (!model.path.empty()) {
|
||||||
|
models.push_back(model);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto files = fs_list(models_dir, true);
|
||||||
|
for (const auto & file : files) {
|
||||||
|
if (file.is_dir) {
|
||||||
|
scan_subdir(file.path, file.name);
|
||||||
|
} else if (string_ends_with(file.name, ".gguf")) {
|
||||||
|
// single file model
|
||||||
|
std::string name = file.name;
|
||||||
|
string_replace_all(name, ".gguf", "");
|
||||||
|
local_model model{
|
||||||
|
/* name */ name,
|
||||||
|
/* path */ file.path,
|
||||||
|
/* path_mmproj */ ""
|
||||||
|
};
|
||||||
|
models.push_back(model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert local models to presets
|
||||||
|
common_presets out;
|
||||||
|
for (const auto & model : models) {
|
||||||
|
common_preset preset;
|
||||||
|
preset.name = model.name;
|
||||||
|
preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
|
||||||
|
if (!model.path_mmproj.empty()) {
|
||||||
|
preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
|
||||||
|
}
|
||||||
|
out[preset.name] = preset;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
|
||||||
|
common_preset preset;
|
||||||
|
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
||||||
|
|
||||||
|
bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
|
||||||
|
if (!ok) {
|
||||||
|
throw std::runtime_error("failed to parse CLI arguments into preset");
|
||||||
|
}
|
||||||
|
|
||||||
|
return preset;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
|
||||||
|
common_presets out = base; // copy
|
||||||
|
for (const auto & [name, preset_added] : added) {
|
||||||
|
if (out.find(name) != out.end()) {
|
||||||
|
// if exists, merge
|
||||||
|
common_preset & target = out[name];
|
||||||
|
target.merge(preset_added);
|
||||||
|
} else {
|
||||||
|
// otherwise, add directly
|
||||||
|
out[name] = preset_added;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
|
||||||
|
common_presets out;
|
||||||
|
for (const auto & [name, preset] : presets) {
|
||||||
|
common_preset tmp = base; // copy
|
||||||
|
tmp.name = name;
|
||||||
|
tmp.merge(preset);
|
||||||
|
out[name] = std::move(tmp);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "arg.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
|
//
|
||||||
|
// INI preset parser and writer
|
||||||
|
//
|
||||||
|
|
||||||
|
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
|
||||||
|
|
||||||
|
struct common_preset_context;
|
||||||
|
|
||||||
|
struct common_preset {
|
||||||
|
std::string name;
|
||||||
|
|
||||||
|
// options are stored as common_arg to string mapping, representing CLI arg and its value
|
||||||
|
std::map<common_arg, std::string> options;
|
||||||
|
|
||||||
|
// convert preset to CLI argument list
|
||||||
|
std::vector<std::string> to_args(const std::string & bin_path = "") const;
|
||||||
|
|
||||||
|
// convert preset to INI format string
|
||||||
|
std::string to_ini() const;
|
||||||
|
|
||||||
|
// TODO: maybe implement to_env() if needed
|
||||||
|
|
||||||
|
// modify preset options where argument is identified by its env variable
|
||||||
|
void set_option(const common_preset_context & ctx, const std::string & env, const std::string & value);
|
||||||
|
|
||||||
|
// unset option by its env variable
|
||||||
|
void unset_option(const std::string & env);
|
||||||
|
|
||||||
|
// get option value by its env variable, return false if not found
|
||||||
|
bool get_option(const std::string & env, std::string & value) const;
|
||||||
|
|
||||||
|
// merge another preset into this one, overwriting existing options
|
||||||
|
void merge(const common_preset & other);
|
||||||
|
|
||||||
|
// apply preset options to common_params
|
||||||
|
void apply_to_params(common_params & params) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// interface for multiple presets in one file
|
||||||
|
using common_presets = std::map<std::string, common_preset>;
|
||||||
|
|
||||||
|
// context for loading and editing presets
|
||||||
|
struct common_preset_context {
|
||||||
|
common_params default_params; // unused for now
|
||||||
|
common_params_context ctx_params;
|
||||||
|
std::map<std::string, common_arg> key_to_opt;
|
||||||
|
|
||||||
|
bool filter_allowed_keys = false;
|
||||||
|
std::set<std::string> allowed_keys;
|
||||||
|
|
||||||
|
// if only_remote_allowed is true, only accept whitelisted keys
|
||||||
|
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
||||||
|
|
||||||
|
// load presets from INI file
|
||||||
|
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
||||||
|
|
||||||
|
// generate presets from cached models
|
||||||
|
common_presets load_from_cache() const;
|
||||||
|
|
||||||
|
// generate presets from local models directory
|
||||||
|
// for the directory structure, see "Using multiple models" in server/README.md
|
||||||
|
common_presets load_from_models_dir(const std::string & models_dir) const;
|
||||||
|
|
||||||
|
// generate one preset from CLI arguments
|
||||||
|
common_preset load_from_args(int argc, char ** argv) const;
|
||||||
|
|
||||||
|
// cascade multiple presets if exist on both: base < added
|
||||||
|
// if preset does not exist in base, it will be added without modification
|
||||||
|
common_presets cascade(const common_presets & base, const common_presets & added) const;
|
||||||
|
|
||||||
|
// apply presets over a base preset (same idea as CSS cascading)
|
||||||
|
common_presets cascade(const common_preset & base, const common_presets & presets) const;
|
||||||
|
};
|
||||||
@@ -0,0 +1,204 @@
|
|||||||
|
#include "regex-partial.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
common_regex::common_regex(const std::string & pattern) :
|
||||||
|
pattern(pattern),
|
||||||
|
rx(pattern),
|
||||||
|
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||||
|
|
||||||
|
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||||
|
std::smatch match;
|
||||||
|
if (pos > input.size()) {
|
||||||
|
throw std::runtime_error("Position out of bounds");
|
||||||
|
}
|
||||||
|
auto start = input.begin() + pos;
|
||||||
|
auto found = as_match
|
||||||
|
? std::regex_match(start, input.end(), match, rx)
|
||||||
|
: std::regex_search(start, input.end(), match, rx);
|
||||||
|
if (found) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||||
|
for (size_t i = 0; i < match.size(); ++i) {
|
||||||
|
auto begin = pos + match.position(i);
|
||||||
|
res.groups.emplace_back(begin, begin + match.length(i));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||||
|
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
||||||
|
auto group = srmatch[1].str();
|
||||||
|
if (group.length() != 0) {
|
||||||
|
auto it = srmatch[1].second.base();
|
||||||
|
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||||
|
if ((!as_match) || it == input.begin()) {
|
||||||
|
common_regex_match res;
|
||||||
|
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||||
|
const size_t begin = std::distance(input.begin(), it);
|
||||||
|
const size_t end = input.size();
|
||||||
|
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
res.groups.push_back({begin, end});
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||||
|
|
||||||
|
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||||
|
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||||
|
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||||
|
|
||||||
|
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
||||||
|
- /a|b/ -> ^(a|b)
|
||||||
|
- /a*?/ -> error, could match ""
|
||||||
|
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
||||||
|
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
||||||
|
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
||||||
|
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
||||||
|
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
||||||
|
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
||||||
|
|
||||||
|
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
||||||
|
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
||||||
|
*/
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||||
|
auto it = pattern.begin();
|
||||||
|
const auto end = pattern.end();
|
||||||
|
|
||||||
|
std::function<std::string()> process = [&]() {
|
||||||
|
std::vector<std::vector<std::string>> alternatives(1);
|
||||||
|
std::vector<std::string> * sequence = &alternatives.back();
|
||||||
|
|
||||||
|
while (it != end) {
|
||||||
|
if (*it == '[') {
|
||||||
|
auto start = it;
|
||||||
|
++it;
|
||||||
|
while (it != end) {
|
||||||
|
if ((*it == '\\') && (++it != end)) {
|
||||||
|
++it;
|
||||||
|
} else if ((it != end) && (*it == ']')) {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '[' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
sequence->push_back(std::string(start, it));
|
||||||
|
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Quantifier without preceding element");
|
||||||
|
}
|
||||||
|
sequence->back() += *it;
|
||||||
|
auto is_star = *it == '*';
|
||||||
|
++it;
|
||||||
|
if (is_star) {
|
||||||
|
if (*it == '?') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (*it == '{') {
|
||||||
|
if (sequence->empty()) {
|
||||||
|
throw std::runtime_error("Repetition without preceding element");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto start = it;
|
||||||
|
while (it != end && *it != '}') {
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
if (it == end) {
|
||||||
|
throw std::runtime_error("Unmatched '{' in pattern");
|
||||||
|
}
|
||||||
|
auto parts = string_split(std::string(start, it), ",");
|
||||||
|
++it;
|
||||||
|
if (parts.size() > 2) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||||
|
if (s.empty()) {
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
return std::stoi(s);
|
||||||
|
};
|
||||||
|
auto min = parseOptInt(parts[0], 0);
|
||||||
|
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||||
|
if (min && max && *max < *min) {
|
||||||
|
throw std::runtime_error("Invalid repetition range in pattern");
|
||||||
|
}
|
||||||
|
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||||
|
auto part = sequence->back();
|
||||||
|
sequence->pop_back();
|
||||||
|
for (int i = 0; i < *min; i++) {
|
||||||
|
sequence->push_back(part);
|
||||||
|
}
|
||||||
|
if (max) {
|
||||||
|
for (int i = *min; i < *max; i++) {
|
||||||
|
sequence->push_back(part + "?");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sequence->push_back(part + "*");
|
||||||
|
}
|
||||||
|
} else if (*it == '(') {
|
||||||
|
++it;
|
||||||
|
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||||
|
it += 2;
|
||||||
|
}
|
||||||
|
auto sub = process();
|
||||||
|
if (*it != ')') {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
++it;
|
||||||
|
auto & part = sequence->emplace_back("(?:");
|
||||||
|
part += sub;
|
||||||
|
part += ")";
|
||||||
|
} else if (*it == ')') {
|
||||||
|
break;
|
||||||
|
} else if (*it == '|') {
|
||||||
|
++it;
|
||||||
|
alternatives.emplace_back();
|
||||||
|
sequence = &alternatives.back();
|
||||||
|
} else if (*it == '\\' && (++it != end)) {
|
||||||
|
auto str = std::string("\\") + *it;
|
||||||
|
sequence->push_back(str);
|
||||||
|
++it;
|
||||||
|
} else if (it != end) {
|
||||||
|
sequence->push_back(std::string(1, *it));
|
||||||
|
++it;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
||||||
|
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||||
|
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||||
|
std::vector<std::string> res_alts;
|
||||||
|
for (const auto & parts : alternatives) {
|
||||||
|
auto & res = res_alts.emplace_back();
|
||||||
|
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||||
|
res += "(?:";
|
||||||
|
}
|
||||||
|
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||||
|
res += *it;
|
||||||
|
if (it != parts.rend() - 1) {
|
||||||
|
res += ")?";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string_join(res_alts, "|");
|
||||||
|
};
|
||||||
|
auto res = process();
|
||||||
|
if (it != end) {
|
||||||
|
throw std::runtime_error("Unmatched '(' in pattern");
|
||||||
|
}
|
||||||
|
|
||||||
|
return "^(" + res + ")";
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
enum common_regex_match_type {
|
||||||
|
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||||
|
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_string_range {
|
||||||
|
size_t begin;
|
||||||
|
size_t end;
|
||||||
|
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||||
|
if (begin > end) {
|
||||||
|
throw std::runtime_error("Invalid range");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// prevent default ctor
|
||||||
|
common_string_range() = delete;
|
||||||
|
bool empty() const {
|
||||||
|
return begin == end;
|
||||||
|
}
|
||||||
|
bool operator==(const common_string_range & other) const {
|
||||||
|
return begin == other.begin && end == other.end;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_regex_match {
|
||||||
|
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
|
||||||
|
bool operator==(const common_regex_match & other) const {
|
||||||
|
return type == other.type && groups == other.groups;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_regex_match & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_regex {
|
||||||
|
std::string pattern;
|
||||||
|
std::regex rx;
|
||||||
|
std::regex rx_reversed_partial;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit common_regex(const std::string & pattern);
|
||||||
|
|
||||||
|
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||||
|
|
||||||
|
const std::string & str() const { return pattern; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// For testing only (pretty print of failures).
|
||||||
|
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||||
@@ -0,0 +1,712 @@
|
|||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstring>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
||||||
|
// TODO: deduplicate with llama-impl.h
|
||||||
|
template<typename T>
|
||||||
|
struct ring_buffer {
|
||||||
|
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
||||||
|
|
||||||
|
T & front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & front() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[first];
|
||||||
|
}
|
||||||
|
|
||||||
|
T & back() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & back() const {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
return data[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back(const T & value) {
|
||||||
|
if (sz == capacity) {
|
||||||
|
// advance the start when buffer is full
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
} else {
|
||||||
|
sz++;
|
||||||
|
}
|
||||||
|
data[pos] = value;
|
||||||
|
pos = (pos + 1) % capacity;
|
||||||
|
}
|
||||||
|
|
||||||
|
T pop_front() {
|
||||||
|
if (sz == 0) {
|
||||||
|
throw std::runtime_error("ring buffer is empty");
|
||||||
|
}
|
||||||
|
T value = data[first];
|
||||||
|
first = (first + 1) % capacity;
|
||||||
|
sz--;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
const T & rat(size_t i) const {
|
||||||
|
if (i >= sz) {
|
||||||
|
throw std::runtime_error("ring buffer: index out of bounds");
|
||||||
|
}
|
||||||
|
return data[(first + sz - i - 1) % capacity];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<T> to_vector() const {
|
||||||
|
std::vector<T> result;
|
||||||
|
result.reserve(sz);
|
||||||
|
for (size_t i = 0; i < sz; i++) {
|
||||||
|
result.push_back(data[(first + i) % capacity]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
// here only reset the status of the buffer
|
||||||
|
sz = 0;
|
||||||
|
first = 0;
|
||||||
|
pos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return sz == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const {
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t capacity = 0;
|
||||||
|
size_t sz = 0;
|
||||||
|
size_t first = 0;
|
||||||
|
size_t pos = 0;
|
||||||
|
std::vector<T> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_sampler {
|
||||||
|
common_params_sampling params;
|
||||||
|
|
||||||
|
struct llama_sampler * grmr;
|
||||||
|
struct llama_sampler * chain;
|
||||||
|
|
||||||
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
|
||||||
|
llama_token_data_array cur_p;
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
prev.clear();
|
||||||
|
|
||||||
|
llama_sampler_reset(chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
|
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
|
||||||
|
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
|
||||||
|
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||||
|
|
||||||
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
|
if (sampled_probs) {
|
||||||
|
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
|
||||||
|
cur.resize(sampled_probs_count);
|
||||||
|
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||||
|
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||||
|
}
|
||||||
|
} else if (sampled_logits) {
|
||||||
|
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||||
|
cur.resize(sampled_logits_count);
|
||||||
|
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||||
|
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
GGML_ASSERT(logits != nullptr);
|
||||||
|
cur.resize(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_p = { cur.data(), cur.size(), -1, false };
|
||||||
|
}
|
||||||
|
|
||||||
|
common_time_meas tm() {
|
||||||
|
return common_time_meas(t_total_us, params.no_perf);
|
||||||
|
}
|
||||||
|
|
||||||
|
mutable int64_t t_total_us = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string common_params_sampling::print() const {
|
||||||
|
char result[1024];
|
||||||
|
|
||||||
|
snprintf(result, sizeof(result),
|
||||||
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
|
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
||||||
|
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
||||||
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
||||||
|
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
||||||
|
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
||||||
|
mirostat, mirostat_eta, mirostat_tau);
|
||||||
|
|
||||||
|
return std::string(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||||
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||||
|
|
||||||
|
lparams.no_perf = params.no_perf;
|
||||||
|
|
||||||
|
llama_sampler * grmr = nullptr;
|
||||||
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||||
|
|
||||||
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
||||||
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||||
|
#ifdef LLAMA_USE_LLGUIDANCE
|
||||||
|
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||||
|
#else
|
||||||
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
|
} else {
|
||||||
|
std::vector<std::string> trigger_patterns;
|
||||||
|
std::vector<llama_token> trigger_tokens;
|
||||||
|
for (const auto & trigger : params.grammar_triggers) {
|
||||||
|
switch (trigger.type) {
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||||
|
{
|
||||||
|
const auto & word = trigger.value;
|
||||||
|
trigger_patterns.push_back(regex_escape(word));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||||
|
{
|
||||||
|
trigger_patterns.push_back(trigger.value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||||
|
{
|
||||||
|
const auto & pattern = trigger.value;
|
||||||
|
std::string anchored = "^$";
|
||||||
|
if (!pattern.empty()) {
|
||||||
|
anchored = (pattern.front() != '^' ? "^" : "")
|
||||||
|
+ pattern
|
||||||
|
+ (pattern.back() != '$' ? "$" : "");
|
||||||
|
}
|
||||||
|
trigger_patterns.push_back(anchored);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||||
|
{
|
||||||
|
const auto token = trigger.token;
|
||||||
|
trigger_tokens.push_back(token);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "unknown trigger type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const char *> trigger_patterns_c;
|
||||||
|
trigger_patterns_c.reserve(trigger_patterns.size());
|
||||||
|
for (const auto & regex : trigger_patterns) {
|
||||||
|
trigger_patterns_c.push_back(regex.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
if (params.grammar_lazy) {
|
||||||
|
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||||
|
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||||
|
trigger_tokens.data(), trigger_tokens.size());
|
||||||
|
} else {
|
||||||
|
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.has_logit_bias()) {
|
||||||
|
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.mirostat == 0) {
|
||||||
|
for (const auto & cnstr : params.samplers) {
|
||||||
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
|
{
|
||||||
|
std::vector<const char *> c_breakers;
|
||||||
|
c_breakers.reserve(params.dry_sequence_breakers.size());
|
||||||
|
for (const auto & str : params.dry_sequence_breakers) {
|
||||||
|
c_breakers.push_back(str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
|
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||||
|
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||||
|
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
|
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
|
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
|
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
|
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
|
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
|
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||||
|
} else if (params.mirostat == 1) {
|
||||||
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
|
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
|
} else if (params.mirostat == 2) {
|
||||||
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
|
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto * smpl : samplers) {
|
||||||
|
llama_sampler_chain_add(chain, smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grmr && params.backend_sampling) {
|
||||||
|
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
|
||||||
|
|
||||||
|
params.backend_sampling = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto * result = new common_sampler {
|
||||||
|
/* .params = */ params,
|
||||||
|
/* .grmr = */ grmr,
|
||||||
|
/* .chain = */ chain,
|
||||||
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
|
/* .cur = */ {},
|
||||||
|
/* .cur_p = */ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
|
if (gsmpl) {
|
||||||
|
llama_sampler_free(gsmpl->grmr);
|
||||||
|
llama_sampler_free(gsmpl->chain);
|
||||||
|
|
||||||
|
delete gsmpl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
|
if (gsmpl->grmr && accept_grammar) {
|
||||||
|
llama_sampler_accept(gsmpl->grmr, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_accept(gsmpl->chain, token);
|
||||||
|
|
||||||
|
gsmpl->prev.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
|
gsmpl->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
|
return new common_sampler {
|
||||||
|
/* .params = */ gsmpl->params,
|
||||||
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
|
/* .prev = */ gsmpl->prev,
|
||||||
|
/* .cur = */ gsmpl->cur,
|
||||||
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
||||||
|
// TODO: measure grammar performance
|
||||||
|
|
||||||
|
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
|
||||||
|
|
||||||
|
llama_perf_sampler_data data_smpl;
|
||||||
|
llama_perf_context_data data_ctx;
|
||||||
|
|
||||||
|
memset(&data_smpl, 0, sizeof(data_smpl));
|
||||||
|
memset(&data_ctx, 0, sizeof(data_ctx));
|
||||||
|
|
||||||
|
if (gsmpl) {
|
||||||
|
auto & data = data_smpl;
|
||||||
|
|
||||||
|
data = llama_perf_sampler(gsmpl->chain);
|
||||||
|
|
||||||
|
// note: the sampling time includes the samplers time + extra time spent in common/sampling
|
||||||
|
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
|
||||||
|
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx) {
|
||||||
|
auto & data = data_ctx;
|
||||||
|
|
||||||
|
data = llama_perf_context(ctx);
|
||||||
|
|
||||||
|
const double t_end_ms = 1e-3 * ggml_time_us();
|
||||||
|
|
||||||
|
const double t_total_ms = t_end_ms - data.t_start_ms;
|
||||||
|
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
|
||||||
|
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
|
||||||
|
|
||||||
|
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
|
||||||
|
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
|
||||||
|
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||||
|
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||||
|
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
|
||||||
|
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||||
|
|
||||||
|
llama_memory_breakdown_print(ctx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||||
|
return gsmpl->chain;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
|
llama_token id = LLAMA_TOKEN_NULL;
|
||||||
|
|
||||||
|
auto & grmr = gsmpl->grmr;
|
||||||
|
auto & chain = gsmpl->chain;
|
||||||
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||||
|
|
||||||
|
// Check if a backend sampler has already sampled a token in which case we
|
||||||
|
// return that token id directly.
|
||||||
|
{
|
||||||
|
id = llama_get_sampled_token_ith(ctx, idx);
|
||||||
|
|
||||||
|
if (id != LLAMA_TOKEN_NULL) {
|
||||||
|
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||||
|
|
||||||
|
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
|
||||||
|
|
||||||
|
// TODO: simplify
|
||||||
|
gsmpl->cur.resize(1);
|
||||||
|
gsmpl->cur[0] = { id, 0.0f, 1.0f };
|
||||||
|
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
|
||||||
|
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
|
id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
|
if (grammar_first) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
||||||
|
{
|
||||||
|
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||||
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &single_token_data_array);
|
||||||
|
|
||||||
|
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
|
if (is_valid) {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resampling:
|
||||||
|
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||||
|
gsmpl->set_logits(ctx, idx);
|
||||||
|
|
||||||
|
llama_sampler_apply(grmr, &cur_p);
|
||||||
|
llama_sampler_apply(chain, &cur_p);
|
||||||
|
|
||||||
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||||
|
|
||||||
|
id = cur_p.data[cur_p.selected].id;
|
||||||
|
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||||
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
result.reserve(idxs.size());
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i < draft.size(); i++) {
|
||||||
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
|
if (draft[i] != id) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == draft.size()) {
|
||||||
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||||
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
|
idxs[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
|
return llama_sampler_get_seed(gsmpl->chain);
|
||||||
|
}
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
|
||||||
|
const auto tm = gsmpl->tm();
|
||||||
|
|
||||||
|
auto * res = &gsmpl->cur_p;
|
||||||
|
|
||||||
|
if (do_sort && !res->sorted) {
|
||||||
|
// remember the selected token before sorting
|
||||||
|
const llama_token id = res->data[res->selected].id;
|
||||||
|
|
||||||
|
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.p > b.p;
|
||||||
|
});
|
||||||
|
|
||||||
|
// restore the selected token after sorting
|
||||||
|
for (size_t i = 0; i < res->size; ++i) {
|
||||||
|
if (res->data[i].id == id) {
|
||||||
|
res->selected = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res->sorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
||||||
|
return gsmpl->prev.rat(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
||||||
|
std::string result = "logits ";
|
||||||
|
|
||||||
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||||
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
|
result += std::string("-> ");
|
||||||
|
result += std::string(llama_sampler_name(smpl)) + " ";
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
|
||||||
|
n = std::min(n, (int) gsmpl->prev.size());
|
||||||
|
|
||||||
|
if (n <= 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
||||||
|
|
||||||
|
for (int i = n - 1; i >= 0; i--) {
|
||||||
|
const llama_token id = gsmpl->prev.rat(i);
|
||||||
|
|
||||||
|
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
||||||
|
|
||||||
|
result += common_token_to_piece(ctx_main, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
||||||
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
||||||
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
|
||||||
|
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
||||||
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
||||||
|
default : return '?';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
||||||
|
switch (cnstr) {
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
||||||
|
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
|
||||||
|
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
||||||
|
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
||||||
|
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
||||||
|
default : return "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
||||||
|
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
||||||
|
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
||||||
|
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
|
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
|
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
|
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
|
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||||
|
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
|
};
|
||||||
|
|
||||||
|
// since samplers names are written multiple ways
|
||||||
|
// make it ready for both system names and input names
|
||||||
|
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
||||||
|
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
|
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
|
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
|
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
|
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
|
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<common_sampler_type> samplers;
|
||||||
|
samplers.reserve(names.size());
|
||||||
|
|
||||||
|
for (const auto & name : names) {
|
||||||
|
auto sampler = sampler_canonical_name_map.find(name);
|
||||||
|
if (sampler != sampler_canonical_name_map.end()) {
|
||||||
|
samplers.push_back(sampler->second);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (allow_alt_names) {
|
||||||
|
sampler = sampler_alt_name_map.find(name);
|
||||||
|
if (sampler != sampler_alt_name_map.end()) {
|
||||||
|
samplers.push_back(sampler->second);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return samplers;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
||||||
|
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
||||||
|
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<common_sampler_type> samplers;
|
||||||
|
samplers.reserve(chars.size());
|
||||||
|
|
||||||
|
for (const auto & c : chars) {
|
||||||
|
const auto sampler = sampler_name_map.find(c);
|
||||||
|
if (sampler != sampler_name_map.end()) {
|
||||||
|
samplers.push_back(sampler->second);
|
||||||
|
} else {
|
||||||
|
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return samplers;
|
||||||
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// common_sampler extends llama_sampler with additional functionality:
|
||||||
|
//
|
||||||
|
// - grammar support
|
||||||
|
// - custom sampler logic based on the parameters
|
||||||
|
// - history of the last accepted tokens
|
||||||
|
// - performance metrics
|
||||||
|
//
|
||||||
|
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
||||||
|
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
||||||
|
// complex (top-k, top-p, etc).
|
||||||
|
//
|
||||||
|
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
||||||
|
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
||||||
|
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
||||||
|
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
||||||
|
//
|
||||||
|
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
|
||||||
|
// be moved into the core llama library.
|
||||||
|
//
|
||||||
|
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
|
||||||
|
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
||||||
|
//
|
||||||
|
// TODO: measure grammar performance
|
||||||
|
//
|
||||||
|
|
||||||
|
struct common_sampler;
|
||||||
|
|
||||||
|
// llama_sampler API overloads
|
||||||
|
|
||||||
|
// note: can mutate params in some cases
|
||||||
|
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||||
|
|
||||||
|
void common_sampler_free(struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||||
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
||||||
|
void common_sampler_reset (struct common_sampler * gsmpl);
|
||||||
|
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// arguments can be nullptr to skip printing
|
||||||
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// get the underlying llama_sampler_chain
|
||||||
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// extended sampling implementation:
|
||||||
|
//
|
||||||
|
// - set logits
|
||||||
|
// - apply the configured sampler chain
|
||||||
|
// - check if the token fits the grammar (if any)
|
||||||
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||||
|
//
|
||||||
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||||
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||||
|
//
|
||||||
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||||
|
|
||||||
|
// generalized version of common_sampler_sample
|
||||||
|
//
|
||||||
|
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
||||||
|
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
||||||
|
//
|
||||||
|
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
||||||
|
//
|
||||||
|
// is equivalent to
|
||||||
|
//
|
||||||
|
// common_sampler_sample(gsmpl, ctx, idx);
|
||||||
|
// common_sampler_accept(gsmpl, token, true);
|
||||||
|
//
|
||||||
|
// requires: idxs.size() == draft.size() + 1
|
||||||
|
//
|
||||||
|
// returns at least 1 token, up to idxs.size()
|
||||||
|
//
|
||||||
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
// access the internal list of current candidate tokens
|
||||||
|
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
|
||||||
|
// the .sorted flag of the result indicates whether the returned candidates are sorted
|
||||||
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
|
||||||
|
|
||||||
|
// get the last accepted token
|
||||||
|
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// print the sampler chain into a string
|
||||||
|
std::string common_sampler_print(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
// get a string representation of the last accepted tokens
|
||||||
|
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
|
||||||
|
|
||||||
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
||||||
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
||||||
|
|
||||||
|
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
||||||
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
||||||
|
|
||||||
|
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||||
|
const char * grammar_kind, const char * grammar_data);
|
||||||
|
|
||||||
|
struct common_sampler_deleter {
|
||||||
|
void operator()(common_sampler * s) { common_sampler_free(s); }
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||||
@@ -0,0 +1,361 @@
|
|||||||
|
#include "speculative.h"
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "llama.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||||
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
|
|
||||||
|
struct common_speculative {
|
||||||
|
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||||
|
struct llama_context * ctx_dft;
|
||||||
|
struct common_sampler * smpl;
|
||||||
|
|
||||||
|
llama_batch batch;
|
||||||
|
llama_tokens prompt_dft;
|
||||||
|
bool vocab_dft_compatible = true; // whether retokenization is needed
|
||||||
|
std::map<std::string, std::string> tgt_dft_replacements = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(
|
||||||
|
struct llama_context * ctx_tgt,
|
||||||
|
struct llama_context * ctx_dft) {
|
||||||
|
auto * result = new common_speculative {
|
||||||
|
/* .ctx_tgt = */ ctx_tgt,
|
||||||
|
/* .ctx_dft = */ ctx_dft,
|
||||||
|
/* .smpl = */ nullptr,
|
||||||
|
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||||
|
/* .prompt_dft = */ {},
|
||||||
|
/* .vocab_dft_compatible = */ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: optimize or pass from outside?
|
||||||
|
#if 0
|
||||||
|
{
|
||||||
|
common_params_sampling params;
|
||||||
|
params.no_perf = false;
|
||||||
|
|
||||||
|
params.top_k = 40;
|
||||||
|
params.top_p = 0.9;
|
||||||
|
|
||||||
|
params.samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_P,
|
||||||
|
COMMON_SAMPLER_TYPE_INFILL,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
common_params_sampling params;
|
||||||
|
params.no_perf = false;
|
||||||
|
|
||||||
|
params.top_k = 10;
|
||||||
|
|
||||||
|
params.samplers = {
|
||||||
|
COMMON_SAMPLER_TYPE_TOP_K,
|
||||||
|
};
|
||||||
|
|
||||||
|
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
|
||||||
|
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec) {
|
||||||
|
if (spec == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_sampler_free(spec->smpl);
|
||||||
|
|
||||||
|
llama_batch_free(spec->batch);
|
||||||
|
|
||||||
|
delete spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_are_compatible(
|
||||||
|
const struct llama_context * ctx_tgt,
|
||||||
|
const struct llama_context * ctx_dft) {
|
||||||
|
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||||
|
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
||||||
|
|
||||||
|
const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||||
|
const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||||
|
|
||||||
|
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
||||||
|
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
||||||
|
|
||||||
|
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
|
||||||
|
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||||
|
|
||||||
|
if (vocab_type_tgt != vocab_type_dft) {
|
||||||
|
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
||||||
|
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||||
|
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||||
|
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
||||||
|
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
|
||||||
|
) {
|
||||||
|
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
||||||
|
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
||||||
|
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
||||||
|
? n_vocab_tgt - n_vocab_dft
|
||||||
|
: n_vocab_dft - n_vocab_tgt;
|
||||||
|
|
||||||
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||||
|
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||||
|
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||||
|
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
||||||
|
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
||||||
|
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||||
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||||
|
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||||
|
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
||||||
|
common_token_to_piece(ctx_tgt, i).c_str(),
|
||||||
|
common_token_to_piece(ctx_dft, i).c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_add_replacement_tgt_dft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
const char *source, const char *dest) {
|
||||||
|
spec->tgt_dft_replacements[source] = dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string replace_to_dft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
const std::string& input) {
|
||||||
|
std::string result = input;
|
||||||
|
for (const auto & pair : spec->tgt_dft_replacements) {
|
||||||
|
size_t pos = result.find(pair.first);
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
result.replace(pos, pair.first.length(), pair.second);
|
||||||
|
pos = result.find(pair.first, pos + pair.second.length());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string replace_to_tgt(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
const std::string& input) {
|
||||||
|
std::string result = input;
|
||||||
|
for (const auto& pair : spec->tgt_dft_replacements) {
|
||||||
|
size_t pos = result.find(pair.second);
|
||||||
|
while (pos != std::string::npos) {
|
||||||
|
result.replace(pos, pair.second.length(), pair.first);
|
||||||
|
pos = result.find(pair.second, pos + pair.first.length());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
llama_tokens common_speculative_gen_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_speculative_params params,
|
||||||
|
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
|
||||||
|
llama_token id_last) {
|
||||||
|
auto & batch = spec->batch;
|
||||||
|
auto & ctx_tgt = spec->ctx_tgt;
|
||||||
|
auto & ctx_dft = spec->ctx_dft;
|
||||||
|
auto & smpl = spec->smpl;
|
||||||
|
auto & prompt_dft = spec->prompt_dft;
|
||||||
|
|
||||||
|
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||||
|
|
||||||
|
int reuse_i = 0;
|
||||||
|
int reuse_n = 0;
|
||||||
|
|
||||||
|
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;
|
||||||
|
|
||||||
|
llama_tokens prompt_tgt_draft_model;
|
||||||
|
if (!spec->vocab_dft_compatible) {
|
||||||
|
std::string text;
|
||||||
|
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
|
||||||
|
text = replace_to_dft(spec, text);
|
||||||
|
LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str());
|
||||||
|
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
|
||||||
|
|
||||||
|
// convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation
|
||||||
|
const auto * model_tgt = llama_get_model(ctx_tgt);
|
||||||
|
const auto * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||||
|
|
||||||
|
int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false);
|
||||||
|
GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last");
|
||||||
|
text.resize(-n_chars);
|
||||||
|
llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false);
|
||||||
|
text = replace_to_dft(spec, text);
|
||||||
|
|
||||||
|
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
|
||||||
|
id_last = common_tokenize(ctx_dft, text, false, true)[0];
|
||||||
|
}
|
||||||
|
// prompt_tgt's tokens will always be compatible with ctx_dft
|
||||||
|
const llama_tokens &prompt_tgt =
|
||||||
|
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;
|
||||||
|
|
||||||
|
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||||
|
|
||||||
|
// reuse as much as possible from the old draft context
|
||||||
|
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
||||||
|
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
|
||||||
|
int cur = 0;
|
||||||
|
while (i_start + cur < (int) prompt_tgt.size() &&
|
||||||
|
i + cur < (int) prompt_dft.size() &&
|
||||||
|
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
|
||||||
|
cur++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
||||||
|
reuse_i = i;
|
||||||
|
reuse_n = cur;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
|
||||||
|
|
||||||
|
llama_tokens result;
|
||||||
|
result.reserve(params.n_draft);
|
||||||
|
|
||||||
|
if (reuse_n == 0) {
|
||||||
|
llama_memory_clear(mem_dft, false);
|
||||||
|
prompt_dft.clear();
|
||||||
|
} else {
|
||||||
|
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||||
|
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||||
|
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
||||||
|
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||||
|
result.push_back(prompt_dft[i]);
|
||||||
|
|
||||||
|
if (params.n_draft <= (int) result.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reuse_i > 0) {
|
||||||
|
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||||
|
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||||
|
|
||||||
|
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (reuse_n < (int) prompt_dft.size()) {
|
||||||
|
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||||
|
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare a batch to evaluate any new tokens in the prompt
|
||||||
|
common_batch_clear(batch);
|
||||||
|
|
||||||
|
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
||||||
|
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
||||||
|
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
||||||
|
|
||||||
|
prompt_dft.push_back(prompt_tgt[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// we should rarely end-up here during normal decoding
|
||||||
|
if (batch.n_tokens > 0) {
|
||||||
|
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||||
|
|
||||||
|
llama_decode(ctx_dft, batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_pos n_past = prompt_dft.size();
|
||||||
|
|
||||||
|
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
||||||
|
|
||||||
|
common_batch_clear(batch);
|
||||||
|
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
||||||
|
|
||||||
|
prompt_dft.push_back(id_last);
|
||||||
|
|
||||||
|
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||||
|
|
||||||
|
llama_decode(ctx_dft, batch);
|
||||||
|
|
||||||
|
common_sampler_reset(smpl);
|
||||||
|
|
||||||
|
// sample n_draft tokens from the draft model
|
||||||
|
for (int i = 0; i < params.n_draft; ++i) {
|
||||||
|
common_batch_clear(batch);
|
||||||
|
|
||||||
|
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||||
|
|
||||||
|
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||||
|
|
||||||
|
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||||
|
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
|
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// add drafted token for each sequence
|
||||||
|
const llama_token id = cur_p->data[0].id;
|
||||||
|
|
||||||
|
common_sampler_accept(smpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
|
if (params.n_draft <= (int) result.size()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// only collect very high-confidence draft tokens
|
||||||
|
if (cur_p->data[0].p < params.p_min) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||||
|
|
||||||
|
// evaluate the drafted tokens on the draft model
|
||||||
|
llama_decode(ctx_dft, batch);
|
||||||
|
|
||||||
|
prompt_dft.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!spec->vocab_dft_compatible) {
|
||||||
|
std::string detokenized = common_detokenize(ctx_dft, result, true);
|
||||||
|
detokenized = replace_to_tgt(spec, detokenized);
|
||||||
|
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
|
||||||
|
result = common_tokenize(ctx_tgt, detokenized, false, true);
|
||||||
|
if (result.size() > (size_t)params.n_draft) {
|
||||||
|
result.resize(params.n_draft);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "common.h"
|
||||||
|
|
||||||
|
struct common_speculative;
|
||||||
|
|
||||||
|
struct common_speculative_params {
|
||||||
|
int n_draft = 16; // max drafted tokens
|
||||||
|
int n_reuse = 256;
|
||||||
|
|
||||||
|
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative * common_speculative_init(
|
||||||
|
struct llama_context * ctx_tgt,
|
||||||
|
struct llama_context * ctx_dft
|
||||||
|
);
|
||||||
|
|
||||||
|
void common_speculative_free(struct common_speculative * spec);
|
||||||
|
|
||||||
|
bool common_speculative_are_compatible(
|
||||||
|
const struct llama_context * ctx_tgt,
|
||||||
|
const struct llama_context * ctx_dft);
|
||||||
|
|
||||||
|
void common_speculative_add_replacement_tgt_dft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
const char *source, const char *dest);
|
||||||
|
|
||||||
|
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||||
|
llama_tokens common_speculative_gen_draft(
|
||||||
|
struct common_speculative * spec,
|
||||||
|
struct common_speculative_params params,
|
||||||
|
const llama_tokens & prompt,
|
||||||
|
llama_token id_last);
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
#include "unicode.h"
|
||||||
|
|
||||||
|
// implementation adopted from src/unicode.cpp
|
||||||
|
|
||||||
|
size_t utf8_sequence_length(unsigned char first_byte) {
|
||||||
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t highbits = static_cast<uint8_t>(first_byte) >> 4;
|
||||||
|
return lookup[highbits];
|
||||||
|
}
|
||||||
|
|
||||||
|
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) {
|
||||||
|
if (offset >= input.size()) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASCII fast path
|
||||||
|
if (!(input[offset] & 0x80)) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::SUCCESS, input[offset], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid: continuation byte as first byte
|
||||||
|
if (!(input[offset] & 0x40)) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2-byte sequence
|
||||||
|
if (!(input[offset] & 0x20)) {
|
||||||
|
if (offset + 1 >= input.size()) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||||
|
}
|
||||||
|
if ((input[offset + 1] & 0xc0) != 0x80) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||||
|
}
|
||||||
|
auto result = ((input[offset] & 0x1f) << 6) | (input[offset + 1] & 0x3f);
|
||||||
|
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3-byte sequence
|
||||||
|
if (!(input[offset] & 0x10)) {
|
||||||
|
if (offset + 2 >= input.size()) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||||
|
}
|
||||||
|
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||||
|
}
|
||||||
|
auto result = ((input[offset] & 0x0f) << 12) | ((input[offset + 1] & 0x3f) << 6) | (input[offset + 2] & 0x3f);
|
||||||
|
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4-byte sequence
|
||||||
|
if (!(input[offset] & 0x08)) {
|
||||||
|
if (offset + 3 >= input.size()) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INCOMPLETE);
|
||||||
|
}
|
||||||
|
if ((input[offset + 1] & 0xc0) != 0x80 || (input[offset + 2] & 0xc0) != 0x80 || (input[offset + 3] & 0xc0) != 0x80) {
|
||||||
|
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||||
|
}
|
||||||
|
auto result = ((input[offset] & 0x07) << 18) | ((input[offset + 1] & 0x3f) << 12) | ((input[offset + 2] & 0x3f) << 6) | (input[offset + 3] & 0x3f);
|
||||||
|
return utf8_parse_result(utf8_parse_result::SUCCESS, result, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid first byte
|
||||||
|
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string_view>
|
||||||
|
|
||||||
|
// UTF-8 parsing utilities for streaming-aware unicode support
|
||||||
|
|
||||||
|
struct utf8_parse_result {
|
||||||
|
uint32_t codepoint; // Decoded codepoint (only valid if status == SUCCESS)
|
||||||
|
size_t bytes_consumed; // How many bytes this codepoint uses (1-4)
|
||||||
|
enum status { SUCCESS, INCOMPLETE, INVALID } status;
|
||||||
|
|
||||||
|
utf8_parse_result(enum status s, uint32_t cp = 0, size_t bytes = 0)
|
||||||
|
: codepoint(cp), bytes_consumed(bytes), status(s) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Determine the expected length of a UTF-8 sequence from its first byte
|
||||||
|
// Returns 0 for invalid first bytes
|
||||||
|
size_t utf8_sequence_length(unsigned char first_byte);
|
||||||
|
|
||||||
|
// Parse a single UTF-8 codepoint from input
|
||||||
|
utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||||
+11334
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,491 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||||
|
project("ggml" C CXX ASM)
|
||||||
|
|
||||||
|
### GGML Version
|
||||||
|
set(GGML_VERSION_MAJOR 0)
|
||||||
|
set(GGML_VERSION_MINOR 9)
|
||||||
|
set(GGML_VERSION_PATCH 5)
|
||||||
|
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||||
|
|
||||||
|
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
if(GIT_EXE)
|
||||||
|
# Get current git commit hash
|
||||||
|
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||||
|
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||||
|
ERROR_QUIET
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the working directory is dirty (i.e., has uncommitted changes)
|
||||||
|
execute_process(COMMAND ${GIT_EXE} diff-index --quiet HEAD -- .
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
RESULT_VARIABLE GGML_GIT_DIRTY
|
||||||
|
ERROR_QUIET
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(GGML_VERSION "${GGML_VERSION_BASE}")
|
||||||
|
|
||||||
|
if(NOT GGML_BUILD_COMMIT)
|
||||||
|
set(GGML_BUILD_COMMIT "unknown")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Build the commit string with optional dirty flag
|
||||||
|
if(DEFINED GGML_GIT_DIRTY AND GGML_GIT_DIRTY EQUAL 1)
|
||||||
|
set(GGML_BUILD_COMMIT "${GGML_BUILD_COMMIT}-dirty")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(CheckIncludeFileCXX)
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
|
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
set(GGML_STANDALONE ON)
|
||||||
|
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
# configure project version
|
||||||
|
# TODO
|
||||||
|
else()
|
||||||
|
set(GGML_STANDALONE OFF)
|
||||||
|
|
||||||
|
if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (EMSCRIPTEN)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
|
||||||
|
option(GGML_WASM_SINGLE_FILE "ggml: embed WASM inside the generated ggml.js" ON)
|
||||||
|
else()
|
||||||
|
if (MINGW)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# remove the lib prefix on win32 mingw
|
||||||
|
if (WIN32)
|
||||||
|
set(CMAKE_STATIC_LIBRARY_PREFIX "")
|
||||||
|
set(CMAKE_SHARED_LIBRARY_PREFIX "")
|
||||||
|
set(CMAKE_SHARED_MODULE_PREFIX "")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT})
|
||||||
|
option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF)
|
||||||
|
set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL")
|
||||||
|
|
||||||
|
#
|
||||||
|
# option list
|
||||||
|
#
|
||||||
|
|
||||||
|
# TODO: mark all options as advanced when not GGML_STANDALONE
|
||||||
|
|
||||||
|
if (APPLE)
|
||||||
|
set(GGML_METAL_DEFAULT ON)
|
||||||
|
set(GGML_BLAS_DEFAULT ON)
|
||||||
|
set(GGML_BLAS_VENDOR_DEFAULT "Apple")
|
||||||
|
else()
|
||||||
|
set(GGML_METAL_DEFAULT OFF)
|
||||||
|
set(GGML_BLAS_DEFAULT OFF)
|
||||||
|
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
|
||||||
|
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
|
||||||
|
set(GGML_NATIVE_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(GGML_NATIVE_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# defaults
|
||||||
|
if (NOT GGML_LLAMAFILE_DEFAULT)
|
||||||
|
set(GGML_LLAMAFILE_DEFAULT OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT GGML_CUDA_GRAPHS_DEFAULT)
|
||||||
|
set(GGML_CUDA_GRAPHS_DEFAULT OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# general
|
||||||
|
option(GGML_STATIC "ggml: static link libraries" OFF)
|
||||||
|
option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT})
|
||||||
|
option(GGML_LTO "ggml: enable link time optimization" OFF)
|
||||||
|
option(GGML_CCACHE "ggml: use ccache if available" ON)
|
||||||
|
|
||||||
|
# debug
|
||||||
|
option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON)
|
||||||
|
option(GGML_ALL_WARNINGS_3RD_PARTY "ggml: enable all compiler warnings in 3rd party libs" OFF)
|
||||||
|
option(GGML_GPROF "ggml: enable gprof" OFF)
|
||||||
|
|
||||||
|
# build
|
||||||
|
option(GGML_FATAL_WARNINGS "ggml: enable -Werror flag" OFF)
|
||||||
|
|
||||||
|
# sanitizers
|
||||||
|
option(GGML_SANITIZE_THREAD "ggml: enable thread sanitizer" OFF)
|
||||||
|
option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF)
|
||||||
|
option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
|
||||||
|
|
||||||
|
# instruction set specific
|
||||||
|
if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)
|
||||||
|
set(INS_ENB OFF)
|
||||||
|
else()
|
||||||
|
set(INS_ENB ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}")
|
||||||
|
message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
|
||||||
|
message(DEBUG "INS_ENB : ${INS_ENB}")
|
||||||
|
|
||||||
|
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||||
|
option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||||
|
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||||
|
option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB})
|
||||||
|
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||||
|
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||||
|
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||||
|
option(GGML_BMI2 "ggml: enable BMI2" ${INS_ENB})
|
||||||
|
option(GGML_AVX512 "ggml: enable AVX512F" OFF)
|
||||||
|
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
||||||
|
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
||||||
|
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
|
||||||
|
if (NOT MSVC)
|
||||||
|
# in MSVC F16C and FMA is implied with AVX2/AVX512
|
||||||
|
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
||||||
|
option(GGML_F16C "ggml: enable F16C" ${INS_ENB})
|
||||||
|
# MSVC does not seem to support AMX
|
||||||
|
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
|
||||||
|
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
||||||
|
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
||||||
|
endif()
|
||||||
|
option(GGML_LASX "ggml: enable lasx" ON)
|
||||||
|
option(GGML_LSX "ggml: enable lsx" ON)
|
||||||
|
option(GGML_RVV "ggml: enable rvv" ON)
|
||||||
|
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
|
||||||
|
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
|
||||||
|
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
|
||||||
|
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
|
||||||
|
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
|
||||||
|
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
|
||||||
|
|
||||||
|
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||||
|
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||||
|
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||||
|
|
||||||
|
# ggml core
|
||||||
|
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
||||||
|
option(GGML_CPU "ggml: enable CPU backend" ON)
|
||||||
|
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
|
||||||
|
|
||||||
|
# 3rd party libs / backends
|
||||||
|
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
||||||
|
option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
|
||||||
|
set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
||||||
|
"ggml: BLAS library vendor")
|
||||||
|
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" ${GGML_LLAMAFILE_DEFAULT})
|
||||||
|
|
||||||
|
option(GGML_CUDA "ggml: use CUDA" OFF)
|
||||||
|
option(GGML_MUSA "ggml: use MUSA" OFF)
|
||||||
|
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
||||||
|
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
||||||
|
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||||
|
"ggml: max. batch size for using peer access")
|
||||||
|
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||||
|
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||||
|
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
|
||||||
|
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||||
|
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||||
|
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
|
||||||
|
"ggml: cuda link binary compression mode; requires cuda 12.8+")
|
||||||
|
set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")
|
||||||
|
|
||||||
|
option(GGML_HIP "ggml: use HIP" OFF)
|
||||||
|
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||||
|
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||||
|
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||||
|
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
|
||||||
|
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
|
||||||
|
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
|
||||||
|
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
|
||||||
|
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||||
|
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||||
|
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
||||||
|
option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF)
|
||||||
|
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
|
||||||
|
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
|
||||||
|
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||||
|
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
|
||||||
|
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||||
|
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
|
||||||
|
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||||
|
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
||||||
|
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||||
|
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||||
|
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||||
|
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||||
|
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
||||||
|
set (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
|
||||||
|
"ggml: metal minimum macOS version")
|
||||||
|
set (GGML_METAL_STD "" CACHE STRING "ggml: metal standard version (-std flag)")
|
||||||
|
option(GGML_OPENMP "ggml: use OpenMP" ON)
|
||||||
|
option(GGML_RPC "ggml: use RPC" OFF)
|
||||||
|
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||||
|
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||||
|
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||||
|
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
||||||
|
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||||
|
"ggml: sycl target device")
|
||||||
|
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||||
|
"ggml: sycl device architecture")
|
||||||
|
|
||||||
|
option(GGML_OPENCL "ggml: use OpenCL" OFF)
|
||||||
|
option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF)
|
||||||
|
option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON)
|
||||||
|
option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON)
|
||||||
|
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||||
|
"gmml: OpenCL API version to target")
|
||||||
|
|
||||||
|
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||||
|
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
|
||||||
|
|
||||||
|
# toolchain for vulkan-shaders-gen
|
||||||
|
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||||
|
|
||||||
|
option(GGML_ZENDNN "ggml: use ZenDNN" OFF)
|
||||||
|
option(ZENDNN_ROOT "ggml: path to ZenDNN installation" "")
|
||||||
|
|
||||||
|
# extra artifacts
|
||||||
|
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})
|
||||||
|
option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
|
||||||
|
|
||||||
|
#
|
||||||
|
# dependencies
|
||||||
|
#
|
||||||
|
|
||||||
|
set(CMAKE_C_STANDARD 11)
|
||||||
|
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||||
|
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||||
|
|
||||||
|
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||||
|
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
|
#
|
||||||
|
# build the library
|
||||||
|
#
|
||||||
|
|
||||||
|
add_subdirectory(src)
|
||||||
|
|
||||||
|
#
|
||||||
|
# tests and examples
|
||||||
|
#
|
||||||
|
|
||||||
|
if (GGML_BUILD_TESTS)
|
||||||
|
enable_testing()
|
||||||
|
add_subdirectory(tests)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
if (GGML_BUILD_EXAMPLES)
|
||||||
|
add_subdirectory(examples)
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
#
|
||||||
|
# install
|
||||||
|
#
|
||||||
|
|
||||||
|
include(CMakePackageConfigHelpers)
|
||||||
|
|
||||||
|
# all public headers
|
||||||
|
set(GGML_PUBLIC_HEADERS
|
||||||
|
include/ggml.h
|
||||||
|
include/ggml-cpu.h
|
||||||
|
include/ggml-alloc.h
|
||||||
|
include/ggml-backend.h
|
||||||
|
include/ggml-blas.h
|
||||||
|
include/ggml-cann.h
|
||||||
|
include/ggml-cpp.h
|
||||||
|
include/ggml-cuda.h
|
||||||
|
include/ggml-opt.h
|
||||||
|
include/ggml-metal.h
|
||||||
|
include/ggml-rpc.h
|
||||||
|
include/ggml-sycl.h
|
||||||
|
include/ggml-vulkan.h
|
||||||
|
include/ggml-webgpu.h
|
||||||
|
include/ggml-zendnn.h
|
||||||
|
include/gguf.h)
|
||||||
|
|
||||||
|
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||||
|
#if (GGML_METAL)
|
||||||
|
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
|
||||||
|
#endif()
|
||||||
|
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
|
||||||
|
install(TARGETS ggml-base LIBRARY)
|
||||||
|
|
||||||
|
if (GGML_STANDALONE)
|
||||||
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||||
|
@ONLY)
|
||||||
|
|
||||||
|
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||||
|
DESTINATION share/pkgconfig)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Create CMake package
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Capture variables prefixed with GGML_.
|
||||||
|
|
||||||
|
set(variable_set_statements
|
||||||
|
"
|
||||||
|
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
|
||||||
|
####### Any changes to this file will be overwritten by the next CMake run #######
|
||||||
|
|
||||||
|
")
|
||||||
|
|
||||||
|
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
|
||||||
|
|
||||||
|
get_cmake_property(all_variables VARIABLES)
|
||||||
|
foreach(variable_name IN LISTS all_variables)
|
||||||
|
if(variable_name MATCHES "^GGML_")
|
||||||
|
string(REPLACE ";" "\\;"
|
||||||
|
variable_value "${${variable_name}}")
|
||||||
|
|
||||||
|
set(variable_set_statements
|
||||||
|
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
||||||
|
|
||||||
|
# Create the CMake package and set install location.
|
||||||
|
|
||||||
|
set(GGML_INSTALL_VERSION ${GGML_VERSION})
|
||||||
|
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||||
|
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||||
|
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||||
|
|
||||||
|
configure_package_config_file(
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||||
|
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
|
||||||
|
PATH_VARS GGML_INCLUDE_INSTALL_DIR
|
||||||
|
GGML_LIB_INSTALL_DIR
|
||||||
|
GGML_BIN_INSTALL_DIR)
|
||||||
|
|
||||||
|
write_basic_package_version_file(
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||||
|
VERSION ${GGML_INSTALL_VERSION}
|
||||||
|
COMPATIBILITY SameMajorVersion)
|
||||||
|
|
||||||
|
target_compile_definitions(ggml-base PRIVATE
|
||||||
|
GGML_VERSION="${GGML_INSTALL_VERSION}"
|
||||||
|
GGML_COMMIT="${GGML_BUILD_COMMIT}"
|
||||||
|
)
|
||||||
|
message(STATUS "ggml version: ${GGML_INSTALL_VERSION}")
|
||||||
|
message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}")
|
||||||
|
|
||||||
|
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||||
|
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
set(MSVC_WARNING_FLAGS
|
||||||
|
/wd4005 # Macro redefinition
|
||||||
|
/wd4244 # Conversion from one type to another type, possible loss of data
|
||||||
|
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
|
||||||
|
/wd4305 # Conversion from 'type1' to 'type2', possible loss of data
|
||||||
|
/wd4566 # Conversion from 'char' to 'wchar_t', possible loss of data
|
||||||
|
/wd4996 # Disable POSIX deprecation warnings
|
||||||
|
/wd4702 # Unreachable code warnings
|
||||||
|
)
|
||||||
|
set(MSVC_COMPILE_OPTIONS
|
||||||
|
"$<$<COMPILE_LANGUAGE:C>:/utf-8>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:CXX>:/utf-8>"
|
||||||
|
)
|
||||||
|
function(configure_msvc_target target_name)
|
||||||
|
if(TARGET ${target_name})
|
||||||
|
target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS})
|
||||||
|
target_compile_options(${target_name} PRIVATE ${MSVC_COMPILE_OPTIONS})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
configure_msvc_target(ggml-base)
|
||||||
|
configure_msvc_target(ggml)
|
||||||
|
configure_msvc_target(ggml-cpu)
|
||||||
|
configure_msvc_target(ggml-cpu-x64)
|
||||||
|
configure_msvc_target(ggml-cpu-sse42)
|
||||||
|
configure_msvc_target(ggml-cpu-sandybridge)
|
||||||
|
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
||||||
|
# skipping ggml-cpu-ivybridge
|
||||||
|
# skipping ggml-cpu-piledriver
|
||||||
|
configure_msvc_target(ggml-cpu-haswell)
|
||||||
|
configure_msvc_target(ggml-cpu-skylakex)
|
||||||
|
configure_msvc_target(ggml-cpu-cannonlake)
|
||||||
|
configure_msvc_target(ggml-cpu-cascadelake)
|
||||||
|
configure_msvc_target(ggml-cpu-icelake)
|
||||||
|
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
|
||||||
|
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
|
||||||
|
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
|
||||||
|
# skipping ggml-cpu-cooperlake
|
||||||
|
# skipping ggml-cpu-zen4
|
||||||
|
configure_msvc_target(ggml-cpu-alderlake)
|
||||||
|
# MSVC doesn't support AMX
|
||||||
|
# skipping ggml-cpu-sapphirerapids
|
||||||
|
|
||||||
|
if (GGML_BUILD_EXAMPLES)
|
||||||
|
configure_msvc_target(common-ggml)
|
||||||
|
configure_msvc_target(common)
|
||||||
|
|
||||||
|
configure_msvc_target(mnist-common)
|
||||||
|
configure_msvc_target(mnist-eval)
|
||||||
|
configure_msvc_target(mnist-train)
|
||||||
|
|
||||||
|
configure_msvc_target(gpt-2-ctx)
|
||||||
|
configure_msvc_target(gpt-2-alloc)
|
||||||
|
configure_msvc_target(gpt-2-backend)
|
||||||
|
configure_msvc_target(gpt-2-sched)
|
||||||
|
configure_msvc_target(gpt-2-quantize)
|
||||||
|
configure_msvc_target(gpt-2-batched)
|
||||||
|
|
||||||
|
configure_msvc_target(gpt-j)
|
||||||
|
configure_msvc_target(gpt-j-quantize)
|
||||||
|
|
||||||
|
configure_msvc_target(magika)
|
||||||
|
configure_msvc_target(yolov3-tiny)
|
||||||
|
configure_msvc_target(sam)
|
||||||
|
|
||||||
|
configure_msvc_target(simple-ctx)
|
||||||
|
configure_msvc_target(simple-backend)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_BUILD_TESTS)
|
||||||
|
configure_msvc_target(test-mul-mat)
|
||||||
|
configure_msvc_target(test-arange)
|
||||||
|
configure_msvc_target(test-backend-ops)
|
||||||
|
configure_msvc_target(test-cont)
|
||||||
|
configure_msvc_target(test-conv-transpose)
|
||||||
|
configure_msvc_target(test-conv-transpose-1d)
|
||||||
|
configure_msvc_target(test-conv1d)
|
||||||
|
configure_msvc_target(test-conv2d)
|
||||||
|
configure_msvc_target(test-conv2d-dw)
|
||||||
|
configure_msvc_target(test-customop)
|
||||||
|
configure_msvc_target(test-dup)
|
||||||
|
configure_msvc_target(test-opt)
|
||||||
|
configure_msvc_target(test-pool)
|
||||||
|
endif ()
|
||||||
|
endif()
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
find_package(Git)
|
||||||
|
|
||||||
|
# the commit's SHA1
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_SHA1
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
|
||||||
|
# the date of the commit
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_DATE
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
|
|
||||||
|
# the subject of the commit
|
||||||
|
execute_process(COMMAND
|
||||||
|
"${GIT_EXECUTABLE}" log -1 --format=%s
|
||||||
|
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||||
|
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
||||||
|
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
function(ggml_get_flags CCID CCVER)
|
||||||
|
set(C_FLAGS "")
|
||||||
|
set(CXX_FLAGS "")
|
||||||
|
|
||||||
|
if (CCID MATCHES "Clang")
|
||||||
|
set(C_FLAGS -Wunreachable-code-break -Wunreachable-code-return)
|
||||||
|
set(CXX_FLAGS -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi)
|
||||||
|
|
||||||
|
if (
|
||||||
|
(CCID STREQUAL "Clang" AND CCVER VERSION_GREATER_EQUAL 3.8.0) OR
|
||||||
|
(CCID STREQUAL "AppleClang" AND CCVER VERSION_GREATER_EQUAL 7.3.0)
|
||||||
|
)
|
||||||
|
list(APPEND C_FLAGS -Wdouble-promotion)
|
||||||
|
endif()
|
||||||
|
elseif (CCID STREQUAL "GNU")
|
||||||
|
set(C_FLAGS -Wdouble-promotion)
|
||||||
|
set(CXX_FLAGS -Wno-array-bounds)
|
||||||
|
|
||||||
|
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
||||||
|
list(APPEND CXX_FLAGS -Wextra-semi)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE)
|
||||||
|
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(ggml_get_system_arch)
|
||||||
|
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
||||||
|
CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
||||||
|
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||||
|
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
||||||
|
set(GGML_SYSTEM_ARCH "ARM" PARENT_SCOPE)
|
||||||
|
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR
|
||||||
|
CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
||||||
|
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||||
|
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$"))
|
||||||
|
set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE)
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc|power")
|
||||||
|
set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE)
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||||
|
set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE)
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
|
||||||
|
set(GGML_SYSTEM_ARCH "riscv64" PARENT_SCOPE)
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||||
|
set(GGML_SYSTEM_ARCH "s390x" PARENT_SCOPE)
|
||||||
|
else()
|
||||||
|
set(GGML_SYSTEM_ARCH "UNKNOWN" PARENT_SCOPE)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
@@ -0,0 +1,191 @@
|
|||||||
|
@PACKAGE_INIT@
|
||||||
|
|
||||||
|
@GGML_VARIABLES_EXPANDED@
|
||||||
|
|
||||||
|
# Find all dependencies before creating any target.
|
||||||
|
include(CMakeFindDependencyMacro)
|
||||||
|
find_dependency(Threads)
|
||||||
|
if (NOT GGML_SHARED_LIB)
|
||||||
|
set(GGML_CPU_INTERFACE_LINK_LIBRARIES "")
|
||||||
|
set(GGML_CPU_INTERFACE_LINK_OPTIONS "")
|
||||||
|
|
||||||
|
if (APPLE AND GGML_ACCELERATE)
|
||||||
|
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||||
|
if(NOT ACCELERATE_FRAMEWORK)
|
||||||
|
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_OPENMP_ENABLED)
|
||||||
|
find_dependency(OpenMP)
|
||||||
|
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CPU_HBM)
|
||||||
|
find_library(memkind memkind)
|
||||||
|
if(NOT memkind)
|
||||||
|
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_BLAS)
|
||||||
|
find_dependency(BLAS)
|
||||||
|
list(APPEND GGML_BLAS_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
||||||
|
list(APPEND GGML_BLAS_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CUDA)
|
||||||
|
set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "")
|
||||||
|
find_dependency(CUDAToolkit)
|
||||||
|
if (GGML_STATIC)
|
||||||
|
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cudart_static>)
|
||||||
|
if (WIN32)
|
||||||
|
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas> $<LINK_ONLY:CUDA::cublasLt>)
|
||||||
|
else()
|
||||||
|
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cublas_static> $<LINK_ONLY:CUDA::cublasLt_static>)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
if (NOT GGML_CUDA_NO_VMM)
|
||||||
|
list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:CUDA::cuda_driver>)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_METAL)
|
||||||
|
find_library(FOUNDATION_LIBRARY Foundation)
|
||||||
|
find_library(METAL_FRAMEWORK Metal)
|
||||||
|
find_library(METALKIT_FRAMEWORK MetalKit)
|
||||||
|
if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK)
|
||||||
|
set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
set(GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||||
|
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_OPENCL)
|
||||||
|
find_dependency(OpenCL)
|
||||||
|
set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:OpenCL::OpenCL>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_VULKAN)
|
||||||
|
find_dependency(Vulkan)
|
||||||
|
set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $<LINK_ONLY:Vulkan::Vulkan>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_HIP)
|
||||||
|
find_dependency(hip)
|
||||||
|
find_dependency(hipblas)
|
||||||
|
find_dependency(rocblas)
|
||||||
|
set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_SYCL)
|
||||||
|
set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "")
|
||||||
|
find_package(DNNL)
|
||||||
|
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||||
|
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
||||||
|
endif()
|
||||||
|
if (WIN32)
|
||||||
|
find_dependency(IntelSYCL)
|
||||||
|
find_dependency(MKL)
|
||||||
|
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||||
|
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||||
|
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||||
|
|
||||||
|
if(NOT TARGET ggml::ggml)
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
find_library(GGML_LIBRARY ggml
|
||||||
|
REQUIRED
|
||||||
|
HINTS ${GGML_LIB_DIR}
|
||||||
|
NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
||||||
|
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||||
|
set_target_properties(ggml::ggml
|
||||||
|
PROPERTIES
|
||||||
|
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||||
|
|
||||||
|
find_library(GGML_BASE_LIBRARY ggml-base
|
||||||
|
REQUIRED
|
||||||
|
HINTS ${GGML_LIB_DIR}
|
||||||
|
NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
||||||
|
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||||
|
set_target_properties(ggml::ggml-base
|
||||||
|
PROPERTIES
|
||||||
|
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||||
|
|
||||||
|
set(_ggml_all_targets "")
|
||||||
|
if (NOT GGML_BACKEND_DL)
|
||||||
|
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||||
|
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||||
|
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||||
|
|
||||||
|
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||||
|
REQUIRED
|
||||||
|
HINTS ${GGML_LIB_DIR}
|
||||||
|
NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
||||||
|
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||||
|
|
||||||
|
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||||
|
set_target_properties(ggml::${_ggml_backend}
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||||
|
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||||
|
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||||
|
INTERFACE_COMPILE_FEATURES c_std_90
|
||||||
|
POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
|
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||||
|
if(is_cpu_variant)
|
||||||
|
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||||
|
set_target_properties(ggml::${_ggml_backend}
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||||
|
|
||||||
|
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||||
|
set_target_properties(ggml::${_ggml_backend}
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
else()
|
||||||
|
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
|
||||||
|
set_target_properties(ggml::${_ggml_backend}
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||||
|
|
||||||
|
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||||
|
set_target_properties(ggml::${_ggml_backend}
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
|
||||||
|
set_target_properties(ggml::ggml
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}")
|
||||||
|
|
||||||
|
add_library(ggml::all INTERFACE IMPORTED)
|
||||||
|
set_target_properties(ggml::all
|
||||||
|
PROPERTIES
|
||||||
|
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
check_required_components(ggml)
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
||||||
|
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
||||||
|
typedef struct ggml_backend * ggml_backend_t;
|
||||||
|
|
||||||
|
// Tensor allocator
|
||||||
|
struct ggml_tallocr {
|
||||||
|
ggml_backend_buffer_t buffer;
|
||||||
|
void * base;
|
||||||
|
size_t alignment;
|
||||||
|
size_t offset;
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// Graph allocator
|
||||||
|
/*
|
||||||
|
Example usage:
|
||||||
|
ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
|
||||||
|
|
||||||
|
// optional: create a worst-case graph and reserve the buffers to avoid reallocations
|
||||||
|
ggml_gallocr_reserve(galloc, build_graph(max_batch));
|
||||||
|
|
||||||
|
// allocate the graph
|
||||||
|
struct ggml_cgraph * graph = build_graph(batch);
|
||||||
|
ggml_gallocr_alloc_graph(galloc, graph);
|
||||||
|
|
||||||
|
printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
|
||||||
|
|
||||||
|
// evaluate the graph
|
||||||
|
ggml_backend_graph_compute(backend, graph);
|
||||||
|
*/
|
||||||
|
|
||||||
|
// special tensor flags for use with the graph allocator:
|
||||||
|
// ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
|
||||||
|
// ggml_set_output(): output tensors are never freed and never overwritten
|
||||||
|
|
||||||
|
typedef struct ggml_gallocr * ggml_gallocr_t;
|
||||||
|
|
||||||
|
GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
|
||||||
|
GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
|
||||||
|
|
||||||
|
// pre-allocate buffers from a measure graph - does not allocate or modify the graph
|
||||||
|
// call with a worst-case graph to avoid buffer reallocations
|
||||||
|
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
|
||||||
|
// returns false if the buffer allocation failed
|
||||||
|
// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes
|
||||||
|
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
||||||
|
GGML_API void ggml_gallocr_reserve_n_size(
|
||||||
|
ggml_gallocr_t galloc,
|
||||||
|
struct ggml_cgraph * graph,
|
||||||
|
const int * node_buffer_ids,
|
||||||
|
const int * leaf_buffer_ids,
|
||||||
|
size_t * sizes);
|
||||||
|
GGML_API bool ggml_gallocr_reserve_n(
|
||||||
|
ggml_gallocr_t galloc,
|
||||||
|
struct ggml_cgraph * graph,
|
||||||
|
const int * node_buffer_ids,
|
||||||
|
const int * leaf_buffer_ids);
|
||||||
|
|
||||||
|
// automatic reallocation if the topology changes when using a single buffer
|
||||||
|
// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
|
||||||
|
GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
|
||||||
|
|
||||||
|
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||||
|
|
||||||
|
// Utils
|
||||||
|
// Create a buffer and allocate all the tensors in a ggml_context
|
||||||
|
// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft
|
||||||
|
GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,373 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
|
||||||
|
#ifdef GGML_BACKEND_SHARED
|
||||||
|
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||||
|
# ifdef GGML_BACKEND_BUILD
|
||||||
|
# define GGML_BACKEND_API __declspec(dllexport) extern
|
||||||
|
# else
|
||||||
|
# define GGML_BACKEND_API __declspec(dllimport) extern
|
||||||
|
# endif
|
||||||
|
# else
|
||||||
|
# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
|
||||||
|
# endif
|
||||||
|
#else
|
||||||
|
# define GGML_BACKEND_API extern
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
|
||||||
|
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
|
||||||
|
typedef struct ggml_backend_event * ggml_backend_event_t;
|
||||||
|
typedef struct ggml_backend * ggml_backend_t;
|
||||||
|
typedef void * ggml_backend_graph_plan_t;
|
||||||
|
typedef struct ggml_backend_reg * ggml_backend_reg_t;
|
||||||
|
typedef struct ggml_backend_device * ggml_backend_dev_t;
|
||||||
|
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend buffer type
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
|
||||||
|
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
|
||||||
|
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend buffer
|
||||||
|
//
|
||||||
|
|
||||||
|
enum ggml_backend_buffer_usage {
|
||||||
|
GGML_BACKEND_BUFFER_USAGE_ANY = 0,
|
||||||
|
GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
|
||||||
|
GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
|
||||||
|
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
|
||||||
|
GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
|
||||||
|
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
|
||||||
|
|
||||||
|
// tensor copy between different backends
|
||||||
|
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend (stream)
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
|
||||||
|
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
|
||||||
|
GGML_API void ggml_backend_free(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
|
||||||
|
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
|
||||||
|
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
|
||||||
|
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
|
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
|
|
||||||
|
// "offset" refers to the offset in tensor->data for setting/getting data
|
||||||
|
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||||
|
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||||
|
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
|
|
||||||
|
GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
|
||||||
|
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
|
||||||
|
|
||||||
|
// NOTE: will be removed, use device version instead
|
||||||
|
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
|
GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
|
||||||
|
|
||||||
|
// asynchronous copy
|
||||||
|
// the copy is performed after all the currently queued operations in backend_src
|
||||||
|
// backend_dst will wait for the copy to complete before performing other operations
|
||||||
|
// automatic fallback to sync copy if async is not supported
|
||||||
|
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Events
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device);
|
||||||
|
GGML_API void ggml_backend_event_free(ggml_backend_event_t event);
|
||||||
|
GGML_API void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend);
|
||||||
|
GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event);
|
||||||
|
GGML_API void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend device
|
||||||
|
//
|
||||||
|
|
||||||
|
enum ggml_backend_dev_type {
|
||||||
|
// CPU device using system memory
|
||||||
|
GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||||
|
// GPU device using dedicated memory
|
||||||
|
GGML_BACKEND_DEVICE_TYPE_GPU,
|
||||||
|
// integrated GPU device using host memory
|
||||||
|
GGML_BACKEND_DEVICE_TYPE_IGPU,
|
||||||
|
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
||||||
|
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
||||||
|
};
|
||||||
|
|
||||||
|
// functionality supported by the device
|
||||||
|
struct ggml_backend_dev_caps {
|
||||||
|
// asynchronous operations
|
||||||
|
bool async;
|
||||||
|
// pinned host buffer
|
||||||
|
bool host_buffer;
|
||||||
|
// creating buffers from host ptr
|
||||||
|
bool buffer_from_host_ptr;
|
||||||
|
// event synchronization
|
||||||
|
bool events;
|
||||||
|
};
|
||||||
|
|
||||||
|
// all the device properties
|
||||||
|
struct ggml_backend_dev_props {
|
||||||
|
// device name
|
||||||
|
const char * name;
|
||||||
|
// device description
|
||||||
|
const char * description;
|
||||||
|
// device free memory in bytes
|
||||||
|
size_t memory_free;
|
||||||
|
// device total memory in bytes
|
||||||
|
size_t memory_total;
|
||||||
|
// device type
|
||||||
|
enum ggml_backend_dev_type type;
|
||||||
|
// device id
|
||||||
|
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
|
||||||
|
// if the id is unknown, this should be NULL
|
||||||
|
const char * device_id;
|
||||||
|
// device capabilities
|
||||||
|
struct ggml_backend_dev_caps caps;
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
||||||
|
GGML_API const char * ggml_backend_dev_description(ggml_backend_dev_t device);
|
||||||
|
GGML_API void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total);
|
||||||
|
GGML_API enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device);
|
||||||
|
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||||
|
|
||||||
|
GGML_API bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
|
||||||
|
GGML_API bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft);
|
||||||
|
GGML_API bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend (reg)
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_API const char * ggml_backend_reg_name(ggml_backend_reg_t reg);
|
||||||
|
GGML_API size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg);
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
|
||||||
|
GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
|
||||||
|
|
||||||
|
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
|
||||||
|
|
||||||
|
// Split buffer type for tensor parallelism
|
||||||
|
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
|
||||||
|
// Set the number of threads for the backend
|
||||||
|
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
|
||||||
|
// Get additional buffer types provided by the device (returns a NULL-terminated array)
|
||||||
|
typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
|
||||||
|
// Set the abort callback for the backend
|
||||||
|
typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
// Get a list of feature flags supported by the backend (returns a NULL-terminated array)
|
||||||
|
struct ggml_backend_feature {
|
||||||
|
const char * name;
|
||||||
|
const char * value;
|
||||||
|
};
|
||||||
|
typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend registry
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
|
||||||
|
|
||||||
|
// Backend (reg) enumeration
|
||||||
|
GGML_API size_t ggml_backend_reg_count(void);
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name);
|
||||||
|
|
||||||
|
// Device enumeration
|
||||||
|
GGML_API size_t ggml_backend_dev_count(void);
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name);
|
||||||
|
GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type);
|
||||||
|
|
||||||
|
// Direct backend (stream) initialization
|
||||||
|
// = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params)
|
||||||
|
GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
|
||||||
|
// = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
|
||||||
|
GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
|
||||||
|
// = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
|
||||||
|
GGML_API ggml_backend_t ggml_backend_init_best(void);
|
||||||
|
|
||||||
|
// Load a backend from a dynamic library and register it
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
|
||||||
|
// Unload a backend if loaded dynamically and unregister it
|
||||||
|
GGML_API void ggml_backend_unload(ggml_backend_reg_t reg);
|
||||||
|
// Load all known backends from dynamic libraries
|
||||||
|
GGML_API void ggml_backend_load_all(void);
|
||||||
|
GGML_API void ggml_backend_load_all_from_path(const char * dir_path);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Backend scheduler
|
||||||
|
//
|
||||||
|
|
||||||
|
// The backend scheduler allows for multiple backend devices to be used together
|
||||||
|
// Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
|
||||||
|
// The backends are selected based on:
|
||||||
|
// - the backend that supports the operation
|
||||||
|
// - the location of the pre-allocated tensors (e.g. the weights)
|
||||||
|
/*
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
// operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
|
||||||
|
// preferrably to run on the same backend as the buffer
|
||||||
|
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||||
|
|
||||||
|
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
|
||||||
|
|
||||||
|
// initialize buffers from a max size graph (optional)
|
||||||
|
reserve_graph = build_graph(sched, max_batch_size);
|
||||||
|
|
||||||
|
// manually assign nodes to a backend (optional, should not be needed in most cases)
|
||||||
|
struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
|
||||||
|
ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
|
||||||
|
|
||||||
|
ggml_backend_sched_reserve(sched, reserve_graph);
|
||||||
|
|
||||||
|
// compute
|
||||||
|
graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there are graph inputs:
|
||||||
|
graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
|
||||||
|
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||||
|
ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
|
||||||
|
ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
|
||||||
|
ggml_backend_sched_graph_compute(sched, graph); // execute the graph
|
||||||
|
|
||||||
|
// as an alternative to the above it is also possible to assign the inputs to a dedicated context and
|
||||||
|
// allocate them statically via ggml_backend_alloc_ctx_tensors
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
typedef struct ggml_backend_sched * ggml_backend_sched_t;
|
||||||
|
|
||||||
|
// Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback)
|
||||||
|
// when ask == true, the scheduler wants to know if the user wants to observe this node
|
||||||
|
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
|
||||||
|
//
|
||||||
|
// when ask == false, the scheduler is passing the node tensor to the user for observation
|
||||||
|
// if the user returns false, the scheduler will cancel the graph compute
|
||||||
|
//
|
||||||
|
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
|
||||||
|
|
||||||
|
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
|
||||||
|
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
|
||||||
|
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
|
// Initialize backend buffers from a measure graph
|
||||||
|
GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes);
|
||||||
|
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
|
||||||
|
|
||||||
|
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
|
||||||
|
|
||||||
|
// Get the number of splits of the last graph
|
||||||
|
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
|
||||||
|
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
// Split graph without allocating it
|
||||||
|
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
|
|
||||||
|
// Allocate and compute graph on the backend scheduler
|
||||||
|
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
|
||||||
|
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
|
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
|
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
|
// Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
|
||||||
|
// This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
|
||||||
|
// The correct way to use this API is to discard the deallocated tensors and create new ones.
|
||||||
|
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
|
||||||
|
|
||||||
|
// Set a callback to be called for each resulting node during graph compute
|
||||||
|
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Utils
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_backend_graph_copy {
|
||||||
|
ggml_backend_buffer_t buffer;
|
||||||
|
struct ggml_context * ctx_allocated;
|
||||||
|
struct ggml_context * ctx_unallocated;
|
||||||
|
struct ggml_cgraph * graph;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Copy a graph to a different backend
|
||||||
|
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
|
||||||
|
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
|
||||||
|
|
||||||
|
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
||||||
|
|
||||||
|
// Compare the output of two backends
|
||||||
|
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
|
||||||
|
|
||||||
|
// Tensor initialization
|
||||||
|
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
||||||
|
GGML_API enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// CPU buffer types are always available
|
||||||
|
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// number of threads used for conversion to float
|
||||||
|
// for openblas and blis, this will also set the number of threads used for blas operations
|
||||||
|
GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2023-2024 The ggml authors
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
* of this software and associated documentation files (the "Software"), to
|
||||||
|
* deal in the Software without restriction, including without limitation the
|
||||||
|
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
||||||
|
* sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
* furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
||||||
|
* IN THE SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Maximum number of CANN devices supported.
|
||||||
|
*/
|
||||||
|
#define GGML_CANN_MAX_DEVICES 16
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initializes the CANN backend for a specified device.
|
||||||
|
*
|
||||||
|
* This function initializes the CANN backend for the given device.
|
||||||
|
* It verifies the device index, allocates a context, and creates a backend
|
||||||
|
* instance.
|
||||||
|
*
|
||||||
|
* @param device The index of the device to initialize.
|
||||||
|
* @return A pointer to the initialized backend instance, or nullptr on failure.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks if a given backend is a CANN backend.
|
||||||
|
*
|
||||||
|
* This function verifies if the provided backend is a CANN backend by comparing
|
||||||
|
* its GUID with the CANN backend's GUID.
|
||||||
|
*
|
||||||
|
* @param backend The backend instance to check.
|
||||||
|
* @return True if the backend is a CANN backend, false otherwise.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Retrieves the CANN buffer type for a specified device.
|
||||||
|
*
|
||||||
|
* This function initializes and returns the buffer type interface associated
|
||||||
|
* with the given device. It ensures thread-safe access using a mutex.
|
||||||
|
*
|
||||||
|
* @param device The device index for which to retrieve the buffer type.
|
||||||
|
* @return A pointer to the buffer type interface for the specified device, or
|
||||||
|
* nullptr if the device index is out of range.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t
|
||||||
|
ggml_backend_cann_buffer_type(int32_t device);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Retrieves the number of CANN devices available.
|
||||||
|
*
|
||||||
|
* This function returns the number of CANN devices available based on
|
||||||
|
* information obtained from `ggml_cann_info()`.
|
||||||
|
*
|
||||||
|
* @return The number of CANN devices available.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
|
||||||
|
*
|
||||||
|
* @return A pointer to the host buffer type interface.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Retrieves the description of a specific CANN device.
|
||||||
|
*
|
||||||
|
* This function sets the specified device, retrieves the SoC name,
|
||||||
|
* and writes it into the provided description buffer.
|
||||||
|
*
|
||||||
|
* @param device The device index to retrieve the description for.
|
||||||
|
* @param description Pointer to a buffer where the description will be written.
|
||||||
|
* @param description_size Size of the description buffer.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API void ggml_backend_cann_get_device_description(
|
||||||
|
int32_t device, char* description, size_t description_size);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Retrieves the memory information of a specific CANN device.
|
||||||
|
*
|
||||||
|
* This function sets the specified device, retrieves the free and total
|
||||||
|
* memory information of the specified type (ACL_HBM_MEM), and stores them
|
||||||
|
* in the provided pointers.
|
||||||
|
*
|
||||||
|
* @param device The device index to retrieve memory information for.
|
||||||
|
* @param free Pointer to a variable where the free memory size will be stored.
|
||||||
|
* @param total Pointer to a variable where the total memory size will be
|
||||||
|
* stored.
|
||||||
|
*/
|
||||||
|
GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
|
||||||
|
size_t* free,
|
||||||
|
size_t* total);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifndef __cplusplus
|
||||||
|
#error "This header is for C++ only"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#include "gguf.h"
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
// Smart pointers for ggml types
|
||||||
|
|
||||||
|
// ggml
|
||||||
|
|
||||||
|
struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };
|
||||||
|
struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<ggml_context, ggml_context_deleter> ggml_context_ptr;
|
||||||
|
typedef std::unique_ptr<gguf_context, gguf_context_deleter> gguf_context_ptr;
|
||||||
|
|
||||||
|
// ggml-alloc
|
||||||
|
|
||||||
|
struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<ggml_gallocr, ggml_gallocr_deleter> ggml_gallocr_ptr;
|
||||||
|
|
||||||
|
// ggml-backend
|
||||||
|
|
||||||
|
struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } };
|
||||||
|
struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };
|
||||||
|
struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } };
|
||||||
|
struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<ggml_backend, ggml_backend_deleter> ggml_backend_ptr;
|
||||||
|
typedef std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter> ggml_backend_buffer_ptr;
|
||||||
|
typedef std::unique_ptr<ggml_backend_event, ggml_backend_event_deleter> ggml_backend_event_ptr;
|
||||||
|
typedef std::unique_ptr<ggml_backend_sched, ggml_backend_sched_deleter> ggml_backend_sched_ptr;
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// the compute plan that needs to be prepared for ggml_graph_compute()
|
||||||
|
// since https://github.com/ggml-org/ggml/issues/287
|
||||||
|
struct ggml_cplan {
|
||||||
|
size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
|
||||||
|
uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
|
||||||
|
|
||||||
|
int n_threads;
|
||||||
|
struct ggml_threadpool * threadpool;
|
||||||
|
|
||||||
|
// abort ggml_graph_compute when true
|
||||||
|
ggml_abort_callback abort_callback;
|
||||||
|
void * abort_callback_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
// numa strategies
|
||||||
|
enum ggml_numa_strategy {
|
||||||
|
GGML_NUMA_STRATEGY_DISABLED = 0,
|
||||||
|
GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
|
||||||
|
GGML_NUMA_STRATEGY_ISOLATE = 2,
|
||||||
|
GGML_NUMA_STRATEGY_NUMACTL = 3,
|
||||||
|
GGML_NUMA_STRATEGY_MIRROR = 4,
|
||||||
|
GGML_NUMA_STRATEGY_COUNT
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
|
||||||
|
GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
|
||||||
|
|
||||||
|
GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
|
||||||
|
GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
||||||
|
GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
|
||||||
|
GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
||||||
|
GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
|
||||||
|
GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
||||||
|
GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
|
||||||
|
|
||||||
|
GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
|
||||||
|
GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
|
||||||
|
GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
|
||||||
|
GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
|
||||||
|
GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
|
||||||
|
|
||||||
|
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
||||||
|
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
||||||
|
GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
|
||||||
|
const struct ggml_cgraph * cgraph,
|
||||||
|
int n_threads, /* = GGML_DEFAULT_N_THREADS */
|
||||||
|
struct ggml_threadpool * threadpool /* = NULL */ );
|
||||||
|
GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
||||||
|
|
||||||
|
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
||||||
|
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
||||||
|
GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
|
||||||
|
|
||||||
|
//
|
||||||
|
// system info
|
||||||
|
//
|
||||||
|
|
||||||
|
// x86
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_bmi2 (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_f16c (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_fma (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
|
||||||
|
// ARM
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_neon (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_dotprod (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_sve (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_sme (void);
|
||||||
|
// other
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_get_rvv_vlen (void); // risc-v vector length in bytes
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
|
||||||
|
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
|
||||||
|
|
||||||
|
// Internal types and functions exposed for tests and benchmarks
|
||||||
|
|
||||||
|
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
||||||
|
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||||
|
|
||||||
|
struct ggml_type_traits_cpu {
|
||||||
|
ggml_from_float_t from_float;
|
||||||
|
ggml_vec_dot_t vec_dot;
|
||||||
|
enum ggml_type vec_dot_type;
|
||||||
|
int64_t nrows; // number of rows to process simultaneously
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_cpu_init(void);
|
||||||
|
|
||||||
|
//
|
||||||
|
// CPU backend
|
||||||
|
//
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
|
||||||
|
GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_HIP
|
||||||
|
#define GGML_CUDA_NAME "ROCm"
|
||||||
|
#define GGML_CUBLAS_NAME "hipBLAS"
|
||||||
|
#elif defined(GGML_USE_MUSA)
|
||||||
|
#define GGML_CUDA_NAME "MUSA"
|
||||||
|
#define GGML_CUBLAS_NAME "muBLAS"
|
||||||
|
#else
|
||||||
|
#define GGML_CUDA_NAME "CUDA"
|
||||||
|
#define GGML_CUBLAS_NAME "cuBLAS"
|
||||||
|
#endif
|
||||||
|
#define GGML_CUDA_MAX_DEVICES 16
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// device buffer
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
||||||
|
|
||||||
|
// split tensor buffer that splits matrices by rows across multiple devices
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
|
||||||
|
|
||||||
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
|
||||||
|
GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
// Note: this description is outdated
|
||||||
|
//
|
||||||
|
// An interface allowing to compute ggml_cgraph with Metal
|
||||||
|
//
|
||||||
|
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
|
||||||
|
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
|
||||||
|
//
|
||||||
|
// How it works?
|
||||||
|
//
|
||||||
|
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
|
||||||
|
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
|
||||||
|
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
|
||||||
|
//
|
||||||
|
// You only need to make sure that all memory buffers that you used during the graph creation
|
||||||
|
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
|
||||||
|
// used during the graph evaluation to determine the arguments of the compute kernels.
|
||||||
|
//
|
||||||
|
// Synchronization between device and host memory (for example for input and output tensors)
|
||||||
|
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
struct ggml_tensor;
|
||||||
|
struct ggml_cgraph;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// backend API
|
||||||
|
// user-code should use only these functions
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO: remove in the future
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
||||||
|
|
||||||
|
// helper to check if the device supports a specific family
|
||||||
|
// ideally, the user code should be doing these checks
|
||||||
|
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||||
|
GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
||||||
|
|
||||||
|
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
||||||
|
GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
#ifndef GGML_OPENCL_H
|
||||||
|
#define GGML_OPENCL_H
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// backend API
|
||||||
|
//
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void);
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // GGML_OPENCL_H
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
// This file contains functionality for training models using GGML.
|
||||||
|
// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
|
||||||
|
// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
|
||||||
|
//
|
||||||
|
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct ggml_opt_dataset;
|
||||||
|
struct ggml_opt_context;
|
||||||
|
struct ggml_opt_result;
|
||||||
|
|
||||||
|
typedef struct ggml_opt_dataset * ggml_opt_dataset_t;
|
||||||
|
typedef struct ggml_opt_context * ggml_opt_context_t;
|
||||||
|
typedef struct ggml_opt_result * ggml_opt_result_t;
|
||||||
|
|
||||||
|
// ====== Loss ======
|
||||||
|
|
||||||
|
// built-in loss types, i.e. the built-in quantities minimized by the optimizer
|
||||||
|
// custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
|
||||||
|
enum ggml_opt_loss_type {
|
||||||
|
GGML_OPT_LOSS_TYPE_MEAN,
|
||||||
|
GGML_OPT_LOSS_TYPE_SUM,
|
||||||
|
GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
|
||||||
|
GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
|
||||||
|
};
|
||||||
|
|
||||||
|
// ====== Dataset ======
|
||||||
|
|
||||||
|
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
|
||||||
|
enum ggml_type type_data, // the type for the internal data tensor
|
||||||
|
enum ggml_type type_label, // the type for the internal labels tensor
|
||||||
|
int64_t ne_datapoint, // number of elements per datapoint
|
||||||
|
int64_t ne_label, // number of elements per label
|
||||||
|
int64_t ndata, // total number of datapoints/labels
|
||||||
|
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
|
||||||
|
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
|
||||||
|
|
||||||
|
// get underlying tensors that store the data
|
||||||
|
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
|
||||||
|
|
||||||
|
// shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
|
||||||
|
GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);
|
||||||
|
|
||||||
|
// get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
|
||||||
|
GGML_API void ggml_opt_dataset_get_batch(
|
||||||
|
ggml_opt_dataset_t dataset,
|
||||||
|
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
|
||||||
|
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
|
||||||
|
int64_t ibatch);
|
||||||
|
GGML_API void ggml_opt_dataset_get_batch_host(
|
||||||
|
ggml_opt_dataset_t dataset,
|
||||||
|
void * data_batch,
|
||||||
|
size_t nb_data_batch,
|
||||||
|
void * labels_batch,
|
||||||
|
int64_t ibatch);
|
||||||
|
|
||||||
|
// ====== Model / Context ======
|
||||||
|
|
||||||
|
enum ggml_opt_build_type {
|
||||||
|
GGML_OPT_BUILD_TYPE_FORWARD = 10,
|
||||||
|
GGML_OPT_BUILD_TYPE_GRAD = 20,
|
||||||
|
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum ggml_opt_optimizer_type {
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
||||||
|
|
||||||
|
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
||||||
|
};
|
||||||
|
|
||||||
|
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||||
|
struct ggml_opt_optimizer_params {
|
||||||
|
struct {
|
||||||
|
float alpha; // learning rate
|
||||||
|
float beta1; // first AdamW momentum
|
||||||
|
float beta2; // second AdamW momentum
|
||||||
|
float eps; // epsilon for numerical stability
|
||||||
|
float wd; // weight decay - 0.0f to disable
|
||||||
|
} adamw;
|
||||||
|
struct {
|
||||||
|
float alpha; // learning rate
|
||||||
|
float wd; // weight decay
|
||||||
|
} sgd;
|
||||||
|
};
|
||||||
|
|
||||||
|
// callback to calculate optimizer parameters prior to a backward pass
|
||||||
|
// userdata can be used to pass arbitrary data
|
||||||
|
typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
|
||||||
|
|
||||||
|
// returns the default optimizer params (constant, hard-coded values)
|
||||||
|
// userdata is not used
|
||||||
|
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
|
||||||
|
|
||||||
|
// casts userdata to ggml_opt_optimizer_params and returns it
|
||||||
|
GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata);
|
||||||
|
|
||||||
|
// parameters for initializing a new optimization context
|
||||||
|
struct ggml_opt_params {
|
||||||
|
ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
|
||||||
|
|
||||||
|
// by default the forward graph needs to be reconstructed for each eval
|
||||||
|
// if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically
|
||||||
|
struct ggml_context * ctx_compute;
|
||||||
|
struct ggml_tensor * inputs;
|
||||||
|
struct ggml_tensor * outputs;
|
||||||
|
|
||||||
|
enum ggml_opt_loss_type loss_type;
|
||||||
|
enum ggml_opt_build_type build_type;
|
||||||
|
|
||||||
|
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
||||||
|
|
||||||
|
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||||
|
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||||
|
|
||||||
|
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
||||||
|
enum ggml_opt_optimizer_type optimizer;
|
||||||
|
};
|
||||||
|
|
||||||
|
// get parameters for an optimization context with defaults set where possible
|
||||||
|
// parameters for which no sensible defaults exist are supplied as arguments to this function
|
||||||
|
GGML_API struct ggml_opt_params ggml_opt_default_params(
|
||||||
|
ggml_backend_sched_t backend_sched,
|
||||||
|
enum ggml_opt_loss_type loss_type);
|
||||||
|
|
||||||
|
GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
|
||||||
|
GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
|
||||||
|
|
||||||
|
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
||||||
|
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||||
|
|
||||||
|
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
||||||
|
|
||||||
|
// get underlying tensors that store data
|
||||||
|
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
|
||||||
|
|
||||||
|
// get the gradient accumulator for a node from the forward graph
|
||||||
|
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
||||||
|
|
||||||
|
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
||||||
|
|
||||||
|
// ====== Optimization Result ======
|
||||||
|
|
||||||
|
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||||
|
GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
|
||||||
|
GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
|
||||||
|
|
||||||
|
// get data from result, uncertainties are optional and can be ignored by passing NULL
|
||||||
|
GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints
|
||||||
|
GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value
|
||||||
|
GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values
|
||||||
|
GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value
|
||||||
|
|
||||||
|
// ====== Computation ======
|
||||||
|
|
||||||
|
// if not using static graphs, this function must be called prior to ggml_opt_alloc
|
||||||
|
GGML_API void ggml_opt_prepare_alloc(
|
||||||
|
ggml_opt_context_t opt_ctx,
|
||||||
|
struct ggml_context * ctx_compute,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct ggml_tensor * inputs,
|
||||||
|
struct ggml_tensor * outputs);
|
||||||
|
|
||||||
|
// allocate the next graph for evaluation, either forward or forward + backward
|
||||||
|
// must be called exactly once prior to calling ggml_opt_eval
|
||||||
|
GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward);
|
||||||
|
|
||||||
|
// do forward pass, increment result if not NULL, do backward pass if allocated
|
||||||
|
GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
|
||||||
|
|
||||||
|
// ############################################################################
|
||||||
|
// ## The high-level functions start here. They do not depend on any private ##
|
||||||
|
// ## functions or structs and can be copied to and adapted for user code. ##
|
||||||
|
// ############################################################################
|
||||||
|
|
||||||
|
// ====== Intended Usage ======
|
||||||
|
//
|
||||||
|
// 1. Select the appropriate loss for your problem.
|
||||||
|
// 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
|
||||||
|
// Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
|
||||||
|
// 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
|
||||||
|
// The first context should contain the model parameters and inputs and be allocated statically in user code.
|
||||||
|
// The second context should contain all other tensors and will be (re)allocated automatically.
|
||||||
|
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
|
||||||
|
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
|
||||||
|
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
|
||||||
|
|
||||||
|
// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
|
||||||
|
typedef void (*ggml_opt_epoch_callback)(
|
||||||
|
bool train, // true after training evaluation, false after validation evaluation
|
||||||
|
ggml_opt_context_t opt_ctx,
|
||||||
|
ggml_opt_dataset_t dataset,
|
||||||
|
ggml_opt_result_t result, // result associated with the dataset subsection
|
||||||
|
int64_t ibatch, // number of batches that have been evaluated so far
|
||||||
|
int64_t ibatch_max, // total number of batches in this dataset subsection
|
||||||
|
int64_t t_start_us); // time at which the evaluation on the dataset subsection was started
|
||||||
|
|
||||||
|
// do training on front of dataset, do evaluation only on back of dataset
|
||||||
|
GGML_API void ggml_opt_epoch(
|
||||||
|
ggml_opt_context_t opt_ctx,
|
||||||
|
ggml_opt_dataset_t dataset,
|
||||||
|
ggml_opt_result_t result_train, // result to increment during training, ignored if NULL
|
||||||
|
ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL
|
||||||
|
int64_t idata_split, // data index at which to split training and evaluation
|
||||||
|
ggml_opt_epoch_callback callback_train,
|
||||||
|
ggml_opt_epoch_callback callback_eval);
|
||||||
|
|
||||||
|
// callback that prints a progress bar on stderr
|
||||||
|
GGML_API void ggml_opt_epoch_callback_progress_bar(
|
||||||
|
bool train,
|
||||||
|
ggml_opt_context_t opt_ctx,
|
||||||
|
ggml_opt_dataset_t dataset,
|
||||||
|
ggml_opt_result_t result,
|
||||||
|
int64_t ibatch,
|
||||||
|
int64_t ibatch_max,
|
||||||
|
int64_t t_start_us);
|
||||||
|
|
||||||
|
// fit model defined by inputs and outputs to dataset
|
||||||
|
GGML_API void ggml_opt_fit(
|
||||||
|
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
|
||||||
|
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
|
||||||
|
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
|
||||||
|
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||||
|
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||||
|
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||||
|
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
||||||
|
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||||
|
int64_t nepoch, // how many times the dataset should be iterated over
|
||||||
|
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
||||||
|
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
||||||
|
bool silent); // whether or not info prints to stderr should be suppressed
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define RPC_PROTO_MAJOR_VERSION 3
|
||||||
|
#define RPC_PROTO_MINOR_VERSION 6
|
||||||
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||||
|
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#define GGML_SYCL_NAME "SYCL"
|
||||||
|
#define GGML_SYCL_MAX_DEVICES 48
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// devide buffer
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
||||||
|
|
||||||
|
// split tensor buffer that splits matrices by rows across multiple devices
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
|
||||||
|
|
||||||
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);
|
||||||
|
GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
|
||||||
|
GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,
|
||||||
|
char *description,
|
||||||
|
size_t description_size);
|
||||||
|
GGML_BACKEND_API int ggml_backend_sycl_get_device_count();
|
||||||
|
GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
||||||
|
|
||||||
|
// SYCL doesn't support registering host memory, keep here for reference
|
||||||
|
// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
||||||
|
// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define GGML_VK_NAME "Vulkan"
|
||||||
|
#define GGML_VK_MAX_DEVICES 16
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);
|
||||||
|
GGML_BACKEND_API int ggml_backend_vk_get_device_count(void);
|
||||||
|
GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
|
||||||
|
GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
|
||||||
|
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define GGML_WEBGPU_NAME "WebGPU"
|
||||||
|
|
||||||
|
// Needed for examples in ggml
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_webgpu_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_webgpu_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// device buffer
|
||||||
|
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_zendnn_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_zendnn(ggml_backend_t backend);
|
||||||
|
|
||||||
|
// number of threads used for zendnn operations
|
||||||
|
GGML_BACKEND_API void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zendnn_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,202 @@
|
|||||||
|
// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
|
||||||
|
// GGUF files have the following structure:
|
||||||
|
//
|
||||||
|
// 1. File magic "GGUF" (4 bytes).
|
||||||
|
// 2. File version (uint32_t).
|
||||||
|
// 3. Number of ggml tensors in file (int64_t).
|
||||||
|
// 4. Number of key-value-pairs in file (int64_t).
|
||||||
|
// 5. For each KV pair:
|
||||||
|
// 1. The key (string).
|
||||||
|
// 2. The value type (gguf_type).
|
||||||
|
// 3a. If the value type is GGUF_TYPE_ARRAY:
|
||||||
|
// 1. The type of the array (gguf_type).
|
||||||
|
// 2. The number of elements in the array (uint64_t).
|
||||||
|
// 3. The binary representation of each element in the array.
|
||||||
|
// 3b. Otherwise:
|
||||||
|
// 1. The binary representation of the value.
|
||||||
|
// 6. For each ggml tensor:
|
||||||
|
// 1. The tensor name (string).
|
||||||
|
// 2. The number of dimensions of the tensor (uint32_t).
|
||||||
|
// 3. For each dimension:
|
||||||
|
// 1. The size of the tensor in the dimension (int64_t).
|
||||||
|
// 4. The tensor data type (ggml_type).
|
||||||
|
// 5. The tensor data offset in the tensor data binary blob (uint64_t).
|
||||||
|
// 7. The tensor data binary blob (optional, aligned).
|
||||||
|
//
|
||||||
|
// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
|
||||||
|
// All enums are stored as int32_t.
|
||||||
|
// All bool values are stored as int8_t.
|
||||||
|
// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
|
||||||
|
// otherwise GGUF_DEFAULT_ALIGNMENT is used.
|
||||||
|
//
|
||||||
|
// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#define GGUF_MAGIC "GGUF"
|
||||||
|
#define GGUF_VERSION 3
|
||||||
|
|
||||||
|
#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
|
||||||
|
|
||||||
|
#define GGUF_DEFAULT_ALIGNMENT 32
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// types that can be stored as GGUF KV data
|
||||||
|
enum gguf_type {
|
||||||
|
GGUF_TYPE_UINT8 = 0,
|
||||||
|
GGUF_TYPE_INT8 = 1,
|
||||||
|
GGUF_TYPE_UINT16 = 2,
|
||||||
|
GGUF_TYPE_INT16 = 3,
|
||||||
|
GGUF_TYPE_UINT32 = 4,
|
||||||
|
GGUF_TYPE_INT32 = 5,
|
||||||
|
GGUF_TYPE_FLOAT32 = 6,
|
||||||
|
GGUF_TYPE_BOOL = 7,
|
||||||
|
GGUF_TYPE_STRING = 8,
|
||||||
|
GGUF_TYPE_ARRAY = 9,
|
||||||
|
GGUF_TYPE_UINT64 = 10,
|
||||||
|
GGUF_TYPE_INT64 = 11,
|
||||||
|
GGUF_TYPE_FLOAT64 = 12,
|
||||||
|
GGUF_TYPE_COUNT, // marks the end of the enum
|
||||||
|
};
|
||||||
|
|
||||||
|
struct gguf_context;
|
||||||
|
|
||||||
|
struct gguf_init_params {
|
||||||
|
bool no_alloc;
|
||||||
|
|
||||||
|
// if not NULL, create a ggml_context and allocate the tensor data in it
|
||||||
|
struct ggml_context ** ctx;
|
||||||
|
};
|
||||||
|
|
||||||
|
GGML_API struct gguf_context * gguf_init_empty(void);
|
||||||
|
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
||||||
|
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
||||||
|
|
||||||
|
GGML_API void gguf_free(struct gguf_context * ctx);
|
||||||
|
|
||||||
|
GGML_API const char * gguf_type_name(enum gguf_type type);
|
||||||
|
|
||||||
|
GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx);
|
||||||
|
GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx);
|
||||||
|
GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
|
||||||
|
|
||||||
|
GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx);
|
||||||
|
GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
|
||||||
|
GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
|
||||||
|
GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
|
||||||
|
// will abort if the wrong type is used for the key
|
||||||
|
GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
|
||||||
|
// get raw pointer to the first element of the array with the given key_id
|
||||||
|
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
||||||
|
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
||||||
|
|
||||||
|
// get ith C string from array with given key_id
|
||||||
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||||
|
|
||||||
|
GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||||
|
GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
|
||||||
|
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
|
||||||
|
GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id);
|
||||||
|
GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id);
|
||||||
|
GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id);
|
||||||
|
|
||||||
|
// removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
|
||||||
|
GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
|
||||||
|
|
||||||
|
// overrides an existing KV pair or adds a new one, the new KV pair is always at the back
|
||||||
|
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
||||||
|
GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
|
||||||
|
GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
|
||||||
|
GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
|
||||||
|
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
|
||||||
|
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
|
||||||
|
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
|
||||||
|
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
|
||||||
|
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
|
||||||
|
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
|
||||||
|
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
|
||||||
|
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
|
||||||
|
|
||||||
|
// creates a new array with n elements of the given type and copies the corresponding number of bytes from data
|
||||||
|
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
|
||||||
|
|
||||||
|
// creates a new array with n strings and copies the corresponding strings from data
|
||||||
|
GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
|
||||||
|
|
||||||
|
// set or add KV pairs from another context
|
||||||
|
GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
|
||||||
|
|
||||||
|
// add tensor to GGUF context, tensor name must be unique
|
||||||
|
GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
|
||||||
|
// in such a way that the tensor data remains as one contiguous block (except for padding)
|
||||||
|
GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
|
||||||
|
|
||||||
|
// assumes that at least gguf_get_tensor_size bytes can be read from data
|
||||||
|
GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
|
||||||
|
|
||||||
|
// writing gguf files can be done in 3 ways:
|
||||||
|
//
|
||||||
|
// - write the entire gguf_context to a binary file in a single pass:
|
||||||
|
//
|
||||||
|
// gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
|
||||||
|
//
|
||||||
|
// - write only the meta data to a file, then re-open the file and append the tensor data:
|
||||||
|
//
|
||||||
|
// gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
|
||||||
|
// FILE * f = fopen(fname, "ab");
|
||||||
|
// fwrite(f, ...); // write tensor data
|
||||||
|
// fclose(f);
|
||||||
|
//
|
||||||
|
// - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
|
||||||
|
//
|
||||||
|
// FILE * f = fopen(fname, "wb");
|
||||||
|
// const size_t size_meta = gguf_get_meta_size(ctx);
|
||||||
|
// fseek(f, size_meta, SEEK_SET);
|
||||||
|
// fwrite(f, ...); // write tensor data
|
||||||
|
// void * data = malloc(size_meta);
|
||||||
|
// gguf_get_meta_data(ctx, data);
|
||||||
|
// rewind(f);
|
||||||
|
// fwrite(data, 1, data, f);
|
||||||
|
// free(data);
|
||||||
|
// fclose(f);
|
||||||
|
//
|
||||||
|
|
||||||
|
// write the entire context to a binary file
|
||||||
|
GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
|
||||||
|
|
||||||
|
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
||||||
|
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
||||||
|
|
||||||
|
// writes the meta data to pointer "data"
|
||||||
|
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -0,0 +1,490 @@
|
|||||||
|
include(CheckCXXCompilerFlag)
|
||||||
|
include("../cmake/common.cmake")
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES})
|
||||||
|
|
||||||
|
# enable libstdc++ assertions for debug builds
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (GGML_SANITIZE_THREAD)
|
||||||
|
add_compile_options(-fsanitize=thread)
|
||||||
|
link_libraries (-fsanitize=thread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_SANITIZE_ADDRESS)
|
||||||
|
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||||
|
link_libraries (-fsanitize=address)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_SANITIZE_UNDEFINED)
|
||||||
|
add_compile_options(-fsanitize=undefined)
|
||||||
|
link_libraries (-fsanitize=undefined)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_FATAL_WARNINGS)
|
||||||
|
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||||
|
list(APPEND C_FLAGS -Werror)
|
||||||
|
list(APPEND CXX_FLAGS -Werror)
|
||||||
|
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||||
|
add_compile_options(/WX)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_ALL_WARNINGS)
|
||||||
|
if (NOT MSVC)
|
||||||
|
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
|
||||||
|
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
|
||||||
|
-Werror=implicit-int -Werror=implicit-function-declaration)
|
||||||
|
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
|
||||||
|
|
||||||
|
list(APPEND C_FLAGS ${WARNING_FLAGS})
|
||||||
|
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
|
||||||
|
|
||||||
|
ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
|
||||||
|
|
||||||
|
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${C_FLAGS};${GF_C_FLAGS}>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:CXX>:${CXX_FLAGS};${GF_CXX_FLAGS}>")
|
||||||
|
else()
|
||||||
|
# todo : msvc
|
||||||
|
set(C_FLAGS "")
|
||||||
|
set(CXX_FLAGS "")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_LTO)
|
||||||
|
include(CheckIPOSupported)
|
||||||
|
check_ipo_supported(RESULT result OUTPUT output)
|
||||||
|
if (result)
|
||||||
|
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
||||||
|
else()
|
||||||
|
message(WARNING "IPO is not supported: ${output}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)
|
||||||
|
find_program(GGML_CCACHE_FOUND ccache)
|
||||||
|
find_program(GGML_SCCACHE_FOUND sccache)
|
||||||
|
|
||||||
|
if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
|
||||||
|
if(GGML_CCACHE_FOUND)
|
||||||
|
set(GGML_CCACHE_VARIANT ccache)
|
||||||
|
else()
|
||||||
|
set(GGML_CCACHE_VARIANT sccache)
|
||||||
|
endif()
|
||||||
|
# TODO: should not be set globally
|
||||||
|
if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)
|
||||||
|
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl")
|
||||||
|
else ()
|
||||||
|
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
|
||||||
|
endif ()
|
||||||
|
set(ENV{CCACHE_SLOPPINESS} time_macros)
|
||||||
|
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
|
||||||
|
endif ()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# this version of Apple ld64 is buggy
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
|
||||||
|
ERROR_VARIABLE output
|
||||||
|
OUTPUT_QUIET
|
||||||
|
)
|
||||||
|
|
||||||
|
if (output MATCHES "dyld-1015\.7")
|
||||||
|
add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# architecture specific
|
||||||
|
# TODO: probably these flags need to be tweaked on some architectures
|
||||||
|
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
if (MSVC)
|
||||||
|
string(TOLOWER "${CMAKE_GENERATOR_PLATFORM}" CMAKE_GENERATOR_PLATFORM_LWR)
|
||||||
|
message(STATUS "CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}")
|
||||||
|
else ()
|
||||||
|
set(CMAKE_GENERATOR_PLATFORM_LWR "")
|
||||||
|
endif ()
|
||||||
|
ggml_get_system_arch()
|
||||||
|
message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}")
|
||||||
|
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (GGML_STATIC)
|
||||||
|
if (UNIX AND NOT APPLE)
|
||||||
|
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a;.so")
|
||||||
|
endif()
|
||||||
|
add_link_options(-static)
|
||||||
|
if (MINGW)
|
||||||
|
add_link_options(-static-libgcc -static-libstdc++)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
if (GGML_GPROF)
|
||||||
|
add_compile_options(-pg)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# POSIX conformance
|
||||||
|
#
|
||||||
|
|
||||||
|
# clock_gettime came in POSIX.1b (1993)
|
||||||
|
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
|
||||||
|
# posix_memalign came in POSIX.1-2001 / SUSv3
|
||||||
|
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
|
||||||
|
|
||||||
|
# Somehow in OpenBSD whenever POSIX conformance is specified
|
||||||
|
# some string functions rely on locale_t availability,
|
||||||
|
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||||
|
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||||
|
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||||
|
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
|
||||||
|
# in order to define _SC_PHYS_PAGES.
|
||||||
|
else()
|
||||||
|
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Data types, macros and functions related to controlling CPU affinity and
|
||||||
|
# some memory allocation are available on Linux through GNU extensions in libc
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux" OR CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||||
|
add_compile_definitions(_GNU_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
|
||||||
|
# and on macOS its availability depends on enabling Darwin extensions
|
||||||
|
# similarly on DragonFly, enabling BSD extensions is necessary
|
||||||
|
if (
|
||||||
|
CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
|
||||||
|
CMAKE_SYSTEM_NAME MATCHES "iOS" OR
|
||||||
|
CMAKE_SYSTEM_NAME MATCHES "tvOS" OR
|
||||||
|
CMAKE_SYSTEM_NAME MATCHES "DragonFly"
|
||||||
|
)
|
||||||
|
add_compile_definitions(_DARWIN_C_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# alloca is a non-standard interface that is not visible on BSDs when
|
||||||
|
# POSIX conformance is specified, but not all of them provide a clean way
|
||||||
|
# to enable it in such cases
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
|
||||||
|
add_compile_definitions(__BSD_VISIBLE)
|
||||||
|
endif()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
|
||||||
|
add_compile_definitions(_NETBSD_SOURCE)
|
||||||
|
endif()
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||||
|
add_compile_definitions(_BSD_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ggml
|
||||||
|
|
||||||
|
if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS)
|
||||||
|
message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(ggml-base
|
||||||
|
../include/ggml.h
|
||||||
|
../include/ggml-alloc.h
|
||||||
|
../include/ggml-backend.h
|
||||||
|
../include/ggml-cpp.h
|
||||||
|
../include/ggml-opt.h
|
||||||
|
../include/gguf.h
|
||||||
|
ggml.c
|
||||||
|
ggml.cpp
|
||||||
|
ggml-alloc.c
|
||||||
|
ggml-backend.cpp
|
||||||
|
ggml-opt.cpp
|
||||||
|
ggml-threading.cpp
|
||||||
|
ggml-threading.h
|
||||||
|
ggml-quants.c
|
||||||
|
ggml-quants.h
|
||||||
|
gguf.cpp)
|
||||||
|
|
||||||
|
set_target_properties(ggml-base PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(ggml-base PRIVATE .)
|
||||||
|
if (GGML_BACKEND_DL)
|
||||||
|
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_SCHED_NO_REALLOC)
|
||||||
|
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(ggml
|
||||||
|
ggml-backend-reg.cpp)
|
||||||
|
add_library(ggml::ggml ALIAS ggml)
|
||||||
|
|
||||||
|
set_target_properties(ggml PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
|
||||||
|
if (GGML_BACKEND_DIR)
|
||||||
|
if (NOT GGML_BACKEND_DL)
|
||||||
|
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
||||||
|
endif()
|
||||||
|
target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_link_libraries(ggml PUBLIC ggml-base)
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
target_link_libraries(ggml PRIVATE dl)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
function(ggml_add_backend_library backend)
|
||||||
|
if (GGML_BACKEND_DL)
|
||||||
|
add_library(${backend} MODULE ${ARGN})
|
||||||
|
# write the shared library to the output directory
|
||||||
|
set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||||
|
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL)
|
||||||
|
add_dependencies(ggml ${backend})
|
||||||
|
if (GGML_BACKEND_DIR)
|
||||||
|
install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR})
|
||||||
|
else()
|
||||||
|
install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
add_library(${backend} ${ARGN})
|
||||||
|
target_link_libraries(ggml PUBLIC ${backend})
|
||||||
|
install(TARGETS ${backend} LIBRARY)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
target_link_libraries(${backend} PRIVATE ggml-base)
|
||||||
|
target_include_directories(${backend} PRIVATE ..)
|
||||||
|
|
||||||
|
if (${BUILD_SHARED_LIBS})
|
||||||
|
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
|
||||||
|
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Set versioning properties for all backend libraries
|
||||||
|
# Building a MODULE library with a version is not supported on macOS (https://gitlab.kitware.com/cmake/cmake/-/issues/20782)
|
||||||
|
if (NOT (APPLE AND GGML_BACKEND_DL))
|
||||||
|
set_target_properties(${backend} PROPERTIES
|
||||||
|
VERSION ${GGML_VERSION}
|
||||||
|
SOVERSION ${GGML_VERSION_MAJOR}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||||
|
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||||
|
CACHE INTERNAL "List of backends for cmake package")
|
||||||
|
else()
|
||||||
|
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
|
||||||
|
if(has_backend EQUAL -1)
|
||||||
|
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
|
||||||
|
CACHE INTERNAL "List of backends for cmake package")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(ggml_add_backend backend)
|
||||||
|
string(TOUPPER "GGML_${backend}" backend_id)
|
||||||
|
if (${backend_id})
|
||||||
|
string(TOLOWER "ggml-${backend}" backend_target)
|
||||||
|
add_subdirectory(${backend_target})
|
||||||
|
message(STATUS "Including ${backend} backend")
|
||||||
|
if (NOT GGML_BACKEND_DL)
|
||||||
|
string(TOUPPER "GGML_USE_${backend}" backend_use)
|
||||||
|
target_compile_definitions(ggml PUBLIC ${backend_use})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
function(ggml_add_cpu_backend_variant tag_name)
|
||||||
|
set(GGML_CPU_TAG_NAME ${tag_name})
|
||||||
|
# other: OPENMP LLAMAFILE CPU_HBM
|
||||||
|
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
|
foreach (feat NATIVE
|
||||||
|
SSE42
|
||||||
|
AVX AVX2 BMI2 AVX_VNNI FMA F16C
|
||||||
|
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
||||||
|
AMX_TILE AMX_INT8 AMX_BF16)
|
||||||
|
set(GGML_${feat} OFF)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_${feat} ON)
|
||||||
|
endforeach()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "ARM")
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
|
endforeach()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
|
endforeach()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
foreach (feat VXE2 NNPA)
|
||||||
|
set(GGML_INTERNAL_${feat} OFF)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
|
endforeach()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||||
|
foreach (feat RVV)
|
||||||
|
set(GGML_INTERNAL_${feat} OFF)
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
foreach (feat ${ARGN})
|
||||||
|
set(GGML_INTERNAL_${feat} ON)
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
ggml_add_backend(CPU)
|
||||||
|
|
||||||
|
if (GGML_CPU_ALL_VARIANTS)
|
||||||
|
if (NOT GGML_BACKEND_DL)
|
||||||
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||||
|
elseif (GGML_CPU_ARM_ARCH)
|
||||||
|
message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS")
|
||||||
|
endif()
|
||||||
|
if (GGML_SYSTEM_ARCH STREQUAL "x86")
|
||||||
|
ggml_add_cpu_backend_variant(x64)
|
||||||
|
ggml_add_cpu_backend_variant(sse42 SSE42)
|
||||||
|
ggml_add_cpu_backend_variant(sandybridge SSE42 AVX)
|
||||||
|
if (NOT MSVC)
|
||||||
|
# __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
||||||
|
ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C)
|
||||||
|
ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA)
|
||||||
|
endif()
|
||||||
|
ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2)
|
||||||
|
ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512)
|
||||||
|
ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI)
|
||||||
|
ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI)
|
||||||
|
ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI)
|
||||||
|
if (NOT MSVC)
|
||||||
|
# MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?!
|
||||||
|
# https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170
|
||||||
|
# https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170
|
||||||
|
ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16)
|
||||||
|
ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16)
|
||||||
|
endif()
|
||||||
|
ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI)
|
||||||
|
if (NOT MSVC)
|
||||||
|
# MSVC doesn't support AMX
|
||||||
|
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||||
|
endif()
|
||||||
|
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
# Many of these features are optional so we build versions with popular
|
||||||
|
# combinations and name the backends based on the version they were
|
||||||
|
# first released with
|
||||||
|
ggml_add_cpu_backend_variant(armv8.0_1)
|
||||||
|
ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD)
|
||||||
|
ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||||
|
ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE)
|
||||||
|
ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)
|
||||||
|
ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)
|
||||||
|
ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)
|
||||||
|
ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)
|
||||||
|
elseif (CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||||
|
# Android-specific backends with SoC-compatible feature sets
|
||||||
|
ggml_add_cpu_backend_variant(android_armv8.0_1)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
|
||||||
|
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
|
||||||
|
elseif (APPLE)
|
||||||
|
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
||||||
|
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
||||||
|
ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
ggml_add_cpu_backend_variant(power0)
|
||||||
|
ggml_add_cpu_backend_variant(power7_1 POWER7)
|
||||||
|
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
|
||||||
|
ggml_add_cpu_backend_variant(power8_1 POWER8)
|
||||||
|
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
|
||||||
|
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
|
||||||
|
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
|
||||||
|
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
ggml_add_cpu_backend_variant(z15 Z15 VXE2)
|
||||||
|
ggml_add_cpu_backend_variant(z16 Z16 VXE2 NNPA)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
|
ggml_add_cpu_backend_variant(riscv64_0)
|
||||||
|
ggml_add_cpu_backend_variant(riscv64_v RVV)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported RISC-V target OS: ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
elseif (GGML_CPU)
|
||||||
|
ggml_add_cpu_backend_variant_impl("")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
ggml_add_backend(BLAS)
|
||||||
|
ggml_add_backend(CANN)
|
||||||
|
ggml_add_backend(CUDA)
|
||||||
|
ggml_add_backend(HIP)
|
||||||
|
ggml_add_backend(METAL)
|
||||||
|
ggml_add_backend(MUSA)
|
||||||
|
ggml_add_backend(RPC)
|
||||||
|
ggml_add_backend(SYCL)
|
||||||
|
ggml_add_backend(Vulkan)
|
||||||
|
ggml_add_backend(WebGPU)
|
||||||
|
ggml_add_backend(zDNN)
|
||||||
|
ggml_add_backend(OpenCL)
|
||||||
|
ggml_add_backend(Hexagon)
|
||||||
|
ggml_add_backend(ZenDNN)
|
||||||
|
|
||||||
|
foreach (target ggml-base ggml)
|
||||||
|
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||||
|
target_compile_features (${target} PRIVATE c_std_11 cxx_std_17) # don't bump
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
target_link_libraries(ggml-base PRIVATE Threads::Threads)
|
||||||
|
|
||||||
|
find_library(MATH_LIBRARY m)
|
||||||
|
if (MATH_LIBRARY)
|
||||||
|
if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT})
|
||||||
|
target_link_libraries(ggml-base PRIVATE m)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||||
|
target_link_libraries(ggml-base PRIVATE dl)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME MATCHES "visionOS")
|
||||||
|
target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
foreach (target ggml-base ggml)
|
||||||
|
set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_compile_definitions(${target} PRIVATE GGML_BUILD)
|
||||||
|
target_compile_definitions(${target} PUBLIC GGML_SHARED)
|
||||||
|
endforeach()
|
||||||
|
endif()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user