added v1.12.0 sources

This commit is contained in:
TheK0tYaRa 2026-03-06 23:48:46 +03:00
commit 63e90f94a8
1402 changed files with 322616 additions and 0 deletions

View file

@ -0,0 +1,12 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[build]
rustflags = ["--cfg", "tokio_unstable"]
# On x86_64, we target the x86-64-v2 psABI, as it is a good compromise between
# modern CPU instructions and compatibility.
[target.x86_64-unknown-linux-gnu]
rustflags = ["--cfg", "tokio_unstable", "-C", "target-cpu=x86-64-v2"]

View file

@ -0,0 +1,10 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
comment: false
flag_management:
default_rules:
carryforward: true

View file

@ -0,0 +1,7 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[profile.default]
retries = 1

View file

@ -0,0 +1,17 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
target/
crates/*/target
crates/*/node_modules
frontend/node_modules
frontend/dist
docs/
.devcontainer/
.github/
.gitignore
Dockerfile
.dockerignore
docker-bake.hcl

View file

@ -0,0 +1,16 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
root = true
[*]
charset=utf-8
end_of_line = lf
[*.{ts,tsx,cts,mts,js,cjs,mjs,css,json,graphql}]
indent_style = space
indent_size = 2
insert_final_newline = true
trim_trailing_whitespace = true

View file

@ -0,0 +1 @@
* @element-hq/mas-maintainers

View file

@ -0,0 +1,25 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Build the frontend assets
description: Installs Node.js and builds the frontend assets from the frontend directory
runs:
using: composite
steps:
- name: Install Node
uses: actions/setup-node@v6.0.0
with:
node-version: "24"
- name: Install dependencies
run: npm ci
working-directory: ./frontend
shell: sh
- name: Build the frontend assets
run: npm run build
working-directory: ./frontend
shell: sh

View file

@ -0,0 +1,21 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Build the Open Policy Agent policies
description: Installs OPA and builds the policies
runs:
using: composite
steps:
- name: Install Open Policy Agent
uses: open-policy-agent/setup-opa@v2.2.0
with:
# Keep in sync with the Dockerfile and policies/Makefile
version: 1.13.1
- name: Build the policies
run: make
working-directory: ./policies
shell: sh

View file

@ -0,0 +1,113 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
labels:
- "A-Dependencies"
- "Z-Deps-Backend"
schedule:
interval: "daily"
ignore:
# We plan to remove apalis soon, let's ignore it for now
- dependency-name: "apalis"
- dependency-name: "apalis-*"
groups:
axum:
patterns:
- "axum"
- "axum-*"
opentelemetry:
patterns:
- "opentelemetry"
- "opentelemetry_sdk"
- "opentelemetry-*"
- "tracing-opentelemetry"
sea-query:
patterns:
- "sea-query"
- "sea-query-*"
sentry:
patterns:
- "sentry"
- "sentry-*"
tracing:
patterns:
- "tracing-*"
exclude-patterns:
- "tracing-opentelemetry"
icu:
patterns:
- "icu"
- "icu_*"
- package-ecosystem: "github-actions"
directory: "/"
labels:
- "A-Dependencies"
- "Z-Deps-CI"
schedule:
interval: "daily"
- package-ecosystem: "npm"
directory: "/frontend/"
labels:
- "A-Dependencies"
- "Z-Deps-Frontend"
schedule:
interval: "daily"
groups:
storybook:
patterns:
- "storybook"
- "storybook-*"
- "@storybook/*"
fontsource:
patterns:
- "@fontsource/*"
vitest:
patterns:
- "vitest"
- "@vitest/*"
vite:
patterns:
- "vite"
- "@vitejs/*"
- "vite-*"
i18next:
patterns:
- "i18next"
- "i18next-*"
- "react-i18next"
react:
patterns:
- "react"
- "react-*"
exclude-patterns:
- "react-i18next"
jotai:
patterns:
- "jotai"
- "jotai-*"
graphql-codegen:
patterns:
- "@graphql-codegen/*"
tanstack-router:
patterns:
- "@tanstack/react-router"
- "@tanstack/react-router-*"
- "@tanstack/router-*"
tanstack-query:
patterns:
- "@tanstack/react-query"
- "@tanstack/react-query-*"
types:
patterns:
- "@types/*"
browser-logos:
patterns:
- "@browser-logos/*"

View file

@ -0,0 +1,45 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
changelog:
categories:
- title: Bug Fixes
labels:
- T-Defect
- title: New Features
labels:
- T-Enhancement
exclude:
labels:
- A-Admin-API
- A-Documentation
- title: Changes to the admin API
labels:
- A-Admin-API
- title: Documentation
labels:
- A-Documentation
- title: Translations
labels:
- A-I18n
- title: Internal Changes
labels:
- T-Task
- title: Other Changes
labels:
- "*"
exclude:
labels:
- A-Dependencies
- title: Dependency Updates
labels:
- A-Dependencies

View file

@ -0,0 +1,7 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
node_modules/
package-lock.json

View file

@ -0,0 +1,44 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const metadataJson = process.env.BUILD_IMAGE_MANIFEST;
if (!metadataJson) throw new Error("BUILD_IMAGE_MANIFEST is not defined");
/** @type {Record<string, {tags: string[]}>} */
const metadata = JSON.parse(metadataJson);
await github.rest.issues.removeLabel({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
name: "Z-Build-Workflow",
});
const tagListMarkdown = metadata.regular.tags
.map((tag) => `- \`${tag}\``)
.join("\n");
// Get the workflow run
const run = await github.rest.actions.getWorkflowRun({
owner: context.repo.owner,
repo: context.repo.repo,
run_id: context.runId,
});
await github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: `A build for this PR at commit <kbd>${context.sha}</kbd> has been created through the <kbd>Z-Build-Workflow</kbd> label by <kbd>${context.actor}</kbd>.
Docker image is available at:
${tagListMarkdown}
Pre-built binaries are available through the [workflow run artifacts](${run.data.html_url}).`,
});
};

View file

@ -0,0 +1,66 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const fs = require("node:fs/promises");
const { owner, repo } = context.repo;
const version = process.env.VERSION;
const parent = context.sha;
if (!version) throw new Error("VERSION is not defined");
const files = ["Cargo.toml", "Cargo.lock"];
/** @type {{path: string, mode: "100644", type: "blob", sha: string}[]} */
const tree = [];
for (const file of files) {
const content = await fs.readFile(file);
const blob = await github.rest.git.createBlob({
owner,
repo,
content: content.toString("base64"),
encoding: "base64",
});
console.log(`Created blob for ${file}:`, blob.data.url);
tree.push({
path: file,
mode: "100644",
type: "blob",
sha: blob.data.sha,
});
}
const treeObject = await github.rest.git.createTree({
owner,
repo,
tree,
base_tree: parent,
});
console.log("Created tree:", treeObject.data.url);
const commit = await github.rest.git.createCommit({
owner,
repo,
message: version,
parents: [parent],
tree: treeObject.data.sha,
});
console.log("Created commit:", commit.data.url);
const tag = await github.rest.git.createTag({
owner,
repo,
tag: `v${version}`,
message: version,
type: "commit",
object: commit.data.sha,
});
console.log("Created tag:", tag.data.url);
return { commit: commit.data.sha, tag: tag.data.sha };
};

View file

@ -0,0 +1,22 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const { owner, repo } = context.repo;
const branch = process.env.BRANCH;
const sha = process.env.SHA;
if (!sha) throw new Error("SHA is not defined");
await github.rest.git.createRef({
owner,
repo,
ref: `refs/heads/${branch}`,
sha,
});
console.log(`Created branch ${branch} from ${sha}`);
};

View file

@ -0,0 +1,24 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const { owner, repo } = context.repo;
const version = process.env.VERSION;
const tagSha = process.env.TAG_SHA;
if (!version) throw new Error("VERSION is not defined");
if (!tagSha) throw new Error("TAG_SHA is not defined");
const tag = await github.rest.git.createRef({
owner,
repo,
ref: `refs/tags/v${version}`,
sha: tagSha,
});
console.log("Created tag ref:", tag.data.url);
};

View file

@ -0,0 +1,60 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const { owner, repo } = context.repo;
const sha = process.env.SHA;
const branch = `ref-merge/${sha}`;
if (!sha) throw new Error("SHA is not defined");
await github.rest.git.createRef({
owner,
repo,
ref: `refs/heads/${branch}`,
sha,
});
console.log(`Created branch ${branch} to ${sha}`);
// Create a PR to merge the branch back to main
const pr = await github.rest.pulls.create({
owner,
repo,
head: branch,
base: "main",
title: "Automatic merge back to main",
body: "This pull request was automatically created by the release workflow. It merges the release branch back to main.",
maintainer_can_modify: true,
});
console.log(
`Created pull request #${pr.data.number} to merge the release branch back to main`,
);
console.log(`PR URL: ${pr.data.html_url}`);
// Add the `T-Task` label to the PR
await github.rest.issues.addLabels({
owner,
repo,
issue_number: pr.data.number,
labels: ["T-Task"],
});
// Enable auto-merge on the PR
await github.graphql(
`
mutation AutoMerge($id: ID!) {
enablePullRequestAutoMerge(input: {
pullRequestId: $id,
mergeMethod: MERGE,
}) {
clientMutationId
}
}
`,
{ id: pr.data.node_id },
);
};

View file

@ -0,0 +1,7 @@
{
"private": true,
"devDependencies": {
"@actions/github-script": "github:actions/github-script",
"typescript": "^5.7.3"
}
}

View file

@ -0,0 +1,22 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const { owner, repo } = context.repo;
const branch = process.env.BRANCH;
const sha = process.env.SHA;
if (!sha) throw new Error("SHA is not defined");
await github.rest.git.updateRef({
owner,
repo,
ref: `heads/${branch}`,
sha,
});
console.log(`Updated branch ${branch} to ${sha}`);
};

View file

@ -0,0 +1,21 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
// @ts-check
/** @param {import('@actions/github-script').AsyncFunctionArguments} AsyncFunctionArguments */
module.exports = async ({ github, context }) => {
const { owner, repo } = context.repo;
const sha = context.sha;
const tag = await github.rest.git.updateRef({
owner,
repo,
force: true,
ref: "tags/unstable",
sha,
});
console.log("Updated tag ref:", tag.data.url);
};

View file

@ -0,0 +1,469 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Build
on:
push:
branches:
- main
- "release/**"
tags:
- "v*"
# Run when there is a label change on the pull request
# This runs only if the 'Z-Build-Workflow' is added to the pull request
pull_request:
types: [labeled]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
SCCACHE_GHA_ENABLED: "true"
RUSTC_WRAPPER: "sccache"
IMAGE: ghcr.io/element-hq/matrix-authentication-service
BUILDCACHE: ghcr.io/element-hq/matrix-authentication-service/buildcache
DOCKER_METADATA_ANNOTATIONS_LEVELS: manifest,index
jobs:
compute-version:
name: Compute version using git describe
if: github.event_name == 'push' || github.event.label.name == 'Z-Build-Workflow'
runs-on: ubuntu-24.04
permissions:
contents: read
outputs:
describe: ${{ steps.git.outputs.describe }}
timestamp: ${{ steps.git.outputs.timestamp }}
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
# Need a full clone so that `git describe` reports the right version
fetch-depth: 0
- name: Compute version and timestamp out of git history
id: git
run: |
echo "describe=$(git describe --tags --match 'v*.*.*' --always)" >> $GITHUB_OUTPUT
echo "timestamp=$(git log -1 --format=%ct)" >> $GITHUB_OUTPUT
build-assets:
name: Build assets
if: github.event_name == 'push' || github.event.label.name == 'Z-Build-Workflow'
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- uses: ./.github/actions/build-frontend
- uses: ./.github/actions/build-policies
- name: Prepare assets artifact
run: |
mkdir -p assets-dist/share
cp policies/policy.wasm assets-dist/share/policy.wasm
cp frontend/dist/manifest.json assets-dist/share/manifest.json
cp -r frontend/dist/ assets-dist/share/assets
cp -r templates/ assets-dist/share/templates
cp -r translations/ assets-dist/share/translations
cp LICENSE assets-dist/LICENSE
chmod -R u=rwX,go=rX assets-dist/
- name: Upload assets
uses: actions/upload-artifact@v6.0.0
with:
name: assets
path: assets-dist
build-binaries:
name: Build binaries
if: github.event_name == 'push' || github.event.label.name == 'Z-Build-Workflow'
runs-on: ubuntu-24.04
needs:
- compute-version
strategy:
matrix:
include:
- target: x86_64-unknown-linux-gnu
- target: aarch64-unknown-linux-gnu
env:
VERGEN_GIT_DESCRIBE: ${{ needs.compute-version.outputs.describe }}
SOURCE_DATE_EPOCH: ${{ needs.compute-version.outputs.timestamp }}
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
targets: |
${{ matrix.target }}
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- name: Install zig
uses: goto-bus-stop/setup-zig@v2
with:
version: 0.13.0
- name: Install cargo-zigbuild
uses: taiki-e/install-action@v2
with:
tool: cargo-zigbuild
- name: Build the binary
run: |
cargo zigbuild \
--release \
--target ${{ matrix.target }}.2.17 \
--no-default-features \
--features dist \
-p mas-cli
- name: Upload binary artifact
uses: actions/upload-artifact@v6.0.0
with:
name: binary-${{ matrix.target }}
path: target/${{ matrix.target }}/release/mas-cli
assemble-archives:
name: Assemble release archives
if: github.event_name == 'push' || github.event.label.name == 'Z-Build-Workflow'
runs-on: ubuntu-24.04
needs:
- build-assets
- build-binaries
permissions:
contents: read
steps:
- name: Download assets
uses: actions/download-artifact@v7
with:
name: assets
path: assets-dist
- name: Download binary x86_64
uses: actions/download-artifact@v7
with:
name: binary-x86_64-unknown-linux-gnu
path: binary-x86_64
- name: Download binary aarch64
uses: actions/download-artifact@v7
with:
name: binary-aarch64-unknown-linux-gnu
path: binary-aarch64
- name: Create final archives
run: |
for arch in x86_64 aarch64; do
mkdir -p dist/${arch}/share
cp -r assets-dist/share/* dist/${arch}/share/
cp assets-dist/LICENSE dist/${arch}/LICENSE
cp binary-$arch/mas-cli dist/${arch}/mas-cli
chmod -R u=rwX,go=rX dist/${arch}/
chmod u=rwx,go=rx dist/${arch}/mas-cli
tar -czvf mas-cli-${arch}-linux.tar.gz --owner=0 --group=0 -C dist/${arch}/ .
done
- name: Upload aarch64 archive
uses: actions/upload-artifact@v6.0.0
with:
name: mas-cli-aarch64-linux
path: mas-cli-aarch64-linux.tar.gz
- name: Upload x86_64 archive
uses: actions/upload-artifact@v6.0.0
with:
name: mas-cli-x86_64-linux
path: mas-cli-x86_64-linux.tar.gz
build-image:
name: Build and push Docker image
if: github.event_name == 'push' || github.event.label.name == 'Z-Build-Workflow'
runs-on: ubuntu-24.04
outputs:
metadata: ${{ steps.output.outputs.metadata }}
permissions:
contents: read
packages: write
id-token: write
needs:
- compute-version
env:
VERGEN_GIT_DESCRIBE: ${{ needs.compute-version.outputs.describe }}
SOURCE_DATE_EPOCH: ${{ needs.compute-version.outputs.timestamp }}
steps:
- name: Docker meta
id: meta
uses: docker/metadata-action@v5.10.0
with:
images: "${{ env.IMAGE }}"
bake-target: docker-metadata-action
flavor: |
latest=auto
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha
- name: Docker meta (debug variant)
id: meta-debug
uses: docker/metadata-action@v5.10.0
with:
images: "${{ env.IMAGE }}"
bake-target: docker-metadata-action-debug
flavor: |
latest=auto
suffix=-debug,onlatest=true
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha
- name: Setup Cosign
uses: sigstore/cosign-installer@v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3.12.0
with:
buildkitd-config-inline: |
[registry."docker.io"]
mirrors = ["mirror.gcr.io"]
- name: Login to GitHub Container Registry
uses: docker/login-action@v3.7.0
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push
id: bake
uses: docker/bake-action@v6.10.0
with:
files: |
./docker-bake.hcl
cwd://${{ steps.meta.outputs.bake-file }}
cwd://${{ steps.meta-debug.outputs.bake-file }}
set: |
base.output=type=image,push=true
base.cache-from=type=registry,ref=${{ env.BUILDCACHE }}:buildcache
base.cache-to=type=registry,ref=${{ env.BUILDCACHE }}:buildcache,mode=max
- name: Transform bake output
# This transforms the ouput to an object which looks like this:
# { reguar: { digest: "…", tags: ["…", "…"] }, debug: { digest: "…", tags: ["…"] }, … }
id: output
run: |
echo 'metadata<<EOF' >> $GITHUB_OUTPUT
echo '${{ steps.bake.outputs.metadata }}' | jq -c 'with_entries(select(.value | (type == "object" and has("containerimage.digest")))) | map_values({ digest: .["containerimage.digest"], tags: (.["image.name"] | split(",")) })' >> $GITHUB_OUTPUT
echo 'EOF' >> $GITHUB_OUTPUT
- name: Sign the images with GitHub Actions provided token
# Only sign on tags and on commits on main branch
if: |
github.event_name != 'pull_request'
&& (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main')
env:
REGULAR_DIGEST: ${{ steps.output.outputs.metadata && fromJSON(steps.output.outputs.metadata).regular.digest }}
DEBUG_DIGEST: ${{ steps.output.outputs.metadata && fromJSON(steps.output.outputs.metadata).debug.digest }}
run: |-
cosign sign --yes \
"$IMAGE@$REGULAR_DIGEST" \
"$IMAGE@$DEBUG_DIGEST" \
release:
name: Release
if: startsWith(github.ref, 'refs/tags/')
runs-on: ubuntu-24.04
needs:
- assemble-archives
- build-image
steps:
- name: Download the artifacts from the previous job
uses: actions/download-artifact@v7
with:
pattern: mas-cli-*
path: artifacts
merge-multiple: true
- name: Prepare a release
uses: softprops/action-gh-release@v2.5.0
with:
generate_release_notes: true
body: |
### Docker image
Regular image:
- Digest:
```
${{ env.IMAGE }}@${{ fromJSON(needs.build-image.outputs.metadata).regular.digest }}
```
- Tags:
```
${{ join(fromJSON(needs.build-image.outputs.metadata).regular.tags, '
') }}
```
Debug variant:
- Digest:
```
${{ env.IMAGE }}@${{ fromJSON(needs.build-image.outputs.metadata).debug.digest }}
```
- Tags:
```
${{ join(fromJSON(needs.build-image.outputs.metadata).debug.tags, '
') }}
```
files: |
artifacts/mas-cli-aarch64-linux.tar.gz
artifacts/mas-cli-x86_64-linux.tar.gz
draft: true
unstable:
name: Update the unstable release
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-24.04
needs:
- assemble-archives
- build-image
permissions:
contents: write
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
sparse-checkout: |
.github/scripts
- name: Download the artifacts from the previous job
uses: actions/download-artifact@v7
with:
pattern: mas-cli-*
path: artifacts
merge-multiple: true
- name: Update unstable git tag
uses: actions/github-script@v8.0.0
with:
script: |
const script = require('./.github/scripts/update-unstable-tag.cjs');
await script({ core, github, context });
- name: Update unstable release
uses: softprops/action-gh-release@v2.5.0
with:
name: "Unstable build"
tag_name: unstable
body: |
This is an automatically updated unstable release containing the latest builds from the main branch.
**⚠️ Warning: These are development builds and may be unstable.**
Last updated: ${{ github.event.head_commit.timestamp }}
Commit: ${{ github.sha }}
### Docker image
Regular image:
- Digest:
```
${{ env.IMAGE }}@${{ fromJSON(needs.build-image.outputs.metadata).regular.digest }}
```
- Tags:
```
${{ join(fromJSON(needs.build-image.outputs.metadata).regular.tags, '
') }}
```
Debug variant:
- Digest:
```
${{ env.IMAGE }}@${{ fromJSON(needs.build-image.outputs.metadata).debug.digest }}
```
- Tags:
```
${{ join(fromJSON(needs.build-image.outputs.metadata).debug.tags, '
') }}
```
files: |
artifacts/mas-cli-aarch64-linux.tar.gz
artifacts/mas-cli-x86_64-linux.tar.gz
prerelease: true
make_latest: false
pr-cleanup:
name: "Remove workflow build PR label and comment on it"
runs-on: ubuntu-24.04
if: github.event_name == 'pull_request' && github.event.label.name == 'Z-Build-Workflow'
needs:
- build-image
permissions:
contents: read
pull-requests: write
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
sparse-checkout: |
.github/scripts
- name: Remove label and comment
uses: actions/github-script@v8.0.0
env:
BUILD_IMAGE_MANIFEST: ${{ needs.build-image.outputs.metadata }}
with:
script: |
const script = require('./.github/scripts/cleanup-pr.cjs');
await script({ core, github, context });

View file

@ -0,0 +1,338 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: CI
on:
push:
branches:
- main
- "release/**"
tags:
- "v*"
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
SCCACHE_GHA_ENABLED: "true"
RUSTC_WRAPPER: "sccache"
jobs:
opa-lint:
name: Lint and test OPA policies
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- uses: ./.github/actions/build-policies
- name: Setup Regal
uses: StyraInc/setup-regal@v1
with:
# Keep in sync with policies/Makefile
version: 0.38.1
- name: Lint policies
working-directory: ./policies
run: make lint
- name: Run OPA tests
working-directory: ./policies
run: make test
frontend-lint:
name: Check frontend style
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Node dependencies
working-directory: ./frontend
run: npm ci
- name: Lint
working-directory: ./frontend
run: npm run lint
frontend-test:
name: Run the frontend test suite
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Node dependencies
working-directory: ./frontend
run: npm ci
- name: Test
working-directory: ./frontend
run: npm test
frontend-knip:
name: Check the frontend for unused dependencies
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Node dependencies
working-directory: ./frontend
run: npm ci
- name: Check for unused dependencies
working-directory: ./frontend
run: npm run knip
rustfmt:
name: Check Rust style
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@nightly
with:
components: rustfmt
- name: Check style
run: cargo fmt --all -- --check
cargo-deny:
name: Run `cargo deny` checks
runs-on: ubuntu-24.04
env:
# We need to remove the sccache wrapper because we don't install it in this job
RUSTC_WRAPPER: ""
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Run `cargo-deny`
uses: EmbarkStudios/cargo-deny-action@v2.0.15
with:
rust-version: stable
check-schema:
name: Check schema
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
run: |
rustup toolchain install stable
rustup default stable
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- uses: ./.github/actions/build-frontend
- name: Update the schemas
run: sh ./misc/update.sh
- name: Check that the workspace is clean
run: |
if ! [[ -z $(git status -s) ]]; then
echo "::error title=Workspace is not clean::Please run 'sh ./misc/update.sh' and commit the changes"
(
echo '## Diff after running `sh ./misc/update.sh`:'
echo
echo '```diff'
git diff
echo '```'
) >> $GITHUB_STEP_SUMMARY
exit 1
fi
clippy:
name: Run Clippy
needs: [rustfmt, opa-lint]
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@1.93.0
with:
components: clippy
- uses: ./.github/actions/build-policies
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- name: Run clippy
run: |
cargo clippy --workspace --tests --bins --lib -- -D warnings
compile-test-artifacts:
name: Compile test artifacts
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Install nextest
uses: taiki-e/install-action@v2
with:
tool: cargo-nextest
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- name: Build and archive tests
run: cargo nextest archive --workspace --archive-file nextest-archive.tar.zst
env:
SQLX_OFFLINE: "1"
- name: Upload archive to workflow
uses: actions/upload-artifact@v6.0.0
with:
name: nextest-archive
path: nextest-archive.tar.zst
test:
name: Run test suite with Rust stable
needs: [rustfmt, opa-lint, compile-test-artifacts]
runs-on: ubuntu-24.04
permissions:
contents: read
strategy:
matrix:
partition: [1, 2, 3]
services:
postgres:
image: docker.io/library/postgres:15.3
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- "5432:5432"
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Install nextest
uses: taiki-e/install-action@v2
with:
tool: cargo-nextest
- uses: ./.github/actions/build-frontend
- uses: ./.github/actions/build-policies
- name: Download archive
uses: actions/download-artifact@v7
with:
name: nextest-archive
- name: Test
env:
DATABASE_URL: postgresql://postgres:postgres@localhost/postgres
run: |
~/.cargo/bin/cargo-nextest nextest run \
--archive-file nextest-archive.tar.zst \
--partition count:${{ matrix.partition }}/3
tests-done:
name: Tests done
if: ${{ always() }}
needs:
- opa-lint
- frontend-lint
- frontend-test
- frontend-knip
- rustfmt
- cargo-deny
- clippy
- check-schema
- test
runs-on: ubuntu-24.04
steps:
- uses: matrix-org/done-action@v3
with:
needs: ${{ toJSON(needs) }}

View file

@ -0,0 +1,139 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Coverage
on:
push:
branches: [main]
pull_request:
branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: false
env:
CARGO_TERM_COLOR: always
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
jobs:
opa:
name: Run OPA test suite with coverage
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- uses: ./.github/actions/build-policies
- name: Run OPA tests with coverage
working-directory: ./policies
run: make coverage
- name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.2
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: policies/coverage.json
flags: policies
frontend:
name: Run frontend test suite with coverage
runs-on: ubuntu-24.04
permissions:
id-token: write
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- uses: ./.github/actions/build-frontend
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Test
working-directory: ./frontend
run: npm run coverage
- name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.2
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: frontend/coverage/
flags: frontend
rust:
name: Run Rust test suite with coverage
runs-on: ubuntu-24.04
permissions:
contents: read
env:
SCCACHE_GHA_ENABLED: "true"
RUSTC_WRAPPER: "sccache"
services:
postgres:
image: docker.io/library/postgres:15.3
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- "5432:5432"
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: llvm-tools-preview
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- name: Install grcov
uses: taiki-e/install-action@v2
with:
tool: grcov
- uses: ./.github/actions/build-frontend
- uses: ./.github/actions/build-policies
- name: Run test suite with profiling enabled
run: |
cargo test --no-fail-fast --workspace
env:
RUSTFLAGS: "-Cinstrument-coverage --cfg tokio_unstable"
LLVM_PROFILE_FILE: "cargo-test-%p-%m.profraw"
DATABASE_URL: postgresql://postgres:postgres@localhost/postgres
SQLX_OFFLINE: "1"
- name: Build grcov report
run: |
mkdir -p target/coverage
grcov . --binary-path ./target/debug/deps/ -s . -t lcov --branch --ignore-not-existing --ignore '../*' --ignore "/*" -o target/coverage/tests.lcov
- name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.2
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: target/coverage/*.lcov
flags: unit

View file

@ -0,0 +1,77 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Build and deploy the documentation
on:
push:
branches: [main]
pull_request:
branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
env:
CARGO_TERM_COLOR: always
CARGO_NET_GIT_FETCH_WITH_CLI: "true"
jobs:
build:
name: Build the documentation
runs-on: ubuntu-24.04
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.9
- name: Install mdbook
uses: taiki-e/install-action@v2
with:
tool: mdbook
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Build the documentation
run: sh misc/build-docs.sh
- name: Fix permissions
run: |
chmod -c -R +rX "target/book/" | while read line; do
echo "::warning title=Invalid file permissions automatically fixed::$line"
done
- name: Upload GitHub Pages artifacts
uses: actions/upload-pages-artifact@v4.0.0
with:
path: target/book/
deploy:
name: Deploy the documentation on GitHub Pages
runs-on: ubuntu-24.04
needs: build
if: github.ref == 'refs/heads/main'
permissions:
pages: write
id-token: write
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4.0.5

View file

@ -0,0 +1,40 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Merge back a reference to main
on:
workflow_call:
inputs:
sha:
required: true
type: string
secrets:
BOT_GITHUB_TOKEN:
required: true
jobs:
merge-back:
name: Merge back the reference to main
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
sparse-checkout: |
.github/scripts
- name: Push branch and open a PR
uses: actions/github-script@v8.0.0
env:
SHA: ${{ inputs.sha }}
with:
github-token: ${{ secrets.BOT_GITHUB_TOKEN }}
script: |
const script = require('./.github/scripts/merge-back.cjs');
await script({ core, github, context });

View file

@ -0,0 +1,123 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Create a new release branch
on:
workflow_dispatch:
inputs:
kind:
description: Kind of release (major = v1.2.3 -> v2.0.0-rc.0, minor = v1.2.3 -> v1.3.0-rc.0)
required: true
type: choice
default: minor
options:
- major
- minor
jobs:
compute-version:
name: Compute the next ${{ inputs.kind }} RC version
runs-on: ubuntu-24.04
permissions:
contents: read
outputs:
full: ${{ steps.next.outputs.full }}
short: ${{ steps.next.outputs.short }}
steps:
- name: Fail the workflow if this is not the main branch
if: ${{ github.ref_name != 'main' }}
run: exit 1
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Compute the new minor RC
id: next
env:
BUMP: pre${{ inputs.kind }}
run: |
CURRENT_VERSION="$(cargo metadata --format-version 1 --no-deps | jq -r '.packages[] | select(.name == "mas-cli") | .version')"
NEXT_VERSION="$(npx --yes semver@7.5.4 -i "$BUMP" --preid rc "${CURRENT_VERSION}")"
# compute the short minor version, e.g. 0.1.0-rc.1 -> 0.1
SHORT_VERSION="$(echo "${NEXT_VERSION}" | cut -d. -f1-2)"
echo "full=${NEXT_VERSION}" >> "$GITHUB_OUTPUT"
echo "short=${SHORT_VERSION}" >> "$GITHUB_OUTPUT"
localazy:
name: Create a new branch in Localazy
runs-on: ubuntu-24.04
needs: [compute-version]
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Localazy CLI
run: npm install -g @localazy/cli
- name: Create a new branch in Localazy
run: localazy branch -w "$LOCALAZY_WRITE_KEY" create main "$BRANCH"
env:
LOCALAZY_WRITE_KEY: ${{ secrets.LOCALAZY_WRITE_KEY }}
# Localazy doesn't like slashes in branch names, so we just use the short version
# For example, a 0.13.0 release will create a localazy branch named "v0.13" and a git branch named "release/v0.13"
BRANCH: v${{ needs.compute-version.outputs.short }}
tag:
uses: ./.github/workflows/tag.yaml
needs: [compute-version]
with:
version: ${{ needs.compute-version.outputs.full }}
secrets:
BOT_GITHUB_TOKEN: ${{ secrets.BOT_GITHUB_TOKEN }}
merge-back:
uses: ./.github/workflows/merge-back.yaml
needs: [tag]
with:
sha: ${{ needs.tag.outputs.sha }}
secrets:
BOT_GITHUB_TOKEN: ${{ secrets.BOT_GITHUB_TOKEN }}
branch:
name: Create a new release branch
runs-on: ubuntu-24.04
permissions:
contents: write
pull-requests: write
needs: [tag, compute-version, localazy]
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
sparse-checkout: |
.github/scripts
- name: Create a new release branch
uses: actions/github-script@v8.0.0
env:
BRANCH: release/v${{ needs.compute-version.outputs.short }}
SHA: ${{ needs.tag.outputs.sha }}
with:
github-token: ${{ secrets.BOT_GITHUB_TOKEN }}
script: |
const script = require('./.github/scripts/create-release-branch.cjs');
await script({ core, github, context });

View file

@ -0,0 +1,93 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Bump the version on a release branch
on:
workflow_dispatch:
inputs:
rc:
description: "Is it a release candidate?"
type: boolean
default: false
merge-back:
description: "Should we merge back the release branch to main?"
type: boolean
default: true
jobs:
compute-version:
name: Compute the next version
runs-on: ubuntu-24.04
permissions:
contents: read
outputs:
version: ${{ steps.next.outputs.version }}
steps:
- name: Fail the workflow if not on a release branch
if: ${{ !startsWith(github.ref_name, 'release/v') }}
run: exit 1
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Extract the current version
id: current
run: echo "version=$(cargo metadata --format-version 1 --no-deps | jq -r '.packages[] | select(.name == "mas-cli") | .version')" >> "$GITHUB_OUTPUT"
- name: Compute the new minor RC
id: next
env:
BUMP: ${{ inputs.rc && 'prerelease' || 'patch' }}
VERSION: ${{ steps.current.outputs.version }}
run: echo "version=$(npx --yes semver@7.5.4 -i "$BUMP" --preid rc "$VERSION")" >> "$GITHUB_OUTPUT"
tag:
uses: ./.github/workflows/tag.yaml
needs: [compute-version]
with:
version: ${{ needs.compute-version.outputs.version }}
secrets:
BOT_GITHUB_TOKEN: ${{ secrets.BOT_GITHUB_TOKEN }}
merge-back:
uses: ./.github/workflows/merge-back.yaml
needs: [tag]
if: inputs.merge-back
with:
sha: ${{ needs.tag.outputs.sha }}
secrets:
BOT_GITHUB_TOKEN: ${{ secrets.BOT_GITHUB_TOKEN }}
update-branch:
name: Update the release branch
runs-on: ubuntu-24.04
permissions:
pull-requests: write
needs: [tag, compute-version]
steps:
- name: Checkout the code
uses: actions/checkout@v6
with:
sparse-checkout: |
.github/scripts
- name: Update the release branch
uses: actions/github-script@v8.0.0
env:
BRANCH: "${{ github.ref_name }}"
SHA: ${{ needs.tag.outputs.sha }}
with:
github-token: ${{ secrets.BOT_GITHUB_TOKEN }}
script: |
const script = require('./.github/scripts/update-release-branch.cjs');
await script({ core, github, context });

View file

@ -0,0 +1,71 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Tag a new version
on:
workflow_call:
inputs:
version:
required: true
type: string
outputs:
sha:
description: "The SHA of the commit made which bumps the version"
value: ${{ jobs.tag.outputs.sha }}
secrets:
BOT_GITHUB_TOKEN:
required: true
jobs:
tag:
name: Tag a new version
runs-on: ubuntu-24.04
permissions:
contents: write
outputs:
sha: ${{ fromJSON(steps.commit.outputs.result).commit }}
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Set the crates version
env:
VERSION: ${{ inputs.version }}
run: |
sed -i "s/^package.version = .*/package.version = \"$VERSION\"/" Cargo.toml
sed -i "/path = \".\/crates\//s/version = \".*\"/version = \"=$VERSION\"/" Cargo.toml
- name: Run `cargo metadata` to make sure the lockfile is up to date
run: cargo metadata --format-version 1
- name: Commit and tag using the GitHub API
uses: actions/github-script@v8.0.0
id: commit
env:
VERSION: ${{ inputs.version }}
with:
# Commit & tag with the actions token, so that they get signed
# This returns the commit sha and the tag object sha
script: |
const script = require('./.github/scripts/commit-and-tag.cjs');
return await script({ core, github, context });
- name: Update the refs
uses: actions/github-script@v8.0.0
env:
VERSION: ${{ inputs.version }}
TAG_SHA: ${{ fromJSON(steps.commit.outputs.result).tag }}
COMMIT_SHA: ${{ fromJSON(steps.commit.outputs.result).commit }}
with:
# Update the refs with the bot token, so that workflows are triggered
github-token: ${{ secrets.BOT_GITHUB_TOKEN }}
script: |
const script = require('./.github/scripts/create-version-tag.cjs');
await script({ core, github, context });

View file

@ -0,0 +1,63 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Download translation files from Localazy
on:
workflow_dispatch:
jobs:
download:
runs-on: ubuntu-24.04
permissions:
contents: write
steps:
- name: Fail the workflow if not on the main branch or a release branch
if: ${{ !(startsWith(github.ref_name, 'release/v') || github.ref_name == 'main') }}
run: exit 1
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Localazy CLI
run: npm install -g @localazy/cli
- name: Compute the Localazy branch name
id: branch
# This will strip the "release/" prefix if present, keeping 'main' as-is
run: echo "name=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
- name: Download translations from Localazy
run: localazy download -w "$LOCALAZY_WRITE_KEY" -b "$BRANCH"
env:
LOCALAZY_WRITE_KEY: ${{ secrets.LOCALAZY_WRITE_KEY }}
BRANCH: ${{ steps.branch.outputs.name }}
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v8.1.0
with:
sign-commits: true
token: ${{ secrets.BOT_GITHUB_TOKEN }}
branch-token: ${{ secrets.GITHUB_TOKEN }}
branch: actions/localazy-download/${{ steps.branch.outputs.name }}
delete-branch: true
title: Translations updates for ${{ steps.branch.outputs.name }}
labels: |
T-Task
A-I18n
commit-message: Translations updates
- name: Enable automerge
run: gh pr merge --merge --auto "$PR_NUMBER"
if: steps.cpr.outputs.pull-request-operation == 'created'
env:
GH_TOKEN: ${{ secrets.BOT_GITHUB_TOKEN }}
PR_NUMBER: ${{ steps.cpr.outputs.pull-request-number }}

View file

@ -0,0 +1,41 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
name: Upload translation files to Localazy
on:
push:
branches:
- main
- release/v**
jobs:
upload:
runs-on: ubuntu-24.04
permissions:
contents: read
steps:
- name: Checkout the code
uses: actions/checkout@v6
- name: Install Node
uses: actions/setup-node@v6.2.0
with:
node-version: 24
- name: Install Localazy CLI
run: npm install -g @localazy/cli
- name: Compute the Localazy branch name
id: branch
run: |
# This will strip the "release/" prefix if present, keeping 'main' as-is
echo "name=${GITHUB_REF_NAME#release/}" >> "$GITHUB_OUTPUT"
- name: Upload translations to Localazy
run: localazy upload -w "$LOCALAZY_WRITE_KEY" -b "$BRANCH"
env:
LOCALAZY_WRITE_KEY: ${{ secrets.LOCALAZY_WRITE_KEY }}
BRANCH: ${{ steps.branch.outputs.name }}

View file

@ -0,0 +1,14 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
# Rust
target
# Editors
.idea
.nova
# OS garbage
.DS_Store

View file

@ -0,0 +1,11 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
max_width = 100
comment_width = 80
wrap_comments = true
imports_granularity = "Crate"
use_small_heuristics = "Default"
group_imports = "StdExternalCrate"

View file

@ -0,0 +1,5 @@
# Contributing to MAS
Thank you for taking the time to contribute to Matrix!
Please see the [contributors' guide](https://element-hq.github.io/matrix-authentication-service/development/contributing.html) in our rendered documentation.

7919
matrix-authentication-service/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,762 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[workspace]
default-members = ["crates/cli"]
members = ["crates/*"]
resolver = "2"
# Updated in the CI with a `sed` command
package.version = "1.12.0"
package.license = "AGPL-3.0-only OR LicenseRef-Element-Commercial"
package.authors = ["Element Backend Team"]
package.edition = "2024"
package.homepage = "https://element-hq.github.io/matrix-authentication-service/"
package.repository = "https://github.com/element-hq/matrix-authentication-service/"
package.publish = false
[workspace.lints.rust]
unsafe_code = "deny"
[workspace.lints.clippy]
# We use groups as good defaults, but with a lower priority so that we can override them
all = { level = "deny", priority = -1 }
pedantic = { level = "warn", priority = -1 }
str_to_string = "deny"
too_many_lines = "allow"
[workspace.lints.rustdoc]
broken_intra_doc_links = "deny"
[workspace.dependencies]
# Workspace crates
mas-axum-utils = { path = "./crates/axum-utils/", version = "=1.12.0" }
mas-cli = { path = "./crates/cli/", version = "=1.12.0" }
mas-config = { path = "./crates/config/", version = "=1.12.0" }
mas-context = { path = "./crates/context/", version = "=1.12.0" }
mas-data-model = { path = "./crates/data-model/", version = "=1.12.0" }
mas-email = { path = "./crates/email/", version = "=1.12.0" }
mas-graphql = { path = "./crates/graphql/", version = "=1.12.0" }
mas-handlers = { path = "./crates/handlers/", version = "=1.12.0" }
mas-http = { path = "./crates/http/", version = "=1.12.0" }
mas-i18n = { path = "./crates/i18n/", version = "=1.12.0" }
mas-i18n-scan = { path = "./crates/i18n-scan/", version = "=1.12.0" }
mas-iana = { path = "./crates/iana/", version = "=1.12.0" }
mas-iana-codegen = { path = "./crates/iana-codegen/", version = "=1.12.0" }
mas-jose = { path = "./crates/jose/", version = "=1.12.0" }
mas-keystore = { path = "./crates/keystore/", version = "=1.12.0" }
mas-listener = { path = "./crates/listener/", version = "=1.12.0" }
mas-matrix = { path = "./crates/matrix/", version = "=1.12.0" }
mas-matrix-synapse = { path = "./crates/matrix-synapse/", version = "=1.12.0" }
mas-oidc-client = { path = "./crates/oidc-client/", version = "=1.12.0" }
mas-policy = { path = "./crates/policy/", version = "=1.12.0" }
mas-router = { path = "./crates/router/", version = "=1.12.0" }
mas-spa = { path = "./crates/spa/", version = "=1.12.0" }
mas-storage = { path = "./crates/storage/", version = "=1.12.0" }
mas-storage-pg = { path = "./crates/storage-pg/", version = "=1.12.0" }
mas-tasks = { path = "./crates/tasks/", version = "=1.12.0" }
mas-templates = { path = "./crates/templates/", version = "=1.12.0" }
mas-tower = { path = "./crates/tower/", version = "=1.12.0" }
oauth2-types = { path = "./crates/oauth2-types/", version = "=1.12.0" }
syn2mas = { path = "./crates/syn2mas", version = "=1.12.0" }
# OpenAPI schema generation and validation
[workspace.dependencies.aide]
version = "0.15.1"
features = ["axum", "axum-extra", "axum-extra-query", "axum-json", "macros"]
# An `Arc` that can be atomically updated
[workspace.dependencies.arc-swap]
version = "1.8.1"
# GraphQL server
[workspace.dependencies.async-graphql]
version = "7.0.17"
default-features = false
features = ["chrono", "url", "tracing", "playground"]
[workspace.dependencies.async-stream]
version = "0.3.6"
# Utility to write and implement async traits
[workspace.dependencies.async-trait]
version = "0.1.89"
# High-level error handling
[workspace.dependencies.anyhow]
version = "1.0.100"
# Assert that a value matches a pattern
[workspace.dependencies.assert_matches]
version = "1.5.0"
# HTTP router
[workspace.dependencies.axum]
version = "0.8.6"
# Extra utilities for Axum
[workspace.dependencies.axum-extra]
version = "0.10.3"
features = ["cookie-private", "cookie-key-expansion", "typed-header", "query"]
# Axum macros
[workspace.dependencies.axum-macros]
version = "0.5.0"
# AEAD (Authenticated Encryption with Associated Data)
[workspace.dependencies.aead]
version = "0.5.2"
features = ["std"]
# Argon2 password hashing
[workspace.dependencies.argon2]
version = "0.5.3"
features = ["password-hash", "std"]
# Constant-time base64
[workspace.dependencies.base64ct]
version = "1.8.0"
features = ["std"]
# Bcrypt password hashing
[workspace.dependencies.bcrypt]
version = "0.18.0"
default-features = true
# Packed bitfields
[workspace.dependencies.bitflags]
version = "2.10.0"
# Bytes
[workspace.dependencies.bytes]
version = "1.10.1"
# UTF-8 paths
[workspace.dependencies.camino]
version = "1.2.1"
features = ["serde1"]
# ChaCha20Poly1305 AEAD
[workspace.dependencies.chacha20poly1305]
version = "0.10.1"
features = ["std"]
# Memory optimisation for short strings
[workspace.dependencies.compact_str]
version = "0.9.0"
# Terminal formatting
[workspace.dependencies.console]
version = "0.15.11"
# Cookie store
[workspace.dependencies.cookie_store]
version = "0.22.0"
default-features = false
features = ["serde_json"]
# Time utilities
[workspace.dependencies.chrono]
version = "0.4.42"
default-features = false
features = ["serde", "clock"]
# CLI argument parsing
[workspace.dependencies.clap]
version = "4.5.50"
features = ["derive"]
# Object Identifiers (OIDs) as constants
[workspace.dependencies.const-oid]
version = "0.9.6"
features = ["std"]
# Utility for converting between different cases
[workspace.dependencies.convert_case]
version = "0.9.0"
# CRC calculation
[workspace.dependencies.crc]
version = "3.3.0"
# Cron expressions
[workspace.dependencies.cron]
version = "0.15.0"
# CSV parsing and writing
[workspace.dependencies.csv]
version = "1.4.0"
# DER encoding
[workspace.dependencies.der]
version = "0.7.10"
features = ["std"]
# Interactive CLI dialogs
[workspace.dependencies.dialoguer]
version = "0.11.0"
default-features = false
features = ["fuzzy-select", "password"]
# Cryptographic digest algorithms
[workspace.dependencies.digest]
version = "0.10.7"
# Load environment variables from .env files
[workspace.dependencies.dotenvy]
version = "0.15.7"
# ECDSA algorithms
[workspace.dependencies.ecdsa]
version = "0.16.9"
features = ["signing", "verifying"]
# Elliptic curve cryptography
[workspace.dependencies.elliptic-curve]
version = "0.13.8"
features = ["std", "pem", "sec1"]
# Configuration loading
[workspace.dependencies.figment]
version = "0.10.19"
features = ["env", "yaml", "test"]
# URL form encoding
[workspace.dependencies.form_urlencoded]
version = "1.2.2"
# Utilities for dealing with futures
[workspace.dependencies.futures-util]
version = "0.3.31"
# Fixed-size arrays with trait implementations
[workspace.dependencies.generic-array]
version = "0.14.7"
# Rate-limiting
[workspace.dependencies.governor]
version = "0.10.1"
default-features = false
features = ["std", "dashmap", "quanta"]
# HMAC calculation
[workspace.dependencies.hmac]
version = "0.12.1"
# HTTP headers
[workspace.dependencies.headers]
version = "0.4.1"
# Hex encoding and decoding
[workspace.dependencies.hex]
version = "0.4.3"
# HTTP request/response
[workspace.dependencies.http]
version = "1.3.1"
# HTTP body trait
[workspace.dependencies.http-body]
version = "1.0.1"
# http-body utilities
[workspace.dependencies.http-body-util]
version = "0.1.3"
# HTTP client and server
[workspace.dependencies.hyper]
version = "1.7.0"
features = ["client", "server", "http1", "http2"]
# Additional Hyper utilties
[workspace.dependencies.hyper-util]
version = "0.1.18"
features = [
"client",
"server",
"server-auto",
"service",
"http1",
"http2",
"tokio",
]
# Hyper Rustls support
[workspace.dependencies.hyper-rustls]
version = "0.27.7"
features = ["http1", "http2"]
default-features = false
# ICU libraries for internationalization
[workspace.dependencies.icu_calendar]
version = "1.5.2"
features = ["compiled_data", "std"]
[workspace.dependencies.icu_datetime]
version = "1.5.1"
features = ["compiled_data", "std"]
[workspace.dependencies.icu_experimental]
version = "0.1.0"
features = ["compiled_data", "std"]
[workspace.dependencies.icu_locid]
version = "1.5.0"
features = ["std"]
[workspace.dependencies.icu_locid_transform]
version = "1.5.0"
features = ["compiled_data", "std"]
[workspace.dependencies.icu_normalizer]
version = "1.5.0"
[workspace.dependencies.icu_plurals]
version = "1.5.0"
features = ["compiled_data", "std"]
[workspace.dependencies.icu_provider]
version = "1.5.0"
features = ["std", "sync"]
[workspace.dependencies.icu_provider_adapters]
version = "1.5.0"
features = ["std"]
# HashMap which preserves insertion order
[workspace.dependencies.indexmap]
version = "2.11.4"
features = ["serde"]
# Indented string literals
[workspace.dependencies.indoc]
version = "2.0.6"
# Snapshot testing
[workspace.dependencies.insta]
version = "1.46.3"
features = ["yaml", "json"]
# IP network address types
[workspace.dependencies.ipnetwork]
version = "0.20.0"
features = ["serde"]
# Iterator utilities
[workspace.dependencies.itertools]
version = "0.14.0"
# K256 elliptic curve
[workspace.dependencies.k256]
version = "0.13.4"
features = ["std"]
# RFC 5646 language tags
[workspace.dependencies.language-tags]
version = "0.3.2"
features = ["serde"]
# Email sending
[workspace.dependencies.lettre]
version = "0.11.19"
default-features = false
features = [
"tokio1-rustls",
"rustls-platform-verifier",
"aws-lc-rs",
"hostname",
"builder",
"tracing",
"pool",
"smtp-transport",
"sendmail-transport",
]
# Listening on passed FDs
[workspace.dependencies.listenfd]
version = "1.0.2"
# MIME type support
[workspace.dependencies.mime]
version = "0.3.17"
# Templates
[workspace.dependencies.minijinja]
version = "2.15.1"
features = ["urlencode", "loader", "json", "speedups", "unstable_machinery"]
# Additional filters for minijinja
[workspace.dependencies.minijinja-contrib]
version = "2.12.0"
features = ["pycompat"]
# Utilities to deal with non-zero values
[workspace.dependencies.nonzero_ext]
version = "0.3.0"
# Open Policy Agent support through WASM
[workspace.dependencies.opa-wasm]
version = "0.1.9"
# OpenTelemetry
[workspace.dependencies.opentelemetry]
version = "0.31.0"
features = ["trace", "metrics"]
[workspace.dependencies.opentelemetry-http]
version = "0.31.0"
features = ["reqwest"]
[workspace.dependencies.opentelemetry-instrumentation-process]
version = "0.1.2"
[workspace.dependencies.opentelemetry-instrumentation-tokio]
version = "0.1.2"
[workspace.dependencies.opentelemetry-jaeger-propagator]
version = "0.31.0"
[workspace.dependencies.opentelemetry-otlp]
version = "0.31.0"
default-features = false
features = ["trace", "metrics", "http-proto"]
[workspace.dependencies.opentelemetry-prometheus-text-exporter]
version = "0.2.1"
[workspace.dependencies.opentelemetry-resource-detectors]
version = "0.10.0"
[workspace.dependencies.opentelemetry-semantic-conventions]
version = "0.31.0"
features = ["semconv_experimental"]
[workspace.dependencies.opentelemetry-stdout]
version = "0.31.0"
features = ["trace", "metrics"]
[workspace.dependencies.opentelemetry_sdk]
version = "0.31.0"
features = [
"experimental_trace_batch_span_processor_with_async_runtime",
"experimental_metrics_periodicreader_with_async_runtime",
"rt-tokio",
]
[workspace.dependencies.tracing-opentelemetry]
version = "0.32.0"
default-features = false
# P256 elliptic curve
[workspace.dependencies.p256]
version = "0.13.2"
features = ["std"]
# P384 elliptic curve
[workspace.dependencies.p384]
version = "0.13.1"
features = ["std"]
# Text padding utilities
[workspace.dependencies.pad]
version = "0.1.6"
# PBKDF2 password hashing
[workspace.dependencies.pbkdf2]
version = "0.12.2"
features = ["password-hash", "std", "simple", "parallel"]
# PEM encoding/decoding
[workspace.dependencies.pem-rfc7468]
version = "0.7.0"
features = ["std"]
# Parser generator
[workspace.dependencies.pest]
version = "2.8.3"
# Pest derive macros
[workspace.dependencies.pest_derive]
version = "2.8.3"
# Pin projection
[workspace.dependencies.pin-project-lite]
version = "0.2.16"
# PKCS#1 encoding
[workspace.dependencies.pkcs1]
version = "0.7.5"
features = ["std"]
# PKCS#8 encoding
[workspace.dependencies.pkcs8]
version = "0.10.2"
features = ["std", "pkcs5", "encryption"]
# Public Suffix List
[workspace.dependencies.psl]
version = "2.1.162"
# High-precision clock
[workspace.dependencies.quanta]
version = "0.12.6"
# Random values
[workspace.dependencies.rand]
version = "0.8.5"
[workspace.dependencies.rand_chacha]
version = "0.3.1"
[workspace.dependencies.rand_core]
version = "0.6.4"
# Regular expressions
[workspace.dependencies.regex]
version = "1.12.2"
# High-level HTTP client
[workspace.dependencies.reqwest]
version = "0.12.24"
default-features = false
features = [
"http2",
"rustls-tls-manual-roots-no-provider",
"charset",
"json",
"socks",
]
# RSA cryptography
[workspace.dependencies.rsa]
version = "0.9.10"
features = ["std", "pem"]
# Fast hash algorithm for HashMap
[workspace.dependencies.rustc-hash]
version = "2.1.1"
# Matrix-related types
[workspace.dependencies.ruma-common]
version = "0.16.0"
# TLS stack
[workspace.dependencies.rustls]
version = "0.23.35"
# PKI types for rustls
[workspace.dependencies.rustls-pki-types]
version = "1.13.0"
# Use platform-specific verifier for TLS
[workspace.dependencies.rustls-platform-verifier]
version = "0.6.1"
# systemd service status notification
[workspace.dependencies.sd-notify]
version = "0.4.5"
# JSON Schema generation
[workspace.dependencies.schemars]
version = "0.9.0"
features = ["url2", "chrono04", "preserve_order"]
# SEC1 encoding format
[workspace.dependencies.sec1]
version = "0.7.3"
features = ["std"]
# Query builder
[workspace.dependencies.sea-query]
version = "0.32.7"
features = ["derive", "attr", "with-uuid", "with-chrono", "postgres-array"]
# Query builder
[workspace.dependencies.sea-query-binder]
version = "0.7.0"
features = [
"sqlx",
"sqlx-postgres",
"with-uuid",
"with-chrono",
"postgres-array",
]
# Sentry error tracking
[workspace.dependencies.sentry]
version = "0.46.2"
default-features = false
features = ["backtrace", "contexts", "panic", "tower", "reqwest"]
# Sentry tower layer
[workspace.dependencies.sentry-tower]
version = "0.46.0"
features = ["http", "axum-matched-path"]
# Sentry tracing integration
[workspace.dependencies.sentry-tracing]
version = "0.46.0"
# Serialization and deserialization
[workspace.dependencies.serde]
version = "1.0.228"
features = ["derive"] # Most of the time, if we need serde, we need derive
# JSON serialization and deserialization
[workspace.dependencies.serde_json]
version = "1.0.145"
features = ["preserve_order"]
# URL encoded form serialization
[workspace.dependencies.serde_urlencoded]
version = "0.7.1"
# Custom serialization helpers
[workspace.dependencies.serde_with]
version = "3.14.0"
features = ["hex", "chrono"]
# YAML serialization
[workspace.dependencies.serde_yaml]
version = "0.9.34"
# SHA-2 cryptographic hash algorithm
[workspace.dependencies.sha2]
version = "0.10.9"
features = ["oid"]
# Digital signature traits
[workspace.dependencies.signature]
version = "2.2.0"
# Low-level socket manipulation
[workspace.dependencies.socket2]
version = "0.6.2"
# Subject Public Key Info
[workspace.dependencies.spki]
version = "0.7.3"
features = ["std"]
# SQL database support
[workspace.dependencies.sqlx]
version = "0.8.6"
features = [
"runtime-tokio",
"tls-rustls-aws-lc-rs",
"postgres",
"migrate",
"chrono",
"json",
"uuid",
"ipnetwork",
]
# Custom error types
[workspace.dependencies.thiserror]
version = "2.0.17"
[workspace.dependencies.thiserror-ext]
version = "0.3.0"
# Async runtime
[workspace.dependencies.tokio]
version = "1.48.0"
features = ["full"]
[workspace.dependencies.tokio-stream]
version = "0.1.17"
# Tokio rustls integration
[workspace.dependencies.tokio-rustls]
version = "0.26.4"
# Tokio test utilities
[workspace.dependencies.tokio-test]
version = "0.4.4"
# Useful async utilities
[workspace.dependencies.tokio-util]
version = "0.7.16"
features = ["rt"]
# Tower services
[workspace.dependencies.tower]
version = "0.5.2"
features = ["util"]
# Tower service trait
[workspace.dependencies.tower-service]
version = "0.3.3"
# Tower layer trait
[workspace.dependencies.tower-layer]
version = "0.3.3"
# Tower HTTP layers
[workspace.dependencies.tower-http]
version = "0.6.6"
features = ["cors", "fs", "add-extension", "set-header"]
# Logging and tracing
[workspace.dependencies.tracing]
version = "0.1.41"
[workspace.dependencies.tracing-subscriber]
version = "0.3.22"
features = ["env-filter"]
[workspace.dependencies.tracing-appender]
version = "0.2.4"
# URL manipulation
[workspace.dependencies.url]
version = "2.5.7"
features = ["serde"]
# URL encoding
[workspace.dependencies.urlencoding]
version = "2.1.3"
# ULID support
[workspace.dependencies.ulid]
version = "=1.1.4" # Pinned to the latest version which used rand 0.8
features = ["serde", "uuid"]
# UUID support
[workspace.dependencies.uuid]
version = "1.18.1"
# HTML escaping
[workspace.dependencies.v_htmlescape]
version = "0.15.8"
# Version information generation
[workspace.dependencies.vergen-gitcl]
version = "1.0.8"
features = ["rustc"]
# Directory traversal
[workspace.dependencies.walkdir]
version = "2.5.0"
# HTTP mock server
[workspace.dependencies.wiremock]
version = "0.6.5"
# User-agent parser
[workspace.dependencies.woothee]
version = "0.13.0"
# String writing interface
[workspace.dependencies.writeable]
version = "0.5.5"
# Zero memory after use
[workspace.dependencies.zeroize]
version = "1.8.2"
# Password strength estimation
[workspace.dependencies.zxcvbn]
version = "3.1.0"
[profile.release]
codegen-units = 1 # Reduce the number of codegen units to increase optimizations
lto = true # Enable fat LTO
# A few profile opt-level tweaks to make the test suite run faster
[profile.dev.package]
argon2.opt-level = 3
bcrypt.opt-level = 3
block-buffer.opt-level = 3
cranelift-codegen.opt-level = 3
digest.opt-level = 3
hmac.opt-level = 3
generic-array.opt-level = 3
num-bigint-dig.opt-level = 3
pbkdf2.opt-level = 3
rayon.opt-level = 3
regalloc2.opt-level = 3
sha2.opt-level = 3
sqlx-macros.opt-level = 3

View file

@ -0,0 +1,169 @@
# syntax = docker/dockerfile:1.21.0
# Copyright 2025, 2026 Element Creations Ltd.
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
# Builds a minimal image with the binary only. It is multi-arch capable,
# cross-building to aarch64 and x86_64. When cross-compiling, Docker sets two
# implicit BUILDARG: BUILDPLATFORM being the host platform and TARGETPLATFORM
# being the platform being built.
# The Debian version and version name must be in sync
ARG DEBIAN_VERSION=13
ARG DEBIAN_VERSION_NAME=trixie
# Keep in sync with .github/workflows/ci.yaml
ARG RUSTC_VERSION=1.93.0
ARG NODEJS_VERSION=24.13.0
# Keep in sync with .github/actions/build-policies/action.yml and policies/Makefile
ARG OPA_VERSION=1.13.1
ARG CARGO_AUDITABLE_VERSION=0.7.2
##########################################
## Build stage that builds the frontend ##
##########################################
FROM --platform=${BUILDPLATFORM} docker.io/library/node:${NODEJS_VERSION}-${DEBIAN_VERSION_NAME} AS frontend
WORKDIR /app/frontend
COPY ./frontend/.npmrc ./frontend/package.json ./frontend/package-lock.json /app/frontend/
# Network access: to fetch dependencies
RUN --network=default \
npm ci
COPY ./frontend/ /app/frontend/
COPY ./templates/ /app/templates/
RUN --network=none \
npm run build
# Move the built files
RUN --network=none \
mkdir -p /share/assets && \
cp ./dist/manifest.json /share/manifest.json && \
rm -f ./dist/index.html* ./dist/manifest.json* && \
cp ./dist/* /share/assets/
##############################################
## Build stage that builds the OPA policies ##
##############################################
FROM --platform=${BUILDPLATFORM} docker.io/library/buildpack-deps:${DEBIAN_VERSION_NAME} AS policy
ARG BUILDOS
ARG BUILDARCH
ARG OPA_VERSION
# Download Open Policy Agent
ADD --chmod=755 https://github.com/open-policy-agent/opa/releases/download/v${OPA_VERSION}/opa_${BUILDOS}_${BUILDARCH}_static /usr/local/bin/opa
WORKDIR /app/policies
COPY ./policies /app/policies
RUN --network=none \
make -B && \
chmod a+r ./policy.wasm
########################################
## Build stage that builds the binary ##
########################################
FROM --platform=${BUILDPLATFORM} docker.io/library/rust:${RUSTC_VERSION}-${DEBIAN_VERSION_NAME} AS builder
ARG CARGO_AUDITABLE_VERSION
ARG RUSTC_VERSION
# Install pinned versions of cargo-auditable
# Network access: to fetch dependencies
RUN --network=default \
cargo install --locked \
cargo-auditable@=${CARGO_AUDITABLE_VERSION}
# Install all cross-compilation targets
# Network access: to download the targets
RUN --network=default \
rustup target add \
--toolchain "${RUSTC_VERSION}" \
x86_64-unknown-linux-gnu \
aarch64-unknown-linux-gnu
RUN --network=none \
dpkg --add-architecture arm64 && \
dpkg --add-architecture amd64
ARG BUILDPLATFORM
# Install cross-compilation toolchains for all supported targets
# Network access: to install apt packages
RUN --network=default \
apt-get update && apt-get install -y \
$(if [ "${BUILDPLATFORM}" != "linux/arm64" ]; then echo "g++-aarch64-linux-gnu"; fi) \
$(if [ "${BUILDPLATFORM}" != "linux/amd64" ]; then echo "g++-x86-64-linux-gnu"; fi) \
libc6-dev-amd64-cross \
libc6-dev-arm64-cross \
g++
# Setup the cross-compilation environment
ENV \
CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc \
CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc \
CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ \
CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-linux-gnu-gcc \
CC_x86_64_unknown_linux_gnu=x86_64-linux-gnu-gcc \
CXX_x86_64_unknown_linux_gnu=x86_64-linux-gnu-g++
# Set the working directory
WORKDIR /app
# Copy the code
COPY ./ /app
ENV SQLX_OFFLINE=true
ARG VERGEN_GIT_DESCRIBE
ENV VERGEN_GIT_DESCRIBE=${VERGEN_GIT_DESCRIBE}
# Network access: cargo auditable needs it
RUN --network=default \
--mount=type=cache,target=/root/.cargo/registry \
--mount=type=cache,target=/app/target \
cargo auditable build \
--locked \
--release \
--bin mas-cli \
--no-default-features \
--features docker \
--target x86_64-unknown-linux-gnu \
--target aarch64-unknown-linux-gnu \
&& mv "target/x86_64-unknown-linux-gnu/release/mas-cli" /usr/local/bin/mas-cli-amd64 \
&& mv "target/aarch64-unknown-linux-gnu/release/mas-cli" /usr/local/bin/mas-cli-arm64
#######################################
## Prepare /usr/local/share/mas-cli/ ##
#######################################
FROM --platform=${BUILDPLATFORM} scratch AS share
COPY --from=frontend /share /share
COPY --from=policy /app/policies/policy.wasm /share/policy.wasm
COPY ./templates/ /share/templates
COPY ./translations/ /share/translations
##################################
## Runtime stage, debug variant ##
##################################
FROM gcr.io/distroless/cc-debian${DEBIAN_VERSION}:debug-nonroot AS debug
ARG TARGETARCH
COPY --from=builder /usr/local/bin/mas-cli-${TARGETARCH} /usr/local/bin/mas-cli
COPY --from=share /share /usr/local/share/mas-cli
WORKDIR /
ENTRYPOINT ["/usr/local/bin/mas-cli"]
###################
## Runtime stage ##
###################
FROM gcr.io/distroless/cc-debian${DEBIAN_VERSION}:nonroot
ARG TARGETARCH
COPY --from=builder /usr/local/bin/mas-cli-${TARGETARCH} /usr/local/bin/mas-cli
COPY --from=share /share /usr/local/share/mas-cli
WORKDIR /
ENTRYPOINT ["/usr/local/bin/mas-cli"]

View file

@ -0,0 +1,661 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.

View file

@ -0,0 +1,6 @@
Licensees holding a valid commercial license with Element may use this
software in accordance with the terms contained in a written agreement
between you and Element.
To purchase a commercial license please contact our sales team at
licensing@element.io

View file

@ -0,0 +1,61 @@
# Matrix Authentication Service
MAS (Matrix Authentication Service) is a user management and authentication service for [Matrix](https://matrix.org/) homeservers, written and maintained by [Element](https://element.io/). You can directly run and manage the source code in this repository, available under an AGPL license (or alternatively under a commercial license from Element). Support is not provided by Element unless you have a subscription.
It has been created to support the migration of Matrix to a next-generation of auth APIs per [MSC3861](https://github.com/matrix-org/matrix-doc/pull/3861).
See the [Documentation](https://element-hq.github.io/matrix-authentication-service/index.html) for information on installation and use.
You can learn more about Matrix and next-generation auth at [areweoidcyet.com](https://areweoidcyet.com/).
## 🚀 Getting started
This component is developed and maintained by [Element](https://element.io). It gets shipped as part of the **Element Server Suite (ESS)** which provides the official means of deployment.
ESS is a Matrix distribution from Element with focus on quality and ease of use. It ships a full Matrix stack tailored to the respective use case.
There are three editions of ESS:
- [ESS Community](https://github.com/element-hq/ess-helm) - the free Matrix
distribution from Element tailored to small-/mid-scale, non-commercial
community use cases
- [ESS Pro](https://element.io/server-suite) - the commercial Matrix
distribution from Element for professional use
- [ESS TI-M](https://element.io/server-suite/ti-messenger) - a special version
of ESS Pro focused on the requirements of TI-Messenger Pro and ePA as
specified by the German National Digital Health Agency Gematik
## 💬 Community room
Developers and users of Matrix Authentication Service can chat in the [#matrix-auth:matrix.org](https://matrix.to/#/#matrix-auth:matrix.org) room on Matrix.
## 🛠️ Standalone installation and configuration
The best way to get a modern Element Matrix stack is through the [Element Server Suite](https://element.io/en/server-suite), which includes MAS.
The MAS documentation describes [how to install and configure MAS](https://element-hq.github.io/matrix-authentication-service/setup/).
We recommend using the [Docker image](https://element-hq.github.io/matrix-authentication-service/setup/installation.html#using-the-docker-image) or the [pre-built binaries](https://element-hq.github.io/matrix-authentication-service/setup/installation.html#pre-built-binaries).
## 📖 Translations
Matrix Authentication Service is available in multiple languages.
Anyone can contribute to translations through [Localazy](https://localazy.com/element-matrix-authentication-service/).
## 🏗️ Contributing
See the [contribution guidelines](https://element-hq.github.io/matrix-authentication-service/development/contributing.html) for information on how to contribute to this project.
## ⚖️ Copyright & License
Copyright 2021-2024 The Matrix.org Foundation C.I.C.
Copyright 2024, 2025 New Vector Ltd.
Copyright 2025, 2026 Element Creations Ltd.
This software is dual-licensed by Element Creations Ltd (Element). It can be used either:
(1) for free under the terms of the GNU Affero General Public License (as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version); OR
(2) under the terms of a paid-for Element Commercial License agreement between you and Element (the terms of which may vary depending on what you and Element have agreed to).
Unless required by applicable law or agreed to in writing, software distributed under the Licenses is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the Licenses for the specific language governing permissions and limitations under the Licenses.

View file

@ -0,0 +1,60 @@
{
"$schema": "https://biomejs.dev/schemas/2.2.4/schema.json",
"assist": { "actions": { "source": { "organizeImports": "on" } } },
"vcs": {
"enabled": true,
"clientKind": "git",
"useIgnoreFile": true
},
"files": {
"includes": [
"**",
"!**/.devcontainer/**",
"!**/docs/**",
"!**/translations/**",
"!**/policies/**",
"!**/crates/**",
"!**/frontend/package.json",
"!**/frontend/src/gql/**",
"!**/frontend/src/routeTree.gen.ts",
"!**/frontend/.storybook/locales.ts",
"!**/frontend/.storybook/public/mockServiceWorker.js",
"!**/frontend/locales/**/*.json",
"!**/coverage/**",
"!**/dist/**"
]
},
"formatter": {
"enabled": true,
"useEditorconfig": true
},
"linter": {
"enabled": true,
"rules": {
"recommended": true,
"complexity": {
"noImportantStyles": "off"
},
"suspicious": {
"noUnknownAtRules": "off"
},
"correctness": {
"noUnusedImports": "warn",
"noUnusedVariables": "warn"
},
"style": {
"noParameterAssign": "error",
"useAsConstAssertion": "error",
"useDefaultParameterLast": "error",
"useEnumInitializers": "error",
"useSelfClosingElements": "error",
"useSingleVarDeclarator": "error",
"noUnusedTemplateLiteral": "error",
"useNumberNamespace": "error",
"noInferrableTypes": "error",
"noUselessElse": "error",
"noDescendingSpecificity": "off"
}
}
}
}

View file

@ -0,0 +1,26 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
# Documentation for possible options in this file is at
# https://rust-lang.github.io/mdBook/format/config.html
[book]
title = "Matrix Authentication Service"
authors = ["Element Backend Team"]
language = "en"
src = "docs"
[build]
build-dir = "target/book"
[output.html]
# The URL visitors will be directed to when they try to edit a page
edit-url-template = "https://github.com/element-hq/matrix-authentication-service/edit/main/{path}"
# The source code URL of the repository
git-repository-url = "https://github.com/element-hq/matrix-authentication-service"
# The path that the docs are hosted on
site-url = "/matrix-authentication-service/"

View file

@ -0,0 +1,22 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
doc-valid-idents = ["OpenID", "OAuth", "UserInfo", "..", "PostgreSQL", "SQLite"]
disallowed-methods = [
{ path = "rand::thread_rng", reason = "do not create rngs on the fly, pass them as parameters" },
{ path = "chrono::Utc::now", reason = "source the current time from the clock instead" },
{ path = "ulid::Ulid::from_datetime", reason = "use Ulid::from_datetime_with_source instead" },
{ path = "ulid::Ulid::new", reason = "use Ulid::from_datetime_with_source instead" },
{ path = "reqwest::Client::new", reason = "use mas_http::reqwest_client instead" },
{ path = "reqwest::RequestBuilder::send", reason = "use send_traced instead" },
]
disallowed-types = [
{ path = "std::path::PathBuf", reason = "use camino::Utf8PathBuf instead" },
{ path = "std::path::Path", reason = "use camino::Utf8Path instead" },
{ path = "axum::extract::Query", reason = "use axum_extra::extract::Query instead. The built-in version doesn't deserialise lists."},
{ path = "axum::extract::rejection::QueryRejection", reason = "use axum_extra::extract::QueryRejection instead"}
]

View file

@ -0,0 +1,48 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[package]
name = "mas-axum-utils"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
publish.workspace = true
[lints]
workspace = true
[dependencies]
anyhow.workspace = true
axum.workspace = true
axum-extra.workspace = true
base64ct.workspace = true
chrono.workspace = true
headers.workspace = true
http.workspace = true
icu_locid.workspace = true
mime.workspace = true
rand.workspace = true
reqwest.workspace = true
sentry.workspace = true
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true
url.workspace = true
ulid.workspace = true
oauth2-types.workspace = true
mas-data-model.workspace = true
mas-http.workspace = true
mas-iana.workspace = true
mas-jose.workspace = true
mas-keystore.workspace = true
mas-storage.workspace = true
mas-templates.workspace = true

View file

@ -0,0 +1,739 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::collections::HashMap;
use axum::{
BoxError, Json,
extract::{
Form, FromRequest,
rejection::{FailedToDeserializeForm, FormRejection},
},
response::IntoResponse,
};
use headers::authorization::{Basic, Bearer, Credentials as _};
use http::{Request, StatusCode};
use mas_data_model::{Client, JwksOrJwksUri};
use mas_http::RequestBuilderExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
use mas_keystore::Encrypter;
use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
use oauth2_types::errors::{ClientError, ClientErrorCode};
use serde::{Deserialize, de::DeserializeOwned};
use serde_json::Value;
use thiserror::Error;
use crate::record_error;
static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
#[derive(Deserialize)]
struct AuthorizedForm<F = ()> {
client_id: Option<String>,
client_secret: Option<String>,
client_assertion_type: Option<String>,
client_assertion: Option<String>,
#[serde(flatten)]
inner: F,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Credentials {
None {
client_id: String,
},
ClientSecretBasic {
client_id: String,
client_secret: String,
},
ClientSecretPost {
client_id: String,
client_secret: String,
},
ClientAssertionJwtBearer {
client_id: String,
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
},
BearerToken {
token: String,
},
}
impl Credentials {
/// Get the `client_id` of the credentials
#[must_use]
pub fn client_id(&self) -> Option<&str> {
match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
Credentials::BearerToken { .. } => None,
}
}
/// Get the bearer token from the credentials.
#[must_use]
pub fn bearer_token(&self) -> Option<&str> {
match self {
Credentials::BearerToken { token } => Some(token),
_ => None,
}
}
/// Fetch the client from the database
///
/// # Errors
///
/// Returns an error if the client could not be found or if the underlying
/// repository errored.
pub async fn fetch<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<Client>, E> {
let client_id = match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
Credentials::BearerToken { .. } => return Ok(None),
};
repo.oauth2_client().find_by_client_id(client_id).await
}
/// Verify credentials presented by the client for authentication
///
/// # Errors
///
/// Returns an error if the credentials are invalid.
#[tracing::instrument(skip_all)]
pub async fn verify(
&self,
http_client: &reqwest::Client,
encrypter: &Encrypter,
method: &OAuthClientAuthenticationMethod,
client: &Client,
) -> Result<(), CredentialsVerificationError> {
match (self, method) {
(Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
(
Credentials::ClientSecretPost { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretPost,
)
| (
Credentials::ClientSecretBasic { client_secret, .. },
OAuthClientAuthenticationMethod::ClientSecretBasic,
) => {
// Decrypt the client_secret
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
// Check if the client_secret matches
if client_secret.as_bytes() != decrypted_client_secret {
return Err(CredentialsVerificationError::ClientSecretMismatch);
}
}
(
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::PrivateKeyJwt,
) => {
// Get the client JWKS
let jwks = client
.jwks
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let jwks = fetch_jwks(http_client, jwks)
.await
.map_err(CredentialsVerificationError::JwksFetchFailed)?;
jwt.verify_with_jwks(&jwks)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(
Credentials::ClientAssertionJwtBearer { jwt, .. },
OAuthClientAuthenticationMethod::ClientSecretJwt,
) => {
// Decrypt the client_secret
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or(CredentialsVerificationError::InvalidClientConfig)?;
let decrypted_client_secret = encrypter
.decrypt_string(encrypted_client_secret)
.map_err(|_e| CredentialsVerificationError::DecryptionError)?;
jwt.verify_with_shared_secret(decrypted_client_secret)
.map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
}
(_, _) => {
return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
}
}
Ok(())
}
}
async fn fetch_jwks(
http_client: &reqwest::Client,
jwks: &JwksOrJwksUri,
) -> Result<PublicJsonWebKeySet, BoxError> {
let uri = match jwks {
JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
JwksOrJwksUri::JwksUri(u) => u,
};
let response = http_client
.get(uri.as_str())
.send_traced()
.await?
.error_for_status()?
.json()
.await?;
Ok(response)
}
#[derive(Debug, Error)]
pub enum CredentialsVerificationError {
#[error("failed to decrypt client credentials")]
DecryptionError,
#[error("invalid client configuration")]
InvalidClientConfig,
#[error("client secret did not match")]
ClientSecretMismatch,
#[error("authentication method mismatch")]
AuthenticationMethodMismatch,
#[error("invalid assertion signature")]
InvalidAssertionSignature,
#[error("failed to fetch jwks")]
JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl CredentialsVerificationError {
/// Returns true if the error is an internal error, not caused by the client
#[must_use]
pub fn is_internal(&self) -> bool {
matches!(
self,
Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
)
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct ClientAuthorization<F = ()> {
pub credentials: Credentials,
pub form: Option<F>,
}
impl<F> ClientAuthorization<F> {
/// Get the `client_id` from the credentials.
#[must_use]
pub fn client_id(&self) -> Option<&str> {
self.credentials.client_id()
}
}
#[derive(Debug, Error)]
pub enum ClientAuthorizationError {
#[error("Invalid Authorization header")]
InvalidHeader,
#[error("Could not deserialize request body")]
BadForm(#[source] FailedToDeserializeForm),
#[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
ClientIdMismatch { credential: String, form: String },
#[error("Unsupported client_assertion_type: {client_assertion_type}")]
UnsupportedClientAssertion { client_assertion_type: String },
#[error("No credentials were presented")]
MissingCredentials,
#[error("Invalid request")]
InvalidRequest,
#[error("Invalid client_assertion")]
InvalidAssertion,
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
impl IntoResponse for ClientAuthorizationError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(self, Self::Internal(_));
match &self {
ClientAuthorizationError::InvalidHeader => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid Authorization header",
)),
),
ClientAuthorizationError::BadForm(err) => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidRequest)
.with_description(format!("{err}")),
),
),
ClientAuthorizationError::ClientIdMismatch { .. } => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidGrant)
.with_description(format!("{self}")),
),
),
ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::InvalidRequest)
.with_description(format!("{self}")),
),
),
ClientAuthorizationError::MissingCredentials => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"No credentials were presented",
)),
),
ClientAuthorizationError::InvalidRequest => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::from(ClientErrorCode::InvalidRequest)),
),
ClientAuthorizationError::InvalidAssertion => (
StatusCode::BAD_REQUEST,
sentry_event_id,
Json(ClientError::new(
ClientErrorCode::InvalidRequest,
"Invalid client_assertion",
)),
),
ClientAuthorizationError::Internal(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
sentry_event_id,
Json(
ClientError::from(ClientErrorCode::ServerError)
.with_description(format!("{e}")),
),
),
}
.into_response()
}
}
impl<S, F> FromRequest<S> for ClientAuthorization<F>
where
F: DeserializeOwned,
S: Send + Sync,
{
type Rejection = ClientAuthorizationError;
async fn from_request(
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
enum Authorization {
Basic(String, String),
Bearer(String),
}
// Sadly, the typed-header 'Authorization' doesn't let us check for both
// Basic and Bearer at the same time, so we need to parse them manually
let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
let bytes = header.as_bytes();
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
let Some(decoded) = Basic::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};
Some(Authorization::Basic(
decoded.username().to_owned(),
decoded.password().to_owned(),
))
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
let Some(decoded) = Bearer::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};
Some(Authorization::Bearer(decoded.token().to_owned()))
} else {
return Err(ClientAuthorizationError::InvalidHeader);
}
} else {
None
};
// Take the form value
let (
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
form,
) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
Ok(Form(form)) => (
form.client_id,
form.client_secret,
form.client_assertion_type,
form.client_assertion,
Some(form.inner),
),
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(ClientAuthorizationError::BadForm(err));
}
// Other errors (body read twice, byte stream broke) return an internal error
Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
};
// And now, figure out the actual auth method
let credentials = match (
authorization,
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
) {
(
Some(Authorization::Basic(client_id, client_secret)),
client_id_from_form,
None,
None,
None,
) => {
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches with the header
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientSecretBasic {
client_id,
client_secret,
}
}
(None, Some(client_id), Some(client_secret), None, None) => {
// Got both client_id and client_secret from the form
Credentials::ClientSecretPost {
client_id,
client_secret,
}
}
(None, Some(client_id), None, None, None) => {
// Only got a client_id in the form
Credentials::None { client_id }
}
(
None,
client_id_from_form,
None,
Some(client_assertion_type),
Some(client_assertion),
) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
// Got a JWT bearer client_assertion
let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
.map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
client_id.clone()
} else {
return Err(ClientAuthorizationError::InvalidAssertion);
};
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches the one in the JWT
if client_id != client_id_from_form {
return Err(ClientAuthorizationError::ClientIdMismatch {
credential: client_id,
form: client_id_from_form,
});
}
}
Credentials::ClientAssertionJwtBearer {
client_id,
jwt: Box::new(jwt),
}
}
(None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
// Got another unsupported client_assertion
return Err(ClientAuthorizationError::UnsupportedClientAssertion {
client_assertion_type,
});
}
(Some(Authorization::Bearer(token)), None, None, None, None) => {
// Got a bearer token
Credentials::BearerToken { token }
}
(None, None, None, None, None) => {
// Special case when there are no credentials anywhere
return Err(ClientAuthorizationError::MissingCredentials);
}
_ => {
// Every other combination is an invalid request
return Err(ClientAuthorizationError::InvalidRequest);
}
};
Ok(ClientAuthorization { credentials, form })
}
}
#[cfg(test)]
mod tests {
use axum::body::Body;
use http::{Method, Request};
use super::*;
#[tokio::test]
async fn none_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Body::new("client_id=client-id&foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::None {
client_id: "client-id".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_secret_basic_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
// client_id in both header and body
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("client_id=client-id&foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretBasic {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
// client_id in both header and body mismatch
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(
http::header::AUTHORIZATION,
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
)
.body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
.unwrap();
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
));
// Invalid header
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Basic invalid")
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert!(matches!(
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
Err(ClientAuthorizationError::InvalidHeader),
));
}
#[tokio::test]
async fn client_secret_post_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(Body::new(
"client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::ClientSecretPost {
client_id: "client-id".to_owned(),
client_secret: "client-secret".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
#[tokio::test]
async fn client_assertion_test() {
// Signed with client_secret = "client-secret"
let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
let body = Body::new(format!(
"client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
));
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.body(body)
.unwrap();
let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap();
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
panic!("expected a JWT client_assertion");
};
assert_eq!(client_id, "client-id");
jwt.verify_with_shared_secret(b"client-secret".to_vec())
.unwrap();
}
#[tokio::test]
async fn bearer_token_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Bearer token")
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::BearerToken {
token: "token".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
}

View file

@ -0,0 +1,169 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
use std::convert::Infallible;
use axum::{
extract::{FromRef, FromRequestParts},
response::{IntoResponseParts, ResponseParts},
};
use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
use http::request::Parts;
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;
use url::Url;
#[derive(Debug, Error)]
#[error("could not decode cookie")]
pub enum CookieDecodeError {
Deserialize(#[from] serde_json::Error),
}
/// Manages cookie options and encryption key
///
/// This is meant to be accessible through axum's state via the [`FromRef`]
/// trait
#[derive(Clone)]
pub struct CookieManager {
options: CookieOption,
key: Key,
}
impl CookieManager {
#[must_use]
pub const fn new(base_url: Url, key: Key) -> Self {
let options = CookieOption::new(base_url);
Self { options, key }
}
#[must_use]
pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
let key = Key::derive_from(key);
Self::new(base_url, key)
}
#[must_use]
pub fn cookie_jar(&self) -> CookieJar {
let inner = PrivateCookieJar::new(self.key.clone());
let options = self.options.clone();
CookieJar { inner, options }
}
#[must_use]
pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar {
let inner = PrivateCookieJar::from_headers(headers, self.key.clone());
let options = self.options.clone();
CookieJar { inner, options }
}
}
impl<S> FromRequestParts<S> for CookieJar
where
CookieManager: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let cookie_manager = CookieManager::from_ref(state);
Ok(cookie_manager.cookie_jar_from_headers(&parts.headers))
}
}
#[derive(Debug, Clone)]
struct CookieOption {
base_url: Url,
}
impl CookieOption {
const fn new(base_url: Url) -> Self {
Self { base_url }
}
fn secure(&self) -> bool {
self.base_url.scheme() == "https"
}
fn path(&self) -> &str {
self.base_url.path()
}
fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
cookie.set_http_only(true);
cookie.set_secure(self.secure());
cookie.set_path(self.path().to_owned());
cookie.set_same_site(SameSite::Lax);
cookie
}
}
/// A cookie jar which encrypts cookies & sets secure options
pub struct CookieJar {
inner: PrivateCookieJar<Key>,
options: CookieOption,
}
impl CookieJar {
/// Save the given payload in a cookie
///
/// If `permanent` is true, the cookie will be valid for 10 years
///
/// # Panics
///
/// Panics if the payload cannot be serialized
#[must_use]
pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
let serialized =
serde_json::to_string(payload).expect("failed to serialize cookie payload");
let cookie = Cookie::new(key.to_owned(), serialized);
let mut cookie = self.options.apply(cookie);
if permanent {
// XXX: this should use a clock
cookie.make_permanent();
}
self.inner = self.inner.add(cookie);
self
}
/// Remove a cookie from the jar
#[must_use]
pub fn remove(mut self, key: &str) -> Self {
self.inner = self.inner.remove(key.to_owned());
self
}
/// Load and deserialize a cookie from the jar
///
/// Returns `None` if the cookie is not present
///
/// # Errors
///
/// Returns an error if the cookie cannot be deserialized
pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
let Some(cookie) = self.inner.get(key) else {
return Ok(None);
};
let decoded = serde_json::from_str(cookie.value())?;
Ok(Some(decoded))
}
}
impl IntoResponseParts for CookieJar {
type Error = Infallible;
fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
self.inner.into_response_parts(res)
}
}

View file

@ -0,0 +1,165 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use mas_data_model::Clock;
use rand::{Rng, RngCore, distributions::Standard, prelude::Distribution as _};
use serde::{Deserialize, Serialize};
use serde_with::{TimestampSeconds, serde_as};
use thiserror::Error;
use crate::cookies::{CookieDecodeError, CookieJar};
/// Failed to validate CSRF token
#[derive(Debug, Error)]
pub enum CsrfError {
/// The token in the form did not match the token in the cookie
#[error("CSRF token mismatch")]
Mismatch,
/// The token in the form did not match the token in the cookie
#[error("Missing CSRF cookie")]
Missing,
/// Failed to decode the token
#[error("could not decode CSRF cookie")]
DecodeCookie(#[from] CookieDecodeError),
/// The token expired
#[error("CSRF token expired")]
Expired,
/// Failed to decode the token
#[error("could not decode CSRF token")]
Decode(#[from] base64ct::Error),
}
/// A CSRF token
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
#[serde_as(as = "TimestampSeconds<i64>")]
expiration: DateTime<Utc>,
token: [u8; 32],
}
impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
let expiration = now + ttl;
Self { expiration, token }
}
/// Generate a new random token valid for a specified duration
fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
let token = Standard.sample(&mut rng);
Self::new(token, now, ttl)
}
/// Generate a new token with the same value but an up to date expiration
fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
Self::new(self.token, now, ttl)
}
/// Get the value to include in HTML forms
#[must_use]
pub fn form_value(&self) -> String {
Base64UrlUnpadded::encode_string(&self.token[..])
}
/// Verifies that the value got from an HTML form matches this token
///
/// # Errors
///
/// Returns an error if the value in the form does not match this token
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
let form_value = Base64UrlUnpadded::decode_vec(form_value)?;
if self.token[..] == form_value {
Ok(())
} else {
Err(CsrfError::Mismatch)
}
}
fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
if now < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
}
}
}
// A CSRF-protected form
#[derive(Deserialize)]
pub struct ProtectedForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
pub trait CsrfExt {
/// Get the current CSRF token out of the cookie jar, generating a new one
/// if necessary
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock;
/// Verify that the given CSRF-protected form is valid, returning the inner
/// value
///
/// # Errors
///
/// Returns an error if the CSRF cookie is missing or if the value in the
/// form is invalid
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock;
}
impl CsrfExt for CookieJar {
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock,
{
let now = clock.now();
let maybe_token = match self.load::<CsrfToken>("csrf") {
Ok(Some(token)) => {
let token = token.verify_expiration(now);
// If the token is expired, just ignore it
token.ok()
}
Ok(None) => None,
Err(e) => {
tracing::warn!("Failed to decode CSRF cookie: {}", e);
None
}
};
let token = maybe_token.map_or_else(
|| CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
|token| token.refresh(now, Duration::try_hours(1).unwrap()),
);
let jar = self.save("csrf", &token, false);
(token, jar)
}
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock,
{
let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
let token = token.verify_expiration(clock.now())?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)
}
}

View file

@ -0,0 +1,23 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use axum::response::{IntoResponse, Response};
use crate::InternalError;
/// A simple wrapper around an error that implements [`IntoResponse`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ErrorWrapper<T>(#[from] pub T);
impl<T> IntoResponse for ErrorWrapper<T>
where
T: std::error::Error + 'static,
{
fn into_response(self) -> Response {
InternalError::from(self.0).into_response()
}
}

View file

@ -0,0 +1,105 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use axum::{
Extension,
http::StatusCode,
response::{IntoResponse, Response},
};
use axum_extra::typed_header::TypedHeader;
use headers::ContentType;
use mas_templates::ErrorContext;
use crate::sentry::SentryEventID;
fn build_context(mut err: &dyn std::error::Error) -> ErrorContext {
let description = err.to_string();
let mut details = Vec::new();
while let Some(source) = err.source() {
err = source;
details.push(err.to_string());
}
ErrorContext::new()
.with_description(description)
.with_details(details.join("\n"))
}
pub struct GenericError {
error: Box<dyn std::error::Error + 'static>,
code: StatusCode,
}
impl IntoResponse for GenericError {
fn into_response(self) -> Response {
tracing::warn!(message = &*self.error);
let context = build_context(&*self.error);
let context_text = format!("{context}");
(
self.code,
TypedHeader(ContentType::text()),
Extension(context),
context_text,
)
.into_response()
}
}
impl GenericError {
pub fn new(code: StatusCode, err: impl std::error::Error + 'static) -> Self {
Self {
error: Box::new(err),
code,
}
}
}
pub struct InternalError {
error: Box<dyn std::error::Error + 'static>,
}
impl IntoResponse for InternalError {
fn into_response(self) -> Response {
tracing::error!(message = &*self.error);
let event_id = SentryEventID::for_last_event();
let context = build_context(&*self.error);
let context_text = format!("{context}");
(
StatusCode::INTERNAL_SERVER_ERROR,
TypedHeader(ContentType::text()),
event_id,
Extension(context),
context_text,
)
.into_response()
}
}
impl<E: std::error::Error + 'static> From<E> for InternalError {
fn from(err: E) -> Self {
Self {
error: Box::new(err),
}
}
}
impl InternalError {
/// Create a new error from a boxed error
#[must_use]
pub fn new(error: Box<dyn std::error::Error + 'static>) -> Self {
Self { error }
}
/// Create a new error from an [`anyhow::Error`]
#[must_use]
pub fn from_anyhow(err: anyhow::Error) -> Self {
Self {
error: err.into_boxed_dyn_error(),
}
}
}

View file

@ -0,0 +1,21 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use axum::response::{IntoResponse, Response};
use axum_extra::typed_header::TypedHeader;
use headers::ContentType;
use mas_jose::jwt::Jwt;
use mime::Mime;
pub struct JwtResponse<T>(pub Jwt<'static, T>);
impl<T> IntoResponse for JwtResponse<T> {
fn into_response(self) -> Response {
let application_jwt: Mime = "application/jwt".parse().unwrap();
let content_type = ContentType::from(application_jwt);
(TypedHeader(content_type), self.0.into_string()).into_response()
}
}

View file

@ -0,0 +1,280 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::cmp::Reverse;
use headers::{Error, Header};
use http::{HeaderName, HeaderValue, header::ACCEPT_LANGUAGE};
use icu_locid::Locale;
#[derive(PartialEq, Eq, Debug)]
struct AcceptLanguagePart {
// None means *
locale: Option<Locale>,
// Quality is between 0 and 1 with 3 decimal places
// Which we map from 0 to 1000, e.g. 0.5 becomes 500
quality: u16,
}
impl PartialOrd for AcceptLanguagePart {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for AcceptLanguagePart {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// When comparing two AcceptLanguage structs, we only consider the
// quality, in reverse.
Reverse(self.quality).cmp(&Reverse(other.quality))
}
}
/// A header that represents the `Accept-Language` header.
#[derive(PartialEq, Eq, Debug)]
pub struct AcceptLanguage {
parts: Vec<AcceptLanguagePart>,
}
impl AcceptLanguage {
pub fn iter(&self) -> impl Iterator<Item = &Locale> {
// This should stop when we hit the first None, aka the first *
self.parts.iter().map_while(|item| item.locale.as_ref())
}
}
/// Utility to trim ASCII whitespace from the start and end of a byte slice
const fn trim_bytes(mut bytes: &[u8]) -> &[u8] {
// Trim leading and trailing whitespace
while let [first, rest @ ..] = bytes {
if first.is_ascii_whitespace() {
bytes = rest;
} else {
break;
}
}
while let [rest @ .., last] = bytes {
if last.is_ascii_whitespace() {
bytes = rest;
} else {
break;
}
}
bytes
}
impl Header for AcceptLanguage {
fn name() -> &'static HeaderName {
&ACCEPT_LANGUAGE
}
fn decode<'i, I>(values: &mut I) -> Result<Self, Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
let mut parts = Vec::new();
for value in values {
for part in value.as_bytes().split(|b| *b == b',') {
let mut it = part.split(|b| *b == b';');
let locale = it.next().ok_or(Error::invalid())?;
let locale = trim_bytes(locale);
let locale = match locale {
b"*" => None,
locale => {
let locale =
Locale::try_from_bytes(locale).map_err(|_e| Error::invalid())?;
Some(locale)
}
};
let quality = if let Some(quality) = it.next() {
let quality = trim_bytes(quality);
let quality = quality.strip_prefix(b"q=").ok_or(Error::invalid())?;
let quality = std::str::from_utf8(quality).map_err(|_e| Error::invalid())?;
let quality = quality.parse::<f64>().map_err(|_e| Error::invalid())?;
// Bound the quality between 0 and 1
let quality = quality.clamp(0_f64, 1_f64);
// Make sure the iterator is empty
if it.next().is_some() {
return Err(Error::invalid());
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
{
f64::round(quality * 1000_f64) as u16
}
} else {
1000
};
parts.push(AcceptLanguagePart { locale, quality });
}
}
parts.sort();
Ok(AcceptLanguage { parts })
}
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
let mut value = String::new();
let mut first = true;
for part in &self.parts {
if first {
first = false;
} else {
value.push_str(", ");
}
if let Some(locale) = &part.locale {
value.push_str(&locale.to_string());
} else {
value.push('*');
}
if part.quality != 1000 {
value.push_str(";q=");
value.push_str(&(f64::from(part.quality) / 1000_f64).to_string());
}
}
// We know this is safe because we only use ASCII characters
values.extend(Some(HeaderValue::from_str(&value).unwrap()));
}
}
#[cfg(test)]
mod tests {
use headers::HeaderMapExt;
use http::{HeaderMap, HeaderValue, header::ACCEPT_LANGUAGE};
use icu_locid::locale;
use super::*;
#[test]
fn test_decode() {
let headers = HeaderMap::from_iter([(
ACCEPT_LANGUAGE,
HeaderValue::from_str("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5").unwrap(),
)]);
let accept_language: Option<AcceptLanguage> = headers.typed_get();
assert!(accept_language.is_some());
let accept_language = accept_language.unwrap();
assert_eq!(
accept_language,
AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 700,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
]
}
);
}
#[test]
/// Test that we can decode a header with multiple values unordered, and
/// that the output is ordered by quality
fn test_decode_order() {
let headers = HeaderMap::from_iter([(
ACCEPT_LANGUAGE,
HeaderValue::from_str("*;q=0.5, fr-CH, en;q=0.8, fr;q=0.9, de;q=0.9").unwrap(),
)]);
let accept_language: Option<AcceptLanguage> = headers.typed_get();
assert!(accept_language.is_some());
let accept_language = accept_language.unwrap();
assert_eq!(
accept_language,
AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
]
}
);
}
#[test]
fn test_encode() {
let accept_language = AcceptLanguage {
parts: vec![
AcceptLanguagePart {
locale: Some(locale!("fr-CH")),
quality: 1000,
},
AcceptLanguagePart {
locale: Some(locale!("fr")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("de")),
quality: 900,
},
AcceptLanguagePart {
locale: Some(locale!("en")),
quality: 800,
},
AcceptLanguagePart {
locale: None,
quality: 500,
},
],
};
let mut headers = HeaderMap::new();
headers.typed_insert(accept_language);
let header = headers.get(ACCEPT_LANGUAGE).unwrap();
assert_eq!(
header.to_str().unwrap(),
"fr-CH, fr;q=0.9, de;q=0.9, en;q=0.8, *;q=0.5"
);
}
}

View file

@ -0,0 +1,27 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
#![deny(clippy::future_not_send)]
#![allow(clippy::module_name_repetitions)]
pub mod client_authorization;
pub mod cookies;
pub mod csrf;
pub mod error_wrapper;
pub mod fancy_error;
pub mod jwt;
pub mod language_detection;
pub mod sentry;
pub mod session;
pub mod user_authorization;
pub use axum;
pub use self::{
error_wrapper::ErrorWrapper,
fancy_error::{GenericError, InternalError},
session::{SessionInfo, SessionInfoExt},
};

View file

@ -0,0 +1,65 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::convert::Infallible;
use axum::response::{IntoResponseParts, ResponseParts};
use sentry::types::Uuid;
/// A wrapper to include a Sentry event ID in the response headers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SentryEventID(Uuid);
impl SentryEventID {
/// Create a new Sentry event ID header for the last event on the hub.
pub fn for_last_event() -> Option<Self> {
sentry::last_event_id().map(Self)
}
}
impl From<Uuid> for SentryEventID {
fn from(uuid: Uuid) -> Self {
Self(uuid)
}
}
impl IntoResponseParts for SentryEventID {
type Error = Infallible;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut()
.insert("X-Sentry-Event-ID", self.0.to_string().parse().unwrap());
Ok(res)
}
}
/// Record an error. It will emit a tracing event with the error level if
/// matches the pattern, warning otherwise. It also returns the Sentry event ID
/// if the error was recorded.
#[macro_export]
macro_rules! record_error {
($error:expr, !) => {{
tracing::warn!(message = &$error as &dyn std::error::Error);
Option::<$crate::sentry::SentryEventID>::None
}};
($error:expr) => {{
tracing::error!(message = &$error as &dyn std::error::Error);
// With the `sentry-tracing` integration, Sentry should have
// captured an error, so let's extract the last event ID from the
// current hub
$crate::sentry::SentryEventID::for_last_event()
}};
($error:expr, $pattern:pat) => {
if let $pattern = $error {
record_error!($error)
} else {
record_error!($error, !)
}
};
}

View file

@ -0,0 +1,101 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use mas_data_model::BrowserSession;
use mas_storage::RepositoryAccess;
use serde::{Deserialize, Serialize};
use ulid::Ulid;
use crate::cookies::CookieJar;
/// An encrypted cookie to save the session ID
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub struct SessionInfo {
current: Option<Ulid>,
}
impl SessionInfo {
/// Forge the cookie from a [`BrowserSession`]
#[must_use]
pub fn from_session(session: &BrowserSession) -> Self {
Self {
current: Some(session.id),
}
}
/// Mark the session as ended
#[must_use]
pub fn mark_session_ended(mut self) -> Self {
self.current = None;
self
}
/// Load the active [`BrowserSession`] from database
///
/// # Errors
///
/// Returns an error if the underlying repository fails to load the session.
pub async fn load_active_session<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> {
let Some(session_id) = self.current else {
return Ok(None);
};
let maybe_session = repo
.browser_session()
.lookup(session_id)
.await?
// Ensure that the session is still active
.filter(BrowserSession::active);
Ok(maybe_session)
}
/// Get the current session ID, if any
#[must_use]
pub fn current_session_id(&self) -> Option<Ulid> {
self.current
}
}
pub trait SessionInfoExt {
#[must_use]
fn session_info(self) -> (SessionInfo, Self);
#[must_use]
fn update_session_info(self, info: &SessionInfo) -> Self;
#[must_use]
fn set_session(self, session: &BrowserSession) -> Self
where
Self: Sized,
{
let session_info = SessionInfo::from_session(session);
self.update_session_info(&session_info)
}
}
impl SessionInfoExt for CookieJar {
fn session_info(self) -> (SessionInfo, Self) {
let info = match self.load("session") {
Ok(Some(s)) => s,
Ok(None) => SessionInfo::default(),
Err(e) => {
tracing::error!("failed to load session cookie: {}", e);
SessionInfo::default()
}
};
let jar = self.update_session_info(&info);
(info, jar)
}
fn update_session_info(self, info: &SessionInfo) -> Self {
self.save("session", info, true)
}
}

View file

@ -0,0 +1,338 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{collections::HashMap, error::Error};
use axum::{
extract::{
Form, FromRequest, FromRequestParts,
rejection::{FailedToDeserializeForm, FormRejection},
},
response::{IntoResponse, Response},
};
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
use headers::{Authorization, Header, HeaderMapExt, HeaderName, authorization::Bearer};
use http::{HeaderMap, HeaderValue, Request, StatusCode, header::WWW_AUTHENTICATE};
use mas_data_model::{Clock, Session};
use mas_storage::{
RepositoryAccess,
oauth2::{OAuth2AccessTokenRepository, OAuth2SessionRepository},
};
use serde::{Deserialize, de::DeserializeOwned};
use thiserror::Error;
#[derive(Debug, Deserialize)]
struct AuthorizedForm<F> {
#[serde(default)]
access_token: Option<String>,
#[serde(flatten)]
inner: F,
}
#[derive(Debug)]
enum AccessToken {
Form(String),
Header(String),
None,
}
impl AccessToken {
async fn fetch<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<(mas_data_model::AccessToken, Session), AuthorizationVerificationError<E>> {
let token = match self {
AccessToken::Form(t) | AccessToken::Header(t) => t,
AccessToken::None => return Err(AuthorizationVerificationError::MissingToken),
};
let token = repo
.oauth2_access_token()
.find_by_token(token.as_str())
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
let session = repo
.oauth2_session()
.lookup(token.session_id)
.await?
.ok_or(AuthorizationVerificationError::InvalidToken)?;
Ok((token, session))
}
}
#[derive(Debug)]
pub struct UserAuthorization<F = ()> {
access_token: AccessToken,
form: Option<F>,
}
impl<F: Send> UserAuthorization<F> {
// TODO: take scopes to validate as parameter
/// Verify a user authorization and return the session and the protected
/// form value
///
/// # Errors
///
/// Returns an error if the token is invalid, if the user session ended or
/// if the form is missing
pub async fn protected_form<E>(
self,
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<(Session, F), AuthorizationVerificationError<E>> {
let Some(form) = self.form else {
return Err(AuthorizationVerificationError::MissingForm);
};
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
Ok((session, form))
}
// TODO: take scopes to validate as parameter
/// Verify a user authorization and return the session
///
/// # Errors
///
/// Returns an error if the token is invalid or if the user session ended
pub async fn protected<E>(
self,
repo: &mut impl RepositoryAccess<Error = E>,
clock: &impl Clock,
) -> Result<Session, AuthorizationVerificationError<E>> {
let (token, session) = self.access_token.fetch(repo).await?;
if !token.is_valid(clock.now()) || !session.is_valid() {
return Err(AuthorizationVerificationError::InvalidToken);
}
if !token.is_used() {
// Mark the token as used
repo.oauth2_access_token().mark_used(clock, token).await?;
}
Ok(session)
}
}
pub enum UserAuthorizationError {
InvalidHeader,
TokenInFormAndHeader,
BadForm(FailedToDeserializeForm),
Internal(Box<dyn Error>),
}
#[derive(Debug, Error)]
pub enum AuthorizationVerificationError<E> {
#[error("missing token")]
MissingToken,
#[error("invalid token")]
InvalidToken,
#[error("missing form")]
MissingForm,
#[error(transparent)]
Internal(#[from] E),
}
enum BearerError {
InvalidRequest,
InvalidToken,
#[allow(dead_code)]
InsufficientScope {
scope: Option<HeaderValue>,
},
}
impl BearerError {
fn error(&self) -> HeaderValue {
match self {
BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"),
BearerError::InvalidToken => HeaderValue::from_static("invalid_token"),
BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"),
}
}
fn params(&self) -> HashMap<&'static str, HeaderValue> {
match self {
BearerError::InsufficientScope { scope: Some(scope) } => {
let mut m = HashMap::new();
m.insert("scope", scope.clone());
m
}
_ => HashMap::new(),
}
}
}
enum WwwAuthenticate {
#[allow(dead_code)]
Basic { realm: HeaderValue },
Bearer {
realm: Option<HeaderValue>,
error: BearerError,
error_description: Option<HeaderValue>,
},
}
impl Header for WwwAuthenticate {
fn name() -> &'static HeaderName {
&WWW_AUTHENTICATE
}
fn decode<'i, I>(_values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i http::HeaderValue>,
{
Err(headers::Error::invalid())
}
fn encode<E: Extend<http::HeaderValue>>(&self, values: &mut E) {
let (scheme, params) = match self {
WwwAuthenticate::Basic { realm } => {
let mut params = HashMap::new();
params.insert("realm", realm.clone());
("Basic", params)
}
WwwAuthenticate::Bearer {
realm,
error,
error_description,
} => {
let mut params = error.params();
params.insert("error", error.error());
if let Some(realm) = realm {
params.insert("realm", realm.clone());
}
if let Some(error_description) = error_description {
params.insert("error_description", error_description.clone());
}
("Bearer", params)
}
};
let params = params.into_iter().map(|(k, v)| format!(" {k}={v:?}"));
let value: String = std::iter::once(scheme.to_owned()).chain(params).collect();
let value = HeaderValue::from_str(&value).unwrap();
values.extend(std::iter::once(value));
}
}
impl IntoResponse for UserAuthorizationError {
fn into_response(self) -> Response {
match self {
Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidRequest,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
}
impl<E> IntoResponse for AuthorizationVerificationError<E>
where
E: ToString,
{
fn into_response(self) -> Response {
match self {
Self::MissingForm | Self::MissingToken => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidRequest,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::InvalidToken => {
let mut headers = HeaderMap::new();
headers.typed_insert(WwwAuthenticate::Bearer {
realm: None,
error: BearerError::InvalidToken,
error_description: None,
});
(StatusCode::BAD_REQUEST, headers).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
}
}
}
impl<S, F> FromRequest<S> for UserAuthorization<F>
where
F: DeserializeOwned,
S: Send + Sync,
{
type Rejection = UserAuthorizationError;
async fn from_request(
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
let (mut parts, body) = req.into_parts();
let header =
TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
// Take the Authorization header
let token_from_header = match header {
Ok(header) => Some(header.token().to_owned()),
Err(err) => match err.reason() {
// If it's missing it is fine
TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error
_ => return Err(UserAuthorizationError::InvalidHeader),
},
};
let req = Request::from_parts(parts, body);
// Take the form value
let (token_from_form, form) =
match Form::<AuthorizedForm<F>>::from_request(req, state).await {
Ok(Form(form)) => (form.access_token, Some(form.inner)),
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(UserAuthorizationError::BadForm(err));
}
// Other errors (body read twice, byte stream broke) return an internal error
Err(e) => return Err(UserAuthorizationError::Internal(Box::new(e))),
};
let access_token = match (token_from_header, token_from_form) {
// Ensure the token should not be in both the form and the access token
(Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader),
(Some(t), None) => AccessToken::Header(t),
(None, Some(t)) => AccessToken::Form(t),
(None, None) => AccessToken::None,
};
Ok(UserAuthorization { access_token, form })
}
}

View file

@ -0,0 +1,103 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[package]
name = "mas-cli"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
publish.workspace = true
build = "build.rs"
[lints]
workspace = true
[dependencies]
anyhow.workspace = true
axum.workspace = true
bytes.workspace = true
camino.workspace = true
chrono.workspace = true
clap.workspace = true
console.workspace = true
dialoguer.workspace = true
dotenvy.workspace = true
figment.workspace = true
futures-util.workspace = true
headers.workspace = true
http-body-util.workspace = true
hyper.workspace = true
ipnetwork.workspace = true
itertools.workspace = true
listenfd.workspace = true
rand.workspace = true
rand_chacha.workspace = true
reqwest.workspace = true
rustls.workspace = true
sd-notify.workspace = true
serde_json.workspace = true
serde_yaml.workspace = true
sqlx.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tower.workspace = true
tower-http.workspace = true
url.workspace = true
zeroize.workspace = true
tracing.workspace = true
tracing-appender.workspace = true
tracing-subscriber.workspace = true
tracing-opentelemetry.workspace = true
opentelemetry.workspace = true
opentelemetry-http.workspace = true
opentelemetry-instrumentation-process.workspace = true
opentelemetry-instrumentation-tokio.workspace = true
opentelemetry-jaeger-propagator.workspace = true
opentelemetry-otlp.workspace = true
opentelemetry-prometheus-text-exporter.workspace = true
opentelemetry-resource-detectors.workspace = true
opentelemetry-semantic-conventions.workspace = true
opentelemetry-stdout.workspace = true
opentelemetry_sdk.workspace = true
sentry.workspace = true
sentry-tracing.workspace = true
sentry-tower.workspace = true
mas-config.workspace = true
mas-context.workspace = true
mas-data-model.workspace = true
mas-email.workspace = true
mas-handlers.workspace = true
mas-http.workspace = true
mas-i18n.workspace = true
mas-keystore.workspace = true
mas-listener.workspace = true
mas-matrix.workspace = true
mas-matrix-synapse.workspace = true
mas-policy.workspace = true
mas-router.workspace = true
mas-storage.workspace = true
mas-storage-pg.workspace = true
mas-tasks.workspace = true
mas-templates.workspace = true
mas-tower.workspace = true
syn2mas.workspace = true
[build-dependencies]
anyhow.workspace = true
vergen-gitcl.workspace = true
[features]
# Features used for the prebuilt binaries
dist = ["mas-config/dist"]
# Features used in the Docker image
docker = ["mas-config/docker"]

View file

@ -0,0 +1,36 @@
// Copyright 2024, 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use vergen_gitcl::{Emitter, GitclBuilder, RustcBuilder};
fn main() -> anyhow::Result<()> {
// Instruct rustc that we'll be using #[cfg(tokio_unstable)]
println!("cargo::rustc-check-cfg=cfg(tokio_unstable)");
// At build time, we override the version through the environment variable
// VERGEN_GIT_DESCRIBE. In some contexts, it means this variable is set but
// empty, so we unset it here.
if let Ok(ver) = std::env::var("VERGEN_GIT_DESCRIBE")
&& ver.is_empty()
{
#[allow(unsafe_code)]
// SAFETY: This is safe because the build script is running a single thread
unsafe {
std::env::remove_var("VERGEN_GIT_DESCRIBE");
}
}
let gitcl = GitclBuilder::default()
.describe(true, false, Some("v*.*.*"))
.build()?;
let rustc = RustcBuilder::default().semver(true).build()?;
Emitter::default()
.add_instructions(&gitcl)?
.add_instructions(&rustc)?
.emit()?;
Ok(())
}

View file

@ -0,0 +1,374 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{convert::Infallible, net::IpAddr, sync::Arc};
use axum::extract::{FromRef, FromRequestParts};
use ipnetwork::IpNetwork;
use mas_context::LogContext;
use mas_data_model::{AppVersion, BoxClock, BoxRng, SiteConfig, SystemClock};
use mas_handlers::{
ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, GraphQLSchema, Limiter,
MetadataCache, RequesterFingerprint, passwords::PasswordManager,
};
use mas_i18n::Translator;
use mas_keystore::{Encrypter, Keystore};
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxRepository, BoxRepositoryFactory, RepositoryFactory};
use mas_storage_pg::PgRepositoryFactory;
use mas_templates::Templates;
use opentelemetry::KeyValue;
use rand::SeedableRng;
use sqlx::PgPool;
use tracing::Instrument;
use crate::{VERSION, telemetry::METER};
#[derive(Clone)]
pub struct AppState {
pub repository_factory: PgRepositoryFactory,
pub templates: Templates,
pub key_store: Keystore,
pub cookie_manager: CookieManager,
pub encrypter: Encrypter,
pub url_builder: UrlBuilder,
pub homeserver_connection: Arc<dyn HomeserverConnection>,
pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: GraphQLSchema,
pub http_client: reqwest::Client,
pub password_manager: PasswordManager,
pub metadata_cache: MetadataCache,
pub site_config: SiteConfig,
pub activity_tracker: ActivityTracker,
pub trusted_proxies: Vec<IpNetwork>,
pub limiter: Limiter,
}
impl AppState {
/// Init the metrics for the app state.
pub fn init_metrics(&mut self) {
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.usage")
.with_description("The number of connections that are currently in `state` described by the state attribute.")
.with_unit("{connection}")
.with_callback(move |instrument| {
let idle = u32::try_from(pool.num_idle()).unwrap_or(u32::MAX);
let used = pool.size() - idle;
instrument.observe(i64::from(idle), &[KeyValue::new("state", "idle")]);
instrument.observe(i64::from(used), &[KeyValue::new("state", "used")]);
})
.build();
let pool = self.repository_factory.pool();
METER
.i64_observable_up_down_counter("db.connections.max")
.with_description("The maximum number of open connections allowed.")
.with_unit("{connection}")
.with_callback(move |instrument| {
let max_conn = pool.options().get_max_connections();
instrument.observe(i64::from(max_conn), &[]);
})
.build();
}
/// Init the metadata cache in the background
pub fn init_metadata_cache(&self) {
let factory = self.repository_factory.clone();
let metadata_cache = self.metadata_cache.clone();
let http_client = self.http_client.clone();
tokio::spawn(
LogContext::new("metadata-cache-warmup")
.run(async move || {
let mut repo = match factory.create().await {
Ok(conn) => conn,
Err(e) => {
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to acquire a database connection"
);
return;
}
};
if let Err(e) = metadata_cache
.warm_up_and_run(
&http_client,
std::time::Duration::from_secs(60 * 15),
&mut repo,
)
.await
{
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to warm up the metadata cache"
);
}
})
.instrument(tracing::info_span!("metadata_cache.background_warmup")),
);
}
}
// XXX(quenting): we only use this for the healthcheck endpoint, checking the db
// should be part of the repository
impl FromRef<AppState> for PgPool {
fn from_ref(input: &AppState) -> Self {
input.repository_factory.pool()
}
}
impl FromRef<AppState> for BoxRepositoryFactory {
fn from_ref(input: &AppState) -> Self {
input.repository_factory.clone().boxed()
}
}
impl FromRef<AppState> for GraphQLSchema {
fn from_ref(input: &AppState) -> Self {
input.graphql_schema.clone()
}
}
impl FromRef<AppState> for Templates {
fn from_ref(input: &AppState) -> Self {
input.templates.clone()
}
}
impl FromRef<AppState> for Arc<Translator> {
fn from_ref(input: &AppState) -> Self {
input.templates.translator()
}
}
impl FromRef<AppState> for Keystore {
fn from_ref(input: &AppState) -> Self {
input.key_store.clone()
}
}
impl FromRef<AppState> for Encrypter {
fn from_ref(input: &AppState) -> Self {
input.encrypter.clone()
}
}
impl FromRef<AppState> for UrlBuilder {
fn from_ref(input: &AppState) -> Self {
input.url_builder.clone()
}
}
impl FromRef<AppState> for reqwest::Client {
fn from_ref(input: &AppState) -> Self {
input.http_client.clone()
}
}
impl FromRef<AppState> for PasswordManager {
fn from_ref(input: &AppState) -> Self {
input.password_manager.clone()
}
}
impl FromRef<AppState> for CookieManager {
fn from_ref(input: &AppState) -> Self {
input.cookie_manager.clone()
}
}
impl FromRef<AppState> for MetadataCache {
fn from_ref(input: &AppState) -> Self {
input.metadata_cache.clone()
}
}
impl FromRef<AppState> for SiteConfig {
fn from_ref(input: &AppState) -> Self {
input.site_config.clone()
}
}
impl FromRef<AppState> for Limiter {
fn from_ref(input: &AppState) -> Self {
input.limiter.clone()
}
}
impl FromRef<AppState> for Arc<PolicyFactory> {
fn from_ref(input: &AppState) -> Self {
input.policy_factory.clone()
}
}
impl FromRef<AppState> for Arc<dyn HomeserverConnection> {
fn from_ref(input: &AppState) -> Self {
Arc::clone(&input.homeserver_connection)
}
}
impl FromRef<AppState> for AppVersion {
fn from_ref(_input: &AppState) -> Self {
AppVersion(VERSION)
}
}
impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &AppState,
) -> Result<Self, Self::Rejection> {
let clock = SystemClock::default();
Ok(Box::new(clock))
}
}
impl FromRequestParts<AppState> for BoxRng {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
_state: &AppState,
) -> Result<Self, Self::Rejection> {
// This rng is used to source the local rng
#[allow(clippy::disallowed_methods)]
let rng = rand::thread_rng();
let rng = rand_chacha::ChaChaRng::from_rng(rng).expect("Failed to seed RNG");
Ok(Box::new(rng))
}
}
impl FromRequestParts<AppState> for Policy {
type Rejection = ErrorWrapper<mas_policy::InstantiateError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let policy = state.policy_factory.instantiate().await?;
Ok(policy)
}
}
impl FromRequestParts<AppState> for ActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
Ok(state.activity_tracker.clone())
}
}
fn infer_client_ip(
parts: &axum::http::request::Parts,
trusted_proxies: &[IpNetwork],
) -> Option<IpAddr> {
let connection_info = parts.extensions.get::<mas_listener::ConnectionInfo>();
let peer = if let Some(info) = connection_info {
// We can always trust the proxy protocol to give us the correct IP address
if let Some(proxy) = info.get_proxy_ref()
&& let Some(source) = proxy.source()
{
return Some(source.ip());
}
info.get_peer_addr().map(|addr| addr.ip())
} else {
None
};
// Get the list of IPs from the X-Forwarded-For header
let peers_from_header = parts
.headers
.get("x-forwarded-for")
.and_then(|value| value.to_str().ok())
.map(|value| value.split(',').filter_map(|v| v.parse().ok()))
.into_iter()
.flatten();
// This constructs a list of IP addresses that might be the client's IP address.
// Each intermediate proxy is supposed to add the client's IP address to front
// of the list. We are effectively adding the IP we got from the socket to the
// front of the list.
// We also call `to_canonical` so that IPv6-mapped IPv4 addresses
// (::ffff:A.B.C.D) are converted to IPv4.
let peer_list: Vec<IpAddr> = peer
.into_iter()
.chain(peers_from_header)
.map(|ip| ip.to_canonical())
.collect();
// We'll fallback to the first IP in the list if all the IPs we got are trusted
let fallback = peer_list.first().copied();
// Now we go through the list, and the IP of the client is the first IP that is
// not in the list of trusted proxies, starting from the back.
let client_ip = peer_list
.iter()
.rfind(|ip| !trusted_proxies.iter().any(|network| network.contains(**ip)))
.copied();
client_ip.or(fallback)
}
impl FromRequestParts<AppState> for BoundActivityTracker {
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// TODO: we may infer the IP twice, for the activity tracker and the limiter
let ip = infer_client_ip(parts, &state.trusted_proxies);
tracing::debug!(ip = ?ip, "Inferred client IP address");
Ok(state.activity_tracker.clone().bind(ip))
}
}
impl FromRequestParts<AppState> for RequesterFingerprint {
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
// TODO: we may infer the IP twice, for the activity tracker and the limiter
let ip = infer_client_ip(parts, &state.trusted_proxies);
if let Some(ip) = ip {
Ok(RequesterFingerprint::new(ip))
} else {
// If we can't infer the IP address, we'll just use an empty fingerprint and
// warn about it
tracing::warn!(
"Could not infer client IP address for an operation which rate-limits based on IP addresses"
);
Ok(RequesterFingerprint::EMPTY)
}
}
}
impl FromRequestParts<AppState> for BoxRepository {
type Rejection = ErrorWrapper<mas_storage::RepositoryError>;
async fn from_request_parts(
_parts: &mut axum::http::request::Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let repo = state.repository_factory.create().await?;
Ok(repo)
}
}

View file

@ -0,0 +1,151 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::process::ExitCode;
use anyhow::Context;
use camino::Utf8PathBuf;
use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSection, RootConfig, SyncConfig};
use mas_data_model::{Clock as _, SystemClock};
use rand::SeedableRng;
use tokio::io::AsyncWriteExt;
use tracing::{info, info_span};
use crate::util::database_connection_from_config;
#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Dump the current config as YAML
Dump {
/// The path to the config file to dump
///
/// If not specified, the config will be written to stdout
#[clap(short, long)]
output: Option<Utf8PathBuf>,
},
/// Check a config file
Check,
/// Generate a new config file
Generate {
/// The path to the config file to generate
///
/// If not specified, the config will be written to stdout
#[clap(short, long)]
output: Option<Utf8PathBuf>,
/// Existing Synapse configuration used to generate the MAS config
#[arg(short, long, action = clap::ArgAction::Append)]
synapse_config: Vec<Utf8PathBuf>,
},
/// Sync the clients and providers from the config file to the database
Sync {
/// Prune elements that are in the database but not in the config file
/// anymore
#[clap(long)]
prune: bool,
/// Do not actually write to the database
#[clap(long)]
dry_run: bool,
},
}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as SC;
match self.subcommand {
SC::Dump { output } => {
let _span = info_span!("cli.config.dump").entered();
let config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let config = serde_yaml::to_string(&config)?;
if let Some(output) = output {
info!("Writing configuration to {output:?}");
let mut file = tokio::fs::File::create(output).await?;
file.write_all(config.as_bytes()).await?;
} else {
info!("Writing configuration to standard output");
tokio::io::stdout().write_all(config.as_bytes()).await?;
}
}
SC::Check => {
let _span = info_span!("cli.config.check").entered();
let _config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Configuration file looks good");
}
SC::Generate {
output,
synapse_config,
} => {
let _span = info_span!("cli.config.generate").entered();
let clock = SystemClock::default();
// XXX: we should disallow SeedableRng::from_entropy
let mut rng = rand_chacha::ChaChaRng::from_entropy();
let mut config = RootConfig::generate(&mut rng).await?;
if !synapse_config.is_empty() {
info!("Adjusting MAS config to match Synapse config from {synapse_config:?}");
let synapse_config = syn2mas::synapse_config::Config::load(&synapse_config)
.map_err(anyhow::Error::from_boxed)?;
config = synapse_config.adjust_mas_config(config, &mut rng, clock.now());
}
let config = serde_yaml::to_string(&config)?;
if let Some(output) = output {
info!("Writing configuration to {output:?}");
let mut file = tokio::fs::File::create(output).await?;
file.write_all(config.as_bytes()).await?;
} else {
info!("Writing configuration to standard output");
tokio::io::stdout().write_all(config.as_bytes()).await?;
}
}
SC::Sync { prune, dry_run } => {
let config = SyncConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter().await?;
// Grab a connection to the database
let mut conn = database_connection_from_config(&config.database).await?;
mas_storage_pg::migrate(&mut conn)
.await
.context("could not run migrations")?;
crate::sync::config_sync(
config.upstream_oauth2,
config.clients,
&mut conn,
&encrypter,
&clock,
prune,
dry_run,
)
.await
.context("could not sync the configuration with the database")?;
}
}
Ok(ExitCode::SUCCESS)
}
}

View file

@ -0,0 +1,43 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::process::ExitCode;
use anyhow::Context;
use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSectionExt, DatabaseConfig};
use tracing::info_span;
use crate::util::database_connection_from_config;
#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Run database migrations
Migrate,
}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let _span = info_span!("cli.database.migrate").entered();
let config =
DatabaseConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&config).await?;
// Run pending migrations
mas_storage_pg::migrate(&mut conn)
.await
.context("could not run migrations")?;
Ok(ExitCode::SUCCESS)
}
}

View file

@ -0,0 +1,70 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::process::ExitCode;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, ExperimentalConfig,
MatrixConfig, PolicyConfig,
};
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
use crate::util::{
database_pool_from_config, load_policy_factory_dynamic_data, policy_factory_from_config,
};
#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Check that the policies compile
Policy {
/// With dynamic data loaded
#[arg(long)]
with_dynamic_data: bool,
},
}
impl Options {
#[tracing::instrument(skip_all)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as SC;
match self.subcommand {
SC::Policy { with_dynamic_data } => {
let _span = info_span!("cli.debug.policy").entered();
let config =
PolicyConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config =
ExperimentalConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!("Loading and compiling the policy module");
let policy_factory =
policy_factory_from_config(&config, &matrix_config, &experimental_config)
.await?;
if with_dynamic_data {
let database_config =
DatabaseConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let pool = database_pool_from_config(&database_config).await?;
let repository_factory = PgRepositoryFactory::new(pool.clone());
load_policy_factory_dynamic_data(&policy_factory, &repository_factory).await?;
}
let _instance = policy_factory.instantiate().await?;
}
}
Ok(ExitCode::SUCCESS)
}
}

View file

@ -0,0 +1,410 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
//! Diagnostic utility to check the health of the deployment
//!
//! The code is quite repetitive for now, but we can refactor later with a
//! better check abstraction
use std::process::ExitCode;
use anyhow::Context;
use clap::Parser;
use figment::Figment;
use hyper::StatusCode;
use mas_config::{ConfigurationSection, RootConfig};
use mas_http::RequestBuilderExt;
use tracing::{error, info, info_span, warn};
use url::{Host, Url};
/// Base URL for the human-readable documentation
const DOCS_BASE: &str = "https://element-hq.github.io/matrix-authentication-service";
#[derive(Parser, Debug)]
pub(super) struct Options {}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let _span = info_span!("cli.doctor").entered();
info!(
"💡 Running diagnostics, make sure that both MAS and Synapse are running, and that MAS is using the same configuration files as this tool."
);
let config = RootConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
// We'll need an HTTP client
let http_client = mas_http::reqwest_client();
let base_url = config.http.public_base.as_str();
let issuer = config.http.issuer.as_ref().map(url::Url::as_str);
let issuer = issuer.unwrap_or(base_url);
let matrix_domain: Host = Host::parse(&config.matrix.homeserver).context(
r"The homeserver host in the config (`matrix.homeserver`) is not a valid domain.
See {DOCS_BASE}/setup/homeserver.html",
)?;
let secret = config.matrix.secret().await?;
let hs_api = config.matrix.endpoint;
if !issuer.starts_with("https://") {
warn!(
r"⚠️ The issuer in the config (`http.issuer`/`http.public_base`) is not an HTTPS URL.
This means some clients will refuse to use it."
);
}
let well_known_uri = format!("https://{matrix_domain}/.well-known/matrix/client");
let result = http_client.get(&well_known_uri).send_traced().await;
let expected_well_known = serde_json::json!({
"m.homeserver": {
"base_url": "...",
},
"org.matrix.msc2965.authentication": {
"issuer": issuer,
"account": format!("{base_url}account/"),
},
});
let discovered_cs_api = match result {
Ok(response) => {
// Make sure we got a 2xx response
let status = response.status();
if !status.is_success() {
warn!(
r#"⚠️ Matrix client well-known replied with {status}, expected 2xx.
Make sure the homeserver is reachable and the well-known document is available at "{well_known_uri}""#,
);
}
let result = response.json::<serde_json::Value>().await;
match result {
Ok(body) => {
if let Some(auth) = body.get("org.matrix.msc2965.authentication") {
if let Some(wk_issuer) =
auth.get("issuer").and_then(|issuer| issuer.as_str())
{
if issuer == wk_issuer {
info!(
r#"✅ Matrix client well-known at "{well_known_uri}" is valid"#
);
} else {
warn!(
r#"⚠️ Matrix client well-known has an "org.matrix.msc2965.authentication" section, but the issuer is not the same as the homeserver.
Check the well-known document at "{well_known_uri}"
This can happen because MAS parses the URL its config differently from the homeserver.
This means some OIDC-native clients might not work.
Make sure that the MAS config contains:
http:
public_base: {issuer:?}
And in the Synapse config:
matrix_authentication_service:
enabled: true
# This must point to where MAS is reachable by Synapse
endpoint: {issuer:?}
# ...
See {DOCS_BASE}/setup/homeserver.html
"#
);
}
} else {
error!(
r#"❌ Matrix client well-known "org.matrix.msc2965.authentication" does not have a valid "issuer" field.
Check the well-known document at "{well_known_uri}"
"#
);
}
} else {
warn!(
r#"Matrix client well-known is missing the "org.matrix.msc2965.authentication" section.
Check the well-known document at "{well_known_uri}"
Make sure Synapse has delegated auth enabled:
matrix_authentication_service:
enabled: true
endpoint: {issuer:?}
# ...
If it is not Synapse handling the well-known document, update it to include the following:
{expected_well_known:#}
See {DOCS_BASE}/setup/homeserver.html
"#
);
}
// Return the discovered homeserver base URL
body.get("m.homeserver")
.and_then(|hs| hs.get("base_url"))
.and_then(|base_url| base_url.as_str())
.and_then(|base_url| Url::parse(base_url).ok())
}
Err(e) => {
warn!(
r#"⚠️ Invalid JSON for the well-known document at "{well_known_uri}".
Make sure going to {well_known_uri:?} in a web browser returns a valid JSON document, similar to:
{expected_well_known:#}
See {DOCS_BASE}/setup/homeserver.html
Error details: {e}
"#
);
None
}
}
}
Err(e) => {
warn!(
r#"⚠️ Failed to fetch well-known document at "{well_known_uri}".
This means that the homeserver is not reachable, the well-known document is not available, or malformed.
Make sure your homeserver is running.
Make sure going to {well_known_uri:?} in a web browser returns a valid JSON document, similar to:
{expected_well_known:#}
See {DOCS_BASE}/setup/homeserver.html
Error details: {e}
"#
);
None
}
};
// Now try to reach the homeserver
let client_versions = hs_api.join("/_matrix/client/versions")?;
let result = http_client
.get(client_versions.as_str())
.send_traced()
.await;
let can_reach_cs = match result {
Ok(response) => {
let status = response.status();
if status.is_success() {
info!(r#"✅ Homeserver is reachable at "{client_versions}""#);
true
} else {
error!(
r#"❌Can't reach the homeserver at "{client_versions}", got {status}.
Make sure your homeserver is running.
This may be due to a misconfiguration in the `matrix` section of the config.
matrix:
homeserver: "{matrix_domain}"
# The homeserver should be reachable at this URL
endpoint: "{hs_api}"
See {DOCS_BASE}/setup/homeserver.html
"#
);
false
}
}
Err(e) => {
error!(
r#"❌ Can't reach the homeserver at "{client_versions}".
This may be due to a misconfiguration in the `matrix` section of the config.
matrix:
homeserver: "{matrix_domain}"
# The homeserver should be reachable at this URL
endpoint: "{hs_api}"
See {DOCS_BASE}/setup/homeserver.html
Error details: {e}
"#
);
false
}
};
if can_reach_cs {
// Try the whoami API. If it replies with `M_UNKNOWN` this is because Synapse
// couldn't reach MAS
let whoami = hs_api.join("/_matrix/client/v3/account/whoami")?;
let result = http_client
.get(whoami.as_str())
.bearer_auth("averyinvalidtokenireallyhopethisisnotvalid")
.send_traced()
.await;
match result {
Ok(response) => {
let status = response.status();
let body = response.text().await.unwrap_or("???".into());
match status.as_u16() {
401 => info!(
r#"✅ Homeserver at "{whoami}" is reachable, and it correctly rejected an invalid token."#
),
0..=399 => error!(
r#"❌ The homeserver at "{whoami}" replied with {status}.
This is *highly* unexpected, as this means that a fake token might have been accepted.
"#
),
503 => error!(
r#"❌ The homeserver at "{whoami}" replied with {status}.
This means probably means that the homeserver was unable to reach MAS to validate the token.
Make sure MAS is running and reachable from Synapse.
Check your homeserver logs.
This is what the homeserver told us about the error:
{body}
See {DOCS_BASE}/setup/homeserver.html
"#
),
_ => warn!(
r#"⚠️ The homeserver at "{whoami}" replied with {status}.
Check that the homeserver is running."#
),
}
}
Err(e) => error!(
r#"❌ Can't reach the homeserver at "{whoami}".
Error details: {e}
"#
),
}
// Try to reach an authenticated MAS API endpoint
let mas_api = hs_api.join("/_synapse/mas/is_localpart_available")?;
let result = http_client
.get(mas_api.as_str())
.bearer_auth(&secret)
.send_traced()
.await;
match result {
Ok(response) => {
let status = response.status();
// We intentionally omit the required 'localpart' parameter
// in this request. If authentication is successful, Synapse
// returns a 400 Bad Request because of the missing
// parameter. If authentication fails, Synapse will return a
// 403 Forbidden. If the MAS integration isn't enabled,
// Synapse will return a 404 Not found.
if status == StatusCode::BAD_REQUEST {
info!(
r#"✅ The Synapse MAS API is reachable with authentication at "{mas_api}"."#
);
} else {
error!(
r#"❌ A Synapse MAS API endpoint at "{mas_api}" replied with {status}.
Make sure the homeserver is running, and that the MAS config has the correct `matrix.secret`.
It should match the `secret` set in the Synapse config.
matrix_authentication_service:
enabled: true
endpoint: {issuer:?}
# This must exactly match the secret in the MAS config:
secret: {secret:?}
And in the MAS config:
matrix:
homeserver: "{matrix_domain}"
endpoint: "{hs_api}"
secret: {secret:?}
"#
);
}
}
Err(e) => error!(
r#"❌ Can't reach the Synapse MAS API at "{mas_api}".
Make sure the homeserver is running, and that the MAS config has the correct `matrix.secret`.
Error details: {e}
"#
),
}
}
let external_cs_api_endpoint = discovered_cs_api.as_ref().unwrap_or(&hs_api);
// Try to reach the legacy login API
let compat_login = external_cs_api_endpoint.join("/_matrix/client/v3/login")?;
let compat_login = compat_login.as_str();
let result = http_client.get(compat_login).send_traced().await;
match result {
Ok(response) => {
let status = response.status();
if status.is_success() {
// Now we need to inspect the body to figure out whether it's Synapse or MAS
// which handled the request
let body = response
.json::<serde_json::Value>()
.await
.unwrap_or_default();
let flows = body
.get("flows")
.and_then(|flows| flows.as_array())
.map(std::vec::Vec::as_slice)
.unwrap_or_default();
let has_compatibility_sso = flows.iter().any(|flow| {
flow.get("type").and_then(|t| t.as_str()) == Some("m.login.sso")
&& (flow
.get("oauth_aware_preferred")
.and_then(serde_json::Value::as_bool)
== Some(true)
// we check for the unstable name too:
|| flow
.get("org.matrix.msc3824.delegated_oidc_compatibility")
.and_then(serde_json::Value::as_bool)
== Some(true))
});
if has_compatibility_sso {
info!(
r#"✅ The legacy login API at "{compat_login}" is reachable and is handled by MAS."#
);
} else {
warn!(
r#"⚠️ The legacy login API at "{compat_login}" is reachable, but it doesn't look to be handled by MAS.
This means legacy clients won't be able to login.
Make sure MAS is running.
Check your reverse proxy settings to make sure that this API is handled by MAS, not by Synapse.
See {DOCS_BASE}/setup/reverse-proxy.html
"#
);
}
} else {
error!(
r#"The legacy login API at "{compat_login}" replied with {status}.
This means legacy clients won't be able to login.
Make sure MAS is running.
Check your reverse proxy settings to make sure that this API is handled by MAS, not by Synapse.
See {DOCS_BASE}/setup/reverse-proxy.html
"#
);
}
}
Err(e) => warn!(
r#"⚠️ Can't reach the legacy login API at "{compat_login}".
This means legacy clients won't be able to login.
Make sure MAS is running.
Check your reverse proxy settings to make sure that this API is handled by MAS, not by Synapse.
See {DOCS_BASE}/setup/reverse-proxy.html
Error details: {e}"#
),
}
Ok(ExitCode::SUCCESS)
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,109 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::process::ExitCode;
use camino::Utf8PathBuf;
use clap::Parser;
use figment::{
Figment,
providers::{Env, Format, Yaml},
};
mod config;
mod database;
mod debug;
mod doctor;
mod manage;
mod server;
mod syn2mas;
mod templates;
mod worker;
#[derive(Parser, Debug)]
enum Subcommand {
/// Configuration-related commands
Config(self::config::Options),
/// Manage the database
Database(self::database::Options),
/// Runs the web server
Server(self::server::Options),
/// Run the worker
Worker(self::worker::Options),
/// Manage the instance
Manage(self::manage::Options),
/// Templates-related commands
Templates(self::templates::Options),
/// Debug utilities
#[clap(hide = true)]
Debug(self::debug::Options),
/// Run diagnostics on the deployment
Doctor(self::doctor::Options),
/// Migrate from Synapse's built-in auth system to MAS.
#[clap(name = "syn2mas")]
// Box<> is to work around a 'large size difference between variants' lint
Syn2Mas(Box<self::syn2mas::Options>),
}
#[derive(Parser, Debug)]
#[command(version = crate::VERSION)]
pub struct Options {
/// Path to the configuration file
#[arg(short, long, global = true, action = clap::ArgAction::Append)]
config: Vec<Utf8PathBuf>,
#[command(subcommand)]
subcommand: Option<Subcommand>,
}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as S;
// We Box the futures for each subcommand so that we avoid this function being
// big on the stack all the time
match self.subcommand {
Some(S::Config(c)) => Box::pin(c.run(figment)).await,
Some(S::Database(c)) => Box::pin(c.run(figment)).await,
Some(S::Server(c)) => Box::pin(c.run(figment)).await,
Some(S::Worker(c)) => Box::pin(c.run(figment)).await,
Some(S::Manage(c)) => Box::pin(c.run(figment)).await,
Some(S::Templates(c)) => Box::pin(c.run(figment)).await,
Some(S::Debug(c)) => Box::pin(c.run(figment)).await,
Some(S::Doctor(c)) => Box::pin(c.run(figment)).await,
Some(S::Syn2Mas(c)) => Box::pin(c.run(figment)).await,
None => Box::pin(self::server::Options::default().run(figment)).await,
}
}
/// Get a [`Figment`] instance with the configuration loaded
pub fn figment(&self) -> Figment {
let configs = if self.config.is_empty() {
// Read the MAS_CONFIG environment variable
std::env::var("MAS_CONFIG")
// Default to "config.yaml"
.unwrap_or_else(|_| "config.yaml".to_owned())
// Split the file list on `:`
.split(':')
.map(Utf8PathBuf::from)
.collect()
} else {
self.config.clone()
};
let base = Figment::new().merge(Env::prefixed("MAS_").split("_"));
configs
.into_iter()
.fold(base, |f, path| f.admerge(Yaml::file(path)))
}
}

View file

@ -0,0 +1,341 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{process::ExitCode, sync::Arc, time::Duration};
use anyhow::Context;
use clap::Parser;
use figment::Figment;
use itertools::Itertools;
use mas_config::{
AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config,
};
use mas_context::LogContext;
use mas_data_model::SystemClock;
use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_router::UrlBuilder;
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span, warn};
use crate::{
app_state::AppState,
lifecycle::LifecycleManager,
util::{
database_pool_from_config, homeserver_connection_from_config,
load_policy_factory_dynamic_data_continuously, mailer_from_config,
password_manager_from_config, policy_factory_from_config, site_config_from_config,
templates_from_config, test_mailer_in_background,
},
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Parser, Debug, Default)]
pub(super) struct Options {
/// Do not apply pending database migrations on start
#[arg(long)]
no_migrate: bool,
/// DEPRECATED: default is to apply pending migrations, use `--no-migrate`
/// to disable
#[arg(long, hide = true)]
migrate: bool,
/// Do not start the task worker
#[arg(long)]
no_worker: bool,
/// Do not sync the configuration with the database
#[arg(long)]
no_sync: bool,
}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let span = info_span!("cli.run.init").entered();
let mut shutdown = LifecycleManager::new()?;
let config = AppConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
info!(version = crate::VERSION, "Starting up");
if self.migrate {
warn!(
"The `--migrate` flag is deprecated and will be removed in a future release. Please use `--no-migrate` to disable automatic migrations on startup."
);
}
// Connect to the database
info!("Connecting to the database");
let pool = database_pool_from_config(&config.database).await?;
if self.no_migrate {
let mut conn = pool.acquire().await?;
let pending_migrations = mas_storage_pg::pending_migrations(&mut conn).await?;
if !pending_migrations.is_empty() {
// Refuse to start if there are pending migrations
return Err(anyhow::anyhow!(
"The server is running with `--no-migrate` but there are pending migrations. Please run them first with `mas-cli database migrate`, or omit the `--no-migrate` flag to apply them automatically on startup."
));
}
} else {
info!("Running pending database migrations");
let mut conn = pool.acquire().await?;
mas_storage_pg::migrate(&mut conn)
.await
.context("could not run migrations")?;
}
let encrypter = config.secrets.encrypter().await?;
if self.no_sync {
info!("Skipping configuration sync");
} else {
// Sync the configuration with the database
let mut conn = pool.acquire().await?;
let clients_config =
ClientsConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let upstream_oauth2_config = UpstreamOAuth2Config::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
crate::sync::config_sync(
upstream_oauth2_config,
clients_config,
&mut conn,
&encrypter,
&SystemClock::default(),
false,
false,
)
.await
.context("could not sync the configuration with the database")?;
}
// Initialize the key store
let key_store = config
.secrets
.key_store()
.await
.context("could not import keys from config")?;
let cookie_manager = CookieManager::derive_from(
config.http.public_base.clone(),
&config.secrets.encryption().await?,
);
// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory =
policy_factory_from_config(&config.policy, &config.matrix, &config.experimental)
.await?;
let policy_factory = Arc::new(policy_factory);
load_policy_factory_dynamic_data_continuously(
&policy_factory,
PgRepositoryFactory::new(pool.clone()).boxed(),
shutdown.soft_shutdown_token(),
shutdown.task_tracker(),
)
.await?;
let url_builder = UrlBuilder::new(
config.http.public_base.clone(),
config.http.issuer.clone(),
None,
);
// Load the site configuration
let site_config = site_config_from_config(
&config.branding,
&config.matrix,
&config.experimental,
&config.passwords,
&config.account,
&config.captcha,
)?;
// Load and compile the templates
let templates = templates_from_config(
&config.templates,
&site_config,
&url_builder,
// Don't use strict mode in production yet
false,
// Don't stabilise in production
false,
)
.await?;
shutdown.register_reloadable(&templates);
let http_client = mas_http::reqwest_client();
let homeserver_connection =
homeserver_connection_from_config(&config.matrix, http_client.clone()).await?;
if !self.no_worker {
let mailer = mailer_from_config(&config.email, &templates)?;
test_mailer_in_background(&mailer, Duration::from_secs(30));
info!("Starting task worker");
mas_tasks::init_and_run(
PgRepositoryFactory::new(pool.clone()),
SystemClock::default(),
&mailer,
homeserver_connection.clone(),
url_builder.clone(),
&site_config,
shutdown.soft_shutdown_token(),
shutdown.task_tracker(),
)
.await?;
}
let listeners_config = config.http.listeners.clone();
let password_manager = password_manager_from_config(&config.passwords).await?;
// The upstream OIDC metadata cache
let metadata_cache = MetadataCache::new();
// Initialize the activity tracker
// Activity is flushed every minute
let activity_tracker = ActivityTracker::new(
PgRepositoryFactory::new(pool.clone()).boxed(),
Duration::from_secs(60),
shutdown.task_tracker(),
shutdown.soft_shutdown_token(),
);
shutdown.register_reloadable(&activity_tracker);
let trusted_proxies = config.http.trusted_proxies.clone();
// Build a rate limiter.
// This should not raise an error here as the config should already have been
// validated.
let limiter = Limiter::new(&config.rate_limiting)
.context("rate-limiting configuration is not valid")?;
// Explicitly the config to properly zeroize secret keys
drop(config);
limiter.start();
let graphql_schema = mas_handlers::graphql_schema(
PgRepositoryFactory::new(pool.clone()).boxed(),
&policy_factory,
homeserver_connection.clone(),
site_config.clone(),
password_manager.clone(),
url_builder.clone(),
limiter.clone(),
);
let state = {
let mut s = AppState {
repository_factory: PgRepositoryFactory::new(pool),
templates,
key_store,
cookie_manager,
encrypter,
url_builder,
homeserver_connection,
policy_factory,
graphql_schema,
http_client,
password_manager,
metadata_cache,
site_config,
activity_tracker,
trusted_proxies,
limiter,
};
s.init_metrics();
s.init_metadata_cache();
s
};
let mut fd_manager = listenfd::ListenFd::from_env();
let servers: Vec<Server<_>> = listeners_config
.into_iter()
.map(|config| {
// Let's first grab all the listeners
let listeners = crate::server::build_listeners(&mut fd_manager, &config.binds)?;
// Load the TLS config
let tls_config = if let Some(tls_config) = config.tls.as_ref() {
let tls_config = crate::server::build_tls_server_config(tls_config)?;
Some(Arc::new(tls_config))
} else {
None
};
// and build the router
let router = crate::server::build_router(
state.clone(),
&config.resources,
config.prefix.as_deref(),
config.name.as_deref(),
);
// Display some informations about where we'll be serving connections
let proto = if config.tls.is_some() { "https" } else { "http" };
let prefix = config.prefix.unwrap_or_default();
let addresses= listeners
.iter()
.map(|listener| {
if let Ok(addr) = listener.local_addr() {
format!("{proto}://{addr:?}{prefix}")
} else {
warn!("Could not get local address for listener, something might be wrong!");
format!("{proto}://???{prefix}")
}
})
.join(", ");
let additional = if config.proxy_protocol {
"(with Proxy Protocol)"
} else {
""
};
info!(
"Listening on {addresses} with resources {resources:?} {additional}",
resources = &config.resources
);
anyhow::Ok(listeners.into_iter().map(move |listener| {
let mut server = Server::new(listener, router.clone());
if let Some(tls_config) = &tls_config {
server = server.with_tls(tls_config.clone());
}
if config.proxy_protocol {
server = server.with_proxy();
}
server
}))
})
.flatten_ok()
.collect::<Result<Vec<_>, _>>()?;
span.exit();
shutdown
.task_tracker()
.spawn(LogContext::new("run-servers").run(|| {
mas_listener::server::run_servers(
servers,
shutdown.soft_shutdown_token(),
shutdown.hard_shutdown_token(),
)
}));
let exit_code = shutdown.run().await;
Ok(exit_code)
}
}

View file

@ -0,0 +1,319 @@
// Copyright 2024, 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{collections::HashMap, process::ExitCode, time::Duration};
use anyhow::Context;
use camino::Utf8PathBuf;
use clap::Parser;
use figment::Figment;
use mas_config::{
ConfigurationSection, ConfigurationSectionExt, DatabaseConfig, MatrixConfig, SyncConfig,
UpstreamOAuth2Config,
};
use mas_data_model::SystemClock;
use rand::thread_rng;
use sqlx::{Connection, Either, PgConnection, postgres::PgConnectOptions, types::Uuid};
use syn2mas::{
LockedMasDatabase, MasWriter, Progress, ProgressStage, SynapseReader, synapse_config,
};
use tracing::{Instrument, error, info};
use crate::util::{DatabaseConnectOptions, database_connection_from_config_with_options};
/// The exit code used by `syn2mas check` and `syn2mas migrate` when there are
/// errors preventing migration.
const EXIT_CODE_CHECK_ERRORS: u8 = 10;
/// The exit code used by `syn2mas check` when there are warnings which should
/// be considered prior to migration.
const EXIT_CODE_CHECK_WARNINGS: u8 = 11;
#[derive(Parser, Debug)]
pub(super) struct Options {
#[command(subcommand)]
subcommand: Subcommand,
/// Path to the Synapse configuration (in YAML format).
/// May be specified multiple times if multiple Synapse configuration files
/// are in use.
#[clap(long = "synapse-config", global = true)]
synapse_configuration_files: Vec<Utf8PathBuf>,
/// Override the Synapse database URI.
/// syn2mas normally loads the Synapse database connection details from the
/// Synapse configuration. However, it may sometimes be necessary to
/// override the database URI and in that case this flag can be used.
///
/// Should be a connection URI of the following general form:
/// ```text
/// postgresql://[user[:password]@][host][:port][/dbname][?param1=value1&...]
/// ```
/// To use a UNIX socket at a custom path, the host should be a path to a
/// socket, but in the URI string it must be URI-encoded by replacing
/// `/` with `%2F`.
///
/// Finally, any missing values will be loaded from the libpq-compatible
/// environment variables `PGHOST`, `PGPORT`, `PGUSER`, `PGDATABASE`,
/// `PGPASSWORD`, etc. It is valid to specify the URL `postgresql:` and
/// configure all values through those environment variables.
#[clap(long = "synapse-database-uri", global = true)]
synapse_database_uri: Option<PgConnectOptions>,
/// Make missing auth providers in Synapse config warnings instead of
/// errors. If this flag is set, and we find `auth_provider` values in
/// the Synapse `user_external_ids` table, that are not configured in
/// the Synapse OIDC configuration, instead of erroring we will just
/// output warnings.
#[clap(long = "ignore-missing-auth-providers", global = true)]
ignore_missing_auth_providers: bool,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Check the setup for potential problems before running a migration.
///
/// It is OK for Synapse to be online during these checks.
Check,
/// Perform a migration. Synapse must be offline during this process.
Migrate {
/// Perform a dry-run migration, which is safe to run with Synapse
/// running, and will restore the MAS database to an empty state.
///
/// This still *does* write to the MAS database, making it more
/// realistic compared to the final migration.
#[clap(long)]
dry_run: bool,
},
}
/// The number of parallel writing transactions active against the MAS database.
const NUM_WRITER_CONNECTIONS: usize = 8;
impl Options {
#[tracing::instrument("cli.syn2mas.run", skip_all)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
if self.synapse_configuration_files.is_empty() {
error!("Please specify the path to the Synapse configuration file(s).");
return Ok(ExitCode::FAILURE);
}
let synapse_config = synapse_config::Config::load(&self.synapse_configuration_files)
.map_err(anyhow::Error::from_boxed)
.context("Failed to load Synapse configuration")?;
// Establish a connection to Synapse's Postgres database
let syn_connection_options = if let Some(db_override) = self.synapse_database_uri {
db_override
} else {
synapse_config
.database
.to_sqlx_postgres()
.context("Synapse database configuration is invalid, cannot migrate.")?
};
let mut syn_conn = PgConnection::connect_with(&syn_connection_options)
.await
.context("could not connect to Synapse Postgres database")?;
let config =
DatabaseConfig::extract_or_default(figment).map_err(anyhow::Error::from_boxed)?;
let mut mas_connection = database_connection_from_config_with_options(
&config,
&DatabaseConnectOptions {
log_slow_statements: false,
},
)
.await?;
mas_storage_pg::migrate(&mut mas_connection)
.await
.context("could not run migrations")?;
if matches!(&self.subcommand, Subcommand::Migrate { .. }) {
// First perform a config sync
// This is crucial to ensure we register upstream OAuth providers
// in the MAS database
let config = SyncConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let clock = SystemClock::default();
let encrypter = config.secrets.encrypter().await?;
crate::sync::config_sync(
config.upstream_oauth2,
config.clients,
&mut mas_connection,
&encrypter,
&clock,
// Don't prune — we don't want to be unnecessarily destructive
false,
// Not a dry run — we do want to create the providers in the database
false,
)
.await
.context("could not sync the configuration with the database")?;
}
let Either::Left(mut mas_connection) = LockedMasDatabase::try_new(mas_connection)
.await
.context("failed to issue query to lock database")?
else {
error!("Failed to acquire syn2mas lock on the database.");
error!("This likely means that another syn2mas instance is already running!");
return Ok(ExitCode::FAILURE);
};
// Check configuration
let (mut check_warnings, mut check_errors) = syn2mas::synapse_config_check(&synapse_config);
{
let (extra_warnings, extra_errors) =
syn2mas::synapse_config_check_against_mas_config(&synapse_config, figment).await?;
check_warnings.extend(extra_warnings);
check_errors.extend(extra_errors);
}
// Check databases
syn2mas::mas_pre_migration_checks(&mut mas_connection).await?;
{
let (extra_warnings, extra_errors) = syn2mas::synapse_database_check(
&mut syn_conn,
&synapse_config,
figment,
self.ignore_missing_auth_providers,
)
.await?;
check_warnings.extend(extra_warnings);
check_errors.extend(extra_errors);
}
// Display errors and warnings
if !check_errors.is_empty() {
eprintln!("\n\n===== Errors =====");
eprintln!("These issues prevent migrating from Synapse to MAS right now:\n");
for error in &check_errors {
eprintln!("{error}\n");
}
}
if !check_warnings.is_empty() {
eprintln!("\n\n===== Warnings =====");
eprintln!(
"These potential issues should be considered before migrating from Synapse to MAS right now:\n"
);
for warning in &check_warnings {
eprintln!("{warning}\n");
}
}
// Do not proceed if there are any errors
if !check_errors.is_empty() {
return Ok(ExitCode::from(EXIT_CODE_CHECK_ERRORS));
}
match self.subcommand {
Subcommand::Check => {
if !check_warnings.is_empty() {
return Ok(ExitCode::from(EXIT_CODE_CHECK_WARNINGS));
}
println!("Check completed successfully with no errors or warnings.");
Ok(ExitCode::SUCCESS)
}
Subcommand::Migrate { dry_run } => {
let provider_id_mappings: HashMap<String, Uuid> = {
let mas_oauth2 = UpstreamOAuth2Config::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
mas_oauth2
.providers
.iter()
.filter_map(|provider| {
let synapse_idp_id = provider.synapse_idp_id.clone()?;
Some((synapse_idp_id, Uuid::from(provider.id)))
})
.collect()
};
// TODO how should we handle warnings at this stage?
let reader = SynapseReader::new(&mut syn_conn, dry_run).await?;
let writer_mas_connections =
futures_util::future::try_join_all((0..NUM_WRITER_CONNECTIONS).map(|_| {
database_connection_from_config_with_options(
&config,
&DatabaseConnectOptions {
log_slow_statements: false,
},
)
}))
.instrument(tracing::info_span!("syn2mas.mas_writer_connections"))
.await?;
let writer =
MasWriter::new(mas_connection, writer_mas_connections, dry_run).await?;
let clock = SystemClock::default();
// TODO is this rng ok?
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();
let progress = Progress::default();
let occasional_progress_logger_task =
tokio::spawn(occasional_progress_logger(progress.clone()));
let mas_matrix =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
syn2mas::migrate(
reader,
writer,
mas_matrix.homeserver,
&clock,
&mut rng,
provider_id_mappings,
&progress,
self.ignore_missing_auth_providers,
)
.await?;
occasional_progress_logger_task.abort();
Ok(ExitCode::SUCCESS)
}
}
}
}
/// Logs progress every 5 seconds, as a lightweight alternative to a progress
/// bar. For most deployments, the migration will not take 5 seconds so this
/// will not be relevant. In other cases, this will give the operator an idea of
/// what's going on.
async fn occasional_progress_logger(progress: Progress) {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
match &**progress.get_current_stage() {
ProgressStage::SettingUp => {
info!(name: "progress", "still setting up");
}
ProgressStage::MigratingData {
entity,
counter,
approx_count,
} => {
let migrated = counter.migrated();
let skipped = counter.skipped();
#[allow(clippy::cast_precision_loss)]
let percent = (f64::from(migrated + skipped) / *approx_count as f64) * 100.0;
info!(name: "progress", "migrating {entity}: {migrated} ({skipped} skipped) /~{approx_count} (~{percent:.1}%)");
}
ProgressStage::RebuildIndex { index_name } => {
info!(name: "progress", "still waiting for rebuild of index {index_name}");
}
ProgressStage::RebuildConstraint { constraint_name } => {
info!(name: "progress", "still waiting for rebuild of constraint {constraint_name}");
}
}
}
}

View file

@ -0,0 +1,145 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{fmt::Write, process::ExitCode};
use anyhow::{Context as _, bail};
use camino::Utf8PathBuf;
use chrono::DateTime;
use clap::Parser;
use figment::Figment;
use mas_config::{
AccountConfig, BrandingConfig, CaptchaConfig, ConfigurationSection, ConfigurationSectionExt,
ExperimentalConfig, MatrixConfig, PasswordsConfig, TemplatesConfig,
};
use mas_data_model::{Clock, SystemClock};
use rand::SeedableRng;
use tracing::info_span;
use crate::util::{site_config_from_config, templates_from_config};
#[derive(Parser, Debug)]
pub(super) struct Options {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser, Debug)]
enum Subcommand {
/// Check that the templates specified in the config are valid
Check {
/// If set, templates will be rendered to this directory.
/// The directory must either not exist or be empty.
#[arg(long = "out-dir")]
out_dir: Option<Utf8PathBuf>,
/// Attempt to remove 'unstable' template input data such as asset
/// hashes, in order to make renders more reproducible between
/// versions.
#[arg(long = "stabilise")]
stabilise: bool,
},
}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as SC;
match self.subcommand {
SC::Check { out_dir, stabilise } => {
let _span = info_span!("cli.templates.check").entered();
let template_config = TemplatesConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let branding_config = BrandingConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let matrix_config =
MatrixConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
let experimental_config = ExperimentalConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let password_config = PasswordsConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let account_config = AccountConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let captcha_config = CaptchaConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let now = if stabilise {
DateTime::from_timestamp_secs(1_446_823_992).unwrap()
} else {
SystemClock::default().now()
};
let rng = if stabilise {
rand_chacha::ChaChaRng::from_seed([42; 32])
} else {
// XXX: we should disallow SeedableRng::from_entropy
rand_chacha::ChaChaRng::from_entropy()
};
let url_builder =
mas_router::UrlBuilder::new("https://example.com/".parse()?, None, None);
let site_config = site_config_from_config(
&branding_config,
&matrix_config,
&experimental_config,
&password_config,
&account_config,
&captcha_config,
)?;
let templates = templates_from_config(
&template_config,
&site_config,
&url_builder,
// Use strict mode in template checks
true,
stabilise,
)
.await?;
let all_renders = templates.check_render(now, &rng)?;
if let Some(out_dir) = out_dir {
// Save renders to disk.
if out_dir.exists() {
let mut read_dir =
tokio::fs::read_dir(&out_dir).await.with_context(|| {
format!("could not read {out_dir} to check it's empty")
})?;
if read_dir.next_entry().await?.is_some() {
bail!("Render directory {out_dir} is not empty, refusing to write.");
}
} else {
tokio::fs::create_dir(&out_dir)
.await
.with_context(|| format!("could not create {out_dir}"))?;
}
for ((template, sample_identifier), template_render) in &all_renders {
let (template_filename_base, template_ext) =
template.rsplit_once('.').unwrap_or((template, "txt"));
let template_filename_base = template_filename_base.replace('/', "_");
// Make a string like `-index=0-browser-session=0-locale=fr`
let sample_suffix = {
let mut s = String::new();
for (k, v) in &sample_identifier.components {
write!(s, "-{k}={v}")?;
}
s
};
let render_path = out_dir.join(format!(
"{template_filename_base}{sample_suffix}.{template_ext}"
));
tokio::fs::write(&render_path, template_render.as_bytes())
.await
.with_context(|| format!("could not write render to {render_path}"))?;
}
}
Ok(ExitCode::SUCCESS)
}
}
}
}

View file

@ -0,0 +1,93 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{process::ExitCode, time::Duration};
use clap::Parser;
use figment::Figment;
use mas_config::{AppConfig, ConfigurationSection};
use mas_data_model::SystemClock;
use mas_router::UrlBuilder;
use mas_storage_pg::PgRepositoryFactory;
use tracing::{info, info_span};
use crate::{
lifecycle::LifecycleManager,
util::{
database_pool_from_config, homeserver_connection_from_config, mailer_from_config,
site_config_from_config, templates_from_config, test_mailer_in_background,
},
};
#[derive(Parser, Debug, Default)]
pub(super) struct Options {}
impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
let shutdown = LifecycleManager::new()?;
let span = info_span!("cli.worker.init").entered();
let config = AppConfig::extract(figment).map_err(anyhow::Error::from_boxed)?;
// Connect to the database
info!("Connecting to the database");
let pool = database_pool_from_config(&config.database).await?;
let url_builder = UrlBuilder::new(
config.http.public_base.clone(),
config.http.issuer.clone(),
None,
);
// Load the site configuration
let site_config = site_config_from_config(
&config.branding,
&config.matrix,
&config.experimental,
&config.passwords,
&config.account,
&config.captcha,
)?;
// Load and compile the templates
let templates = templates_from_config(
&config.templates,
&site_config,
&url_builder,
// Don't use strict mode on task workers for now
false,
// Don't stabilise in production
false,
)
.await?;
let mailer = mailer_from_config(&config.email, &templates)?;
test_mailer_in_background(&mailer, Duration::from_secs(30));
let http_client = mas_http::reqwest_client();
let conn = homeserver_connection_from_config(&config.matrix, http_client).await?;
drop(config);
info!("Starting task scheduler");
mas_tasks::init_and_run(
PgRepositoryFactory::new(pool.clone()),
SystemClock::default(),
&mailer,
conn,
url_builder,
&site_config,
shutdown.soft_shutdown_token(),
shutdown.task_tracker(),
)
.await?;
span.exit();
let exit_code = shutdown.run().await;
Ok(exit_code)
}
}

View file

@ -0,0 +1,239 @@
// Copyright 2024, 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{process::ExitCode, time::Duration};
use futures_util::future::{BoxFuture, Either};
use mas_handlers::ActivityTracker;
use mas_templates::Templates;
use tokio::signal::unix::{Signal, SignalKind};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
/// A helper to manage the lifecycle of the service, inclusing handling graceful
/// shutdowns and configuration reloads.
///
/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft
/// shutdown on the first signal, and a hard shutdown on the second signal or
/// after a timeout.
///
/// Users of this manager should use the `soft_shutdown_token` to react to a
/// soft shutdown, which should gracefully finish requests and close
/// connections, and the `hard_shutdown_token` to react to a hard shutdown,
/// which should drop all connections and finish all requests.
///
/// They should also use the `task_tracker` to make it track things running, so
/// that it knows when the soft shutdown is over and worked.
///
/// It also integrates with [`sd_notify`] to notify the service manager of the
/// state of the service.
pub struct LifecycleManager {
hard_shutdown_token: CancellationToken,
soft_shutdown_token: CancellationToken,
task_tracker: TaskTracker,
sigterm: Signal,
sigint: Signal,
sighup: Signal,
timeout: Duration,
reload_handlers: Vec<Box<dyn Fn() -> BoxFuture<'static, ()>>>,
}
/// Represents a thing that can be reloaded with a SIGHUP
pub trait Reloadable: Clone + Send {
fn reload(&self) -> impl Future<Output = ()> + Send;
}
impl Reloadable for ActivityTracker {
async fn reload(&self) {
self.flush().await;
}
}
impl Reloadable for Templates {
async fn reload(&self) {
if let Err(err) = self.reload().await {
tracing::error!(
error = &err as &dyn std::error::Error,
"Failed to reload templates"
);
}
}
}
/// A wrapper around [`sd_notify::notify`] that logs any errors
fn notify(states: &[sd_notify::NotifyState]) {
if let Err(e) = sd_notify::notify(false, states) {
tracing::error!(
error = &e as &dyn std::error::Error,
"Failed to notify service manager"
);
}
}
impl LifecycleManager {
/// Create a new shutdown manager, installing the signal handlers
///
/// # Errors
///
/// Returns an error if the signal handler could not be installed
pub fn new() -> Result<Self, std::io::Error> {
let hard_shutdown_token = CancellationToken::new();
let soft_shutdown_token = hard_shutdown_token.child_token();
let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?;
let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
let sighup = tokio::signal::unix::signal(SignalKind::hangup())?;
let timeout = Duration::from_secs(60);
let task_tracker = TaskTracker::new();
notify(&[sd_notify::NotifyState::MainPid(std::process::id())]);
Ok(Self {
hard_shutdown_token,
soft_shutdown_token,
task_tracker,
sigterm,
sigint,
sighup,
timeout,
reload_handlers: Vec::new(),
})
}
/// Add a handler to be called when the server gets a SIGHUP
pub fn register_reloadable(&mut self, reloadable: &(impl Reloadable + 'static)) {
let reloadable = reloadable.clone();
self.reload_handlers.push(Box::new(move || {
let reloadable = reloadable.clone();
Box::pin(async move { reloadable.reload().await })
}));
}
/// Get a reference to the task tracker
#[must_use]
pub fn task_tracker(&self) -> &TaskTracker {
&self.task_tracker
}
/// Get a cancellation token that can be used to react to a hard shutdown
#[must_use]
pub fn hard_shutdown_token(&self) -> CancellationToken {
self.hard_shutdown_token.clone()
}
/// Get a cancellation token that can be used to react to a soft shutdown
#[must_use]
pub fn soft_shutdown_token(&self) -> CancellationToken {
self.soft_shutdown_token.clone()
}
/// Run until we finish completely shutting down.
pub async fn run(mut self) -> ExitCode {
notify(&[sd_notify::NotifyState::Ready]);
// This will be `Some` if we have the watchdog enabled, and `None` if not
let mut watchdog_interval = {
let mut watchdog_usec = 0;
if sd_notify::watchdog_enabled(false, &mut watchdog_usec) {
Some(tokio::time::interval(Duration::from_micros(
watchdog_usec / 2,
)))
} else {
None
}
};
// Wait for a first shutdown signal and trigger the soft shutdown
let likely_crashed = loop {
// This makes a Future that will either yield the watchdog tick if enabled, or a
// pending Future if not
let watchdog_tick = if let Some(watchdog_interval) = &mut watchdog_interval {
Either::Left(watchdog_interval.tick())
} else {
Either::Right(futures_util::future::pending())
};
tokio::select! {
() = self.soft_shutdown_token.cancelled() => {
tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down");
break true;
},
_ = self.sigterm.recv() => {
tracing::info!("Shutdown signal received (SIGTERM), shutting down");
break false;
},
_ = self.sigint.recv() => {
tracing::info!("Shutdown signal received (SIGINT), shutting down");
break false;
},
_ = watchdog_tick => {
notify(&[
sd_notify::NotifyState::Watchdog,
]);
},
_ = self.sighup.recv() => {
tracing::info!("Reload signal received (SIGHUP), reloading");
notify(&[
sd_notify::NotifyState::Reloading,
sd_notify::NotifyState::monotonic_usec_now()
.expect("Failed to read monotonic clock")
]);
// XXX: if one handler takes a long time, it will block the
// rest of the shutdown process, which is not ideal. We
// should probably have a timeout here
futures_util::future::join_all(
self.reload_handlers
.iter()
.map(|handler| handler())
).await;
notify(&[sd_notify::NotifyState::Ready]);
tracing::info!("Reloading done");
},
}
};
notify(&[sd_notify::NotifyState::Stopping]);
self.soft_shutdown_token.cancel();
self.task_tracker.close();
// Start the timeout
let timeout = tokio::time::sleep(self.timeout);
tokio::select! {
_ = self.sigterm.recv() => {
tracing::warn!("Second shutdown signal received (SIGTERM), abort");
},
_ = self.sigint.recv() => {
tracing::warn!("Second shutdown signal received (SIGINT), abort");
},
() = timeout => {
tracing::warn!("Shutdown timeout reached, abort");
},
() = self.task_tracker.wait() => {
// This is the "happy path", we have gracefully shutdown
},
}
self.hard_shutdown_token().cancel();
// TODO: we may want to have a time out on the task tracker, in case we have
// really stuck tasks on it
self.task_tracker().wait().await;
tracing::info!("All tasks are done, exitting");
if likely_crashed {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}
}

View file

@ -0,0 +1,181 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
#![allow(clippy::module_name_repetitions)]
use std::{io::IsTerminal, process::ExitCode, sync::Arc};
use anyhow::Context;
use clap::Parser;
use mas_config::{ConfigurationSectionExt, TelemetryConfig};
use sentry_tracing::EventFilter;
use tracing_subscriber::{
EnvFilter, Layer, Registry,
filter::{LevelFilter, filter_fn},
layer::SubscriberExt,
util::SubscriberInitExt,
};
mod app_state;
mod commands;
mod lifecycle;
mod server;
mod sync;
mod telemetry;
mod util;
/// The application version, as reported by `git describe` at build time
static VERSION: &str = env!("VERGEN_GIT_DESCRIBE");
#[derive(Debug)]
struct SentryTransportFactory {
client: reqwest::Client,
}
impl SentryTransportFactory {
fn new() -> Self {
Self {
client: mas_http::reqwest_client(),
}
}
}
impl sentry::TransportFactory for SentryTransportFactory {
fn create_transport(&self, options: &sentry::ClientOptions) -> Arc<dyn sentry::Transport> {
let transport =
sentry::transports::ReqwestHttpTransport::with_client(options, self.client.clone());
Arc::new(transport)
}
}
fn main() -> anyhow::Result<ExitCode> {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
#[cfg(tokio_unstable)]
builder
.enable_metrics_poll_time_histogram()
.metrics_poll_time_histogram_configuration(tokio::runtime::HistogramConfiguration::log(
tokio::runtime::LogHistogram::default(),
));
let runtime = builder.build()?;
runtime.block_on(async_main())
}
async fn async_main() -> anyhow::Result<ExitCode> {
// We're splitting the "fallible" part of main in another function to have a
// chance to shutdown the telemetry exporters regardless of if there was an
// error or not
let res = try_main().await;
if let Err(err) = self::telemetry::shutdown() {
eprintln!("Failed to shutdown telemetry exporters: {err}");
}
res
}
async fn try_main() -> anyhow::Result<ExitCode> {
// Load environment variables from .env files
// We keep the path to log it afterwards
let dotenv_path: Result<Option<_>, _> = dotenvy::dotenv()
.map(Some)
// Display the error if it is something other than the .env file not existing
.or_else(|e| if e.not_found() { Ok(None) } else { Err(e) });
// Setup logging
// This writes logs to stderr
let output = std::io::stderr();
let with_ansi = output.is_terminal();
let (log_writer, _guard) = tracing_appender::non_blocking(output);
let fmt_layer = tracing_subscriber::fmt::layer()
.with_writer(log_writer)
.event_format(mas_context::EventFormatter)
.with_ansi(with_ansi);
let filter_layer = EnvFilter::try_from_default_env()
.or_else(|_| EnvFilter::try_new("info"))
.context("could not setup logging filter")?;
// Suppress the following warning from the Jaeger propagator:
// Invalid jaeger header format header_value=""
let suppress_layer = filter_fn(|metadata| metadata.name() != "JaegerPropagator.InvalidHeader");
// Setup the rustls crypto provider
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.map_err(|_| anyhow::anyhow!("could not install the AWS LC crypto provider"))?;
// Parse the CLI arguments
let opts = self::commands::Options::parse();
// Load the base configuration files
let figment = opts.figment();
let telemetry_config = TelemetryConfig::extract_or_default(&figment)
.map_err(anyhow::Error::from_boxed)
.context("Failed to load telemetry config")?;
// Setup Sentry
let sentry = sentry::init((
telemetry_config.sentry.dsn.as_deref(),
sentry::ClientOptions {
transport: Some(Arc::new(SentryTransportFactory::new())),
environment: telemetry_config.sentry.environment.clone().map(Into::into),
release: Some(VERSION.into()),
sample_rate: telemetry_config.sentry.sample_rate.unwrap_or(1.0),
traces_sample_rate: telemetry_config.sentry.traces_sample_rate.unwrap_or(0.0),
..Default::default()
},
));
let sentry_layer = sentry.is_enabled().then(|| {
sentry_tracing::layer().event_filter(|md| {
// By default, Sentry records all events as breadcrumbs, except errors.
//
// Because we're emitting error events for 5xx responses, we need to exclude
// them and also record them as breadcrumbs.
if md.name() == "http.server.response" {
EventFilter::Breadcrumb
} else {
sentry_tracing::default_event_filter(md)
}
})
});
// Setup OpenTelemetry tracing and metrics
self::telemetry::setup(&telemetry_config).context("failed to setup OpenTelemetry")?;
let tracer = self::telemetry::TRACER
.get()
.context("TRACER was not set")?;
let telemetry_layer = tracing_opentelemetry::layer()
.with_tracer(tracer.clone())
.with_tracked_inactivity(false)
.with_filter(LevelFilter::INFO);
let subscriber = Registry::default()
.with(suppress_layer)
.with(sentry_layer)
.with(telemetry_layer)
.with(filter_layer)
.with(fmt_layer);
subscriber
.try_init()
.context("could not initialize logging")?;
// Log about the .env loading
match dotenv_path {
Ok(Some(path)) => tracing::info!(?path, "Loaded environment variables from .env file"),
Ok(None) => {}
Err(e) => tracing::warn!(?e, "Failed to load .env file"),
}
// And run the command
tracing::trace!(?opts, "Running command");
opts.run(&figment).await
}

View file

@ -0,0 +1,429 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
os::unix::net::UnixListener,
time::Duration,
};
use anyhow::Context;
use axum::{
Extension, Router,
extract::{FromRef, MatchedPath},
};
use headers::{CacheControl, HeaderMapExt as _, UserAgent};
use hyper::{Method, Request, Response, StatusCode, Version, header::USER_AGENT};
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
use mas_context::LogContext;
use mas_listener::{ConnectionInfo, unix_or_tcp::UnixOrTcpListener};
use mas_router::Route;
use mas_templates::Templates;
use mas_tower::{
DurationRecorderLayer, InFlightCounterLayer, KV, TraceLayer, make_span_fn,
metrics_attributes_fn,
};
use opentelemetry::{Key, KeyValue};
use opentelemetry_http::HeaderExtractor;
use opentelemetry_semantic_conventions::trace::{
HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, HTTP_ROUTE, NETWORK_PROTOCOL_NAME,
NETWORK_PROTOCOL_VERSION, URL_PATH, URL_QUERY, URL_SCHEME, USER_AGENT_ORIGINAL,
};
use rustls::ServerConfig;
use sentry_tower::{NewSentryLayer, SentryHttpLayer};
use tower::Layer;
use tower_http::services::{ServeDir, fs::ServeFileSystemResponseBody};
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::app_state::AppState;
const MAS_LISTENER_NAME: Key = Key::from_static_str("mas.listener.name");
#[inline]
fn otel_http_method<B>(request: &Request<B>) -> &'static str {
match request.method() {
&Method::OPTIONS => "OPTIONS",
&Method::GET => "GET",
&Method::POST => "POST",
&Method::PUT => "PUT",
&Method::DELETE => "DELETE",
&Method::HEAD => "HEAD",
&Method::TRACE => "TRACE",
&Method::CONNECT => "CONNECT",
&Method::PATCH => "PATCH",
_other => "_OTHER",
}
}
#[inline]
fn otel_net_protocol_version<B>(request: &Request<B>) -> &'static str {
match request.version() {
Version::HTTP_09 => "0.9",
Version::HTTP_10 => "1.0",
Version::HTTP_11 => "1.1",
Version::HTTP_2 => "2.0",
Version::HTTP_3 => "3.0",
_other => "_OTHER",
}
}
fn otel_http_route<B>(request: &Request<B>) -> Option<&str> {
request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str)
}
fn otel_url_scheme<B>(request: &Request<B>) -> &'static str {
// XXX: maybe we should panic if the connection info was not injected in the
// request extensions
request
.extensions()
.get::<ConnectionInfo>()
.map_or("http", |conn_info| {
if conn_info.get_tls_ref().is_some() {
"https"
} else {
"http"
}
})
}
fn make_http_span<B>(req: &Request<B>) -> Span {
let method = otel_http_method(req);
let route = otel_http_route(req);
let span_name = if let Some(route) = route.as_ref() {
format!("{method} {route}")
} else {
method.to_owned()
};
let span = tracing::info_span!(
"http.server.request",
"otel.kind" = "server",
"otel.name" = span_name,
"otel.status_code" = tracing::field::Empty,
{ NETWORK_PROTOCOL_NAME } = "http",
{ NETWORK_PROTOCOL_VERSION } = otel_net_protocol_version(req),
{ HTTP_REQUEST_METHOD } = method,
{ HTTP_ROUTE } = tracing::field::Empty,
{ HTTP_RESPONSE_STATUS_CODE } = tracing::field::Empty,
{ URL_PATH } = req.uri().path(),
{ URL_QUERY } = tracing::field::Empty,
{ URL_SCHEME } = otel_url_scheme(req),
{ USER_AGENT_ORIGINAL } = tracing::field::Empty,
);
if let Some(route) = route.as_ref() {
span.record(HTTP_ROUTE, route);
}
if let Some(query) = req.uri().query() {
span.record(URL_QUERY, query);
}
if let Some(user_agent) = req
.headers()
.get(USER_AGENT)
.and_then(|ua| ua.to_str().ok())
{
span.record(USER_AGENT_ORIGINAL, user_agent);
}
// In case the span is disabled by any of tracing layers, e.g. if `RUST_LOG`
// is set to `warn`, `set_parent` will fail. So we only try to set the
// parent context if the span is not disabled.
if !span.is_disabled() {
// Extract the parent span context from the request headers
let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| {
let extractor = HeaderExtractor(req.headers());
let context = opentelemetry::Context::new();
propagator.extract_with_context(&context, &extractor)
});
if let Err(err) = span.set_parent(parent_context) {
tracing::error!(
error = &err as &dyn std::error::Error,
"Failed to set parent context on span"
);
}
}
span
}
fn on_http_request_labels<B>(request: &Request<B>) -> Vec<KeyValue> {
vec![
KeyValue::new(NETWORK_PROTOCOL_NAME, "http"),
KeyValue::new(NETWORK_PROTOCOL_VERSION, otel_net_protocol_version(request)),
KeyValue::new(HTTP_REQUEST_METHOD, otel_http_method(request)),
KeyValue::new(
HTTP_ROUTE,
otel_http_route(request).unwrap_or("FALLBACK").to_owned(),
),
KeyValue::new(URL_SCHEME, otel_url_scheme(request)),
]
}
fn on_http_response_labels<B>(res: &Response<B>) -> Vec<KeyValue> {
vec![KeyValue::new(
HTTP_RESPONSE_STATUS_CODE,
i64::from(res.status().as_u16()),
)]
}
async fn log_response_middleware(
request: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
let user_agent: Option<UserAgent> = request.headers().typed_get();
let user_agent = user_agent.as_ref().map_or("-", |u| u.as_str());
let method = otel_http_method(&request);
let path = request.uri().path().to_owned();
let version = otel_net_protocol_version(&request);
let response = next.run(request).await;
let Some(stats) = LogContext::maybe_with(LogContext::stats) else {
tracing::error!("Missing log context for request, this is a bug!");
return response;
};
let status_code = response.status();
match status_code.as_u16() {
100..=399 => tracing::info!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
400..=499 => tracing::warn!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
500..=599 => tracing::error!(
name: "http.server.response",
"\"{method} {path} HTTP/{version}\" {status_code} {user_agent:?} [{stats}]",
),
_ => { /* This shouldn't happen */ }
}
response
}
pub fn build_router(
state: AppState,
resources: &[HttpResource],
prefix: Option<&str>,
name: Option<&str>,
) -> Router<()> {
let templates = Templates::from_ref(&state);
let mut router = Router::new();
for resource in resources {
router = match resource {
mas_config::HttpResource::Health => {
router.merge(mas_handlers::healthcheck_router::<AppState>())
}
mas_config::HttpResource::Prometheus => {
router.route_service("/metrics", crate::telemetry::prometheus_service())
}
mas_config::HttpResource::Discovery => {
router.merge(mas_handlers::discovery_router::<AppState>())
}
mas_config::HttpResource::Human => {
router.merge(mas_handlers::human_router::<AppState>(templates.clone()))
}
mas_config::HttpResource::GraphQL {
playground,
undocumented_oauth2_access,
} => router.merge(mas_handlers::graphql_router::<AppState>(
*playground,
*undocumented_oauth2_access,
)),
mas_config::HttpResource::Assets { path } => {
let static_service = ServeDir::new(path)
.append_index_html_on_directories(false)
// The vite build pre-compresses assets with brotli and gzip
.precompressed_br()
.precompressed_gzip();
let add_cache_headers = axum::middleware::map_response(
async |mut res: Response<ServeFileSystemResponseBody>| {
let cache_control = if res.status() == StatusCode::NOT_FOUND {
// Cache 404s for 5 minutes
CacheControl::new()
.with_public()
.with_max_age(Duration::from_secs(5 * 60))
} else {
// Cache assets for 1 year
CacheControl::new()
.with_public()
.with_max_age(Duration::from_secs(365 * 24 * 60 * 60))
.with_immutable()
};
res.headers_mut().typed_insert(cache_control);
res
},
);
router.nest_service(
mas_router::StaticAsset::route(),
add_cache_headers.layer(static_service),
)
}
mas_config::HttpResource::OAuth => router.merge(mas_handlers::api_router::<AppState>()),
mas_config::HttpResource::Compat => {
router.merge(mas_handlers::compat_router::<AppState>(templates.clone()))
}
mas_config::HttpResource::AdminApi => {
let (_, api_router) = mas_handlers::admin_api_router::<AppState>();
router.merge(api_router)
}
// TODO: do a better handler here
mas_config::HttpResource::ConnectionInfo => router.route(
"/connection-info",
axum::routing::get(async |connection: Extension<ConnectionInfo>| {
format!("{connection:?}")
}),
),
}
}
// We normalize the prefix:
// - if it's None, it becomes '/'
// - if it's Some(..), any trailing '/' is first trimmed, then a '/' is added
let prefix = format!("{}/", prefix.unwrap_or_default().trim_end_matches('/'));
// Then we only nest the router if the prefix is not empty and not the root
// If we blindly nest the router if the prefix is Some("/"), axum will panic as
// we're not supposed to nest the router at the root
if !prefix.is_empty() && prefix != "/" {
router = Router::new().nest(&prefix, router);
}
router = router.fallback(mas_handlers::fallback);
router
.layer(axum::middleware::from_fn(log_response_middleware))
.layer(
InFlightCounterLayer::new("http.server.active_requests").on_request((
name.map(|name| KeyValue::new(MAS_LISTENER_NAME, name.to_owned())),
metrics_attributes_fn(on_http_request_labels),
)),
)
.layer(
DurationRecorderLayer::new("http.server.duration")
.on_request((
name.map(|name| KeyValue::new(MAS_LISTENER_NAME, name.to_owned())),
metrics_attributes_fn(on_http_request_labels),
))
.on_response_fn(on_http_response_labels),
)
.layer(
TraceLayer::new((
make_span_fn(make_http_span),
name.map(|name| KV("mas.listener.name", name.to_owned())),
))
.on_response_fn(|span: &Span, response: &Response<_>| {
let status_code = response.status().as_u16();
span.record("http.response.status_code", status_code);
span.record("otel.status_code", "OK");
}),
)
.layer(mas_context::LogContextLayer::new(|req| {
otel_http_method(req).into()
}))
// Careful about the order here: the `NewSentryLayer` must be around the
// `SentryHttpLayer`. axum makes new layers wrap the existing ones,
// which is the other way around compared to `tower::ServiceBuilder`.
// So even if the Sentry docs has an example that does
// 'NewSentryHttpLayer then SentryHttpLayer', we must do the opposite.
.layer(SentryHttpLayer::new().enable_transaction())
.layer(NewSentryLayer::new_from_top())
.with_state(state)
}
pub fn build_tls_server_config(config: &HttpTlsConfig) -> Result<ServerConfig, anyhow::Error> {
let (key, chain) = config.load()?;
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key)
.context("failed to build TLS server config")?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(config)
}
pub fn build_listeners(
fd_manager: &mut ListenFd,
configs: &[HttpBindConfig],
) -> Result<Vec<UnixOrTcpListener>, anyhow::Error> {
let mut listeners = Vec::with_capacity(configs.len());
for bind in configs {
let listener = match bind {
HttpBindConfig::Listen { host, port } => {
let addrs = match host.as_deref() {
Some(host) => (host, *port)
.to_socket_addrs()
.context("could not parse listener host")?
.collect(),
None => vec![
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), *port),
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), *port),
],
};
let listener = TcpListener::bind(&addrs[..]).context("could not bind address")?;
listener.set_nonblocking(true)?;
listener.try_into()?
}
HttpBindConfig::Address { address } => {
let addr: SocketAddr = address
.parse()
.context("could not parse listener address")?;
let listener = TcpListener::bind(addr).context("could not bind address")?;
listener.set_nonblocking(true)?;
listener.try_into()?
}
HttpBindConfig::Unix { socket } => {
let listener = UnixListener::bind(socket).context("could not bind socket")?;
listener.try_into()?
}
HttpBindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Tcp,
} => {
let listener = fd_manager
.take_tcp_listener(*fd)?
.context("no listener found on file descriptor")?;
listener.set_nonblocking(true)?;
listener.try_into()?
}
HttpBindConfig::FileDescriptor {
fd,
kind: UnixOrTcp::Unix,
} => {
let listener = fd_manager
.take_unix_listener(*fd)?
.context("no unix socket found on file descriptor")?;
listener.set_nonblocking(true)?;
listener.try_into()?
}
};
listeners.push(listener);
}
Ok(listeners)
}

View file

@ -0,0 +1,430 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
//! Utilities to synchronize the configuration file with the database.
use std::collections::{BTreeMap, BTreeSet};
use mas_config::{ClientsConfig, UpstreamOAuth2Config};
use mas_data_model::Clock;
use mas_keystore::Encrypter;
use mas_storage::{
Pagination, RepositoryAccess,
upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams},
};
use mas_storage_pg::PgRepository;
use sqlx::{Connection, PgConnection, postgres::PgAdvisoryLock};
use tracing::{error, info, info_span, warn};
fn map_import_action(
config: mas_config::UpstreamOAuth2ImportAction,
) -> mas_data_model::UpstreamOAuthProviderImportAction {
match config {
mas_config::UpstreamOAuth2ImportAction::Ignore => {
mas_data_model::UpstreamOAuthProviderImportAction::Ignore
}
mas_config::UpstreamOAuth2ImportAction::Suggest => {
mas_data_model::UpstreamOAuthProviderImportAction::Suggest
}
mas_config::UpstreamOAuth2ImportAction::Force => {
mas_data_model::UpstreamOAuthProviderImportAction::Force
}
mas_config::UpstreamOAuth2ImportAction::Require => {
mas_data_model::UpstreamOAuthProviderImportAction::Require
}
}
}
fn map_import_on_conflict(
config: mas_config::UpstreamOAuth2OnConflict,
) -> mas_data_model::UpstreamOAuthProviderOnConflict {
match config {
mas_config::UpstreamOAuth2OnConflict::Add => {
mas_data_model::UpstreamOAuthProviderOnConflict::Add
}
mas_config::UpstreamOAuth2OnConflict::Replace => {
mas_data_model::UpstreamOAuthProviderOnConflict::Replace
}
mas_config::UpstreamOAuth2OnConflict::Set => {
mas_data_model::UpstreamOAuthProviderOnConflict::Set
}
mas_config::UpstreamOAuth2OnConflict::Fail => {
mas_data_model::UpstreamOAuthProviderOnConflict::Fail
}
}
}
fn map_claims_imports(
config: &mas_config::UpstreamOAuth2ClaimsImports,
) -> mas_data_model::UpstreamOAuthProviderClaimsImports {
mas_data_model::UpstreamOAuthProviderClaimsImports {
subject: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.subject.template.clone(),
},
skip_confirmation: config.skip_confirmation,
localpart: mas_data_model::UpstreamOAuthProviderLocalpartPreference {
action: map_import_action(config.localpart.action),
template: config.localpart.template.clone(),
on_conflict: map_import_on_conflict(config.localpart.on_conflict),
},
displayname: mas_data_model::UpstreamOAuthProviderImportPreference {
action: map_import_action(config.displayname.action),
template: config.displayname.template.clone(),
},
email: mas_data_model::UpstreamOAuthProviderImportPreference {
action: map_import_action(config.email.action),
template: config.email.template.clone(),
},
account_name: mas_data_model::UpstreamOAuthProviderSubjectPreference {
template: config.account_name.template.clone(),
},
}
}
#[tracing::instrument(name = "config.sync", skip_all)]
pub async fn config_sync(
upstream_oauth2_config: UpstreamOAuth2Config,
clients_config: ClientsConfig,
connection: &mut PgConnection,
encrypter: &Encrypter,
clock: &dyn Clock,
prune: bool,
dry_run: bool,
) -> anyhow::Result<()> {
// Start a transaction
let txn = connection.begin().await?;
// Grab a lock within the transaction
tracing::info!("Acquiring configuration lock");
let lock = PgAdvisoryLock::new("MAS config sync");
let lock = lock.acquire(txn).await?;
// Create a repository from the connection with the lock
let mut repo = PgRepository::from_conn(lock);
tracing::info!(
prune,
dry_run,
"Syncing providers and clients defined in config to database"
);
{
let _span = info_span!("cli.config.sync.providers").entered();
let config_ids = upstream_oauth2_config
.providers
.iter()
.filter(|p| p.enabled)
.map(|p| p.id)
.collect::<BTreeSet<_>>();
// Let's assume we have less than 1000 providers
let page = repo
.upstream_oauth_provider()
.list(
UpstreamOAuthProviderFilter::default(),
Pagination::first(1000),
)
.await?;
// A warning is probably enough
if page.has_next_page {
warn!(
"More than 1000 providers in the database, only the first 1000 will be considered"
);
}
let mut existing_enabled_ids = BTreeSet::new();
let mut existing_disabled = BTreeMap::new();
// Process the existing providers
for edge in page.edges {
let provider = edge.node;
if provider.enabled() {
if config_ids.contains(&provider.id) {
existing_enabled_ids.insert(provider.id);
} else {
// Provider is enabled in the database but not in the config
info!(%provider.id, "Disabling provider");
let provider = if dry_run {
provider
} else {
repo.upstream_oauth_provider()
.disable(clock, provider)
.await?
};
existing_disabled.insert(provider.id, provider);
}
} else {
existing_disabled.insert(provider.id, provider);
}
}
if prune {
for provider_id in existing_disabled.keys().copied() {
info!(provider.id = %provider_id, "Deleting provider");
if dry_run {
continue;
}
repo.upstream_oauth_provider()
.delete_by_id(provider_id)
.await?;
}
} else {
let len = existing_disabled.len();
match len {
0 => {}
1 => warn!(
"A provider is soft-deleted in the database. Run `mas-cli config sync --prune` to delete it."
),
n => warn!(
"{n} providers are soft-deleted in the database. Run `mas-cli config sync --prune` to delete them."
),
}
}
for (index, provider) in upstream_oauth2_config.providers.into_iter().enumerate() {
if !provider.enabled {
continue;
}
// Use the position in the config of the provider as position in the UI
let ui_order = index.try_into().unwrap_or(i32::MAX);
let _span = info_span!("provider", %provider.id).entered();
if existing_enabled_ids.contains(&provider.id) {
info!(provider.id = %provider.id, "Updating provider");
} else if existing_disabled.contains_key(&provider.id) {
info!(provider.id = %provider.id, "Enabling and updating provider");
} else {
info!(provider.id = %provider.id, "Adding provider");
}
if dry_run {
continue;
}
let encrypted_client_secret = if let Some(client_secret) = provider.client_secret {
Some(encrypter.encrypt_to_string(client_secret.value().await?.as_bytes())?)
} else if let Some(mut siwa) = provider.sign_in_with_apple.clone() {
// if private key file is defined and not private key (raw), we populate the
// private key to hold the content of the private key file.
// private key (raw) takes precedence so both can be defined
// without issues
if siwa.private_key.is_none()
&& let Some(private_key_file) = siwa.private_key_file.take()
{
let key = tokio::fs::read_to_string(private_key_file).await?;
siwa.private_key = Some(key);
}
let encoded = serde_json::to_vec(&siwa)?;
Some(encrypter.encrypt_to_string(&encoded)?)
} else {
None
};
let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc
}
mas_config::UpstreamOAuth2DiscoveryMode::Insecure => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Insecure
}
mas_config::UpstreamOAuth2DiscoveryMode::Disabled => {
mas_data_model::UpstreamOAuthProviderDiscoveryMode::Disabled
}
};
let token_endpoint_auth_method = match provider.token_endpoint_auth_method {
mas_config::UpstreamOAuth2TokenAuthMethod::None => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::None
}
mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretBasic => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretBasic
}
mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretPost => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost
}
mas_config::UpstreamOAuth2TokenAuthMethod::ClientSecretJwt => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::ClientSecretJwt
}
mas_config::UpstreamOAuth2TokenAuthMethod::PrivateKeyJwt => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::PrivateKeyJwt
}
mas_config::UpstreamOAuth2TokenAuthMethod::SignInWithApple => {
mas_data_model::UpstreamOAuthProviderTokenAuthMethod::SignInWithApple
}
};
let response_mode = provider
.response_mode
.map(|response_mode| match response_mode {
mas_config::UpstreamOAuth2ResponseMode::Query => {
mas_data_model::UpstreamOAuthProviderResponseMode::Query
}
mas_config::UpstreamOAuth2ResponseMode::FormPost => {
mas_data_model::UpstreamOAuthProviderResponseMode::FormPost
}
});
if discovery_mode.is_disabled() {
if provider.authorization_endpoint.is_none() {
error!(provider.id = %provider.id, "Provider has discovery disabled but no authorization endpoint set");
}
if provider.token_endpoint.is_none() {
error!(provider.id = %provider.id, "Provider has discovery disabled but no token endpoint set");
}
if provider.jwks_uri.is_none() {
warn!(provider.id = %provider.id, "Provider has discovery disabled but no JWKS URI set");
}
}
let pkce_mode = match provider.pkce_method {
mas_config::UpstreamOAuth2PkceMethod::Auto => {
mas_data_model::UpstreamOAuthProviderPkceMode::Auto
}
mas_config::UpstreamOAuth2PkceMethod::Always => {
mas_data_model::UpstreamOAuthProviderPkceMode::S256
}
mas_config::UpstreamOAuth2PkceMethod::Never => {
mas_data_model::UpstreamOAuthProviderPkceMode::Disabled
}
};
let on_backchannel_logout = match provider.on_backchannel_logout {
mas_config::UpstreamOAuth2OnBackchannelLogout::DoNothing => {
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::DoNothing
}
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutBrowserOnly => {
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutBrowserOnly
}
mas_config::UpstreamOAuth2OnBackchannelLogout::LogoutAll => {
mas_data_model::UpstreamOAuthProviderOnBackchannelLogout::LogoutAll
}
};
repo.upstream_oauth_provider()
.upsert(
clock,
provider.id,
UpstreamOAuthProviderParams {
issuer: provider.issuer,
human_name: provider.human_name,
brand_name: provider.brand_name,
scope: provider.scope.parse()?,
token_endpoint_auth_method,
token_endpoint_signing_alg: provider.token_endpoint_auth_signing_alg,
id_token_signed_response_alg: provider.id_token_signed_response_alg,
client_id: provider.client_id,
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
fetch_userinfo: provider.fetch_userinfo,
userinfo_signed_response_alg: provider.userinfo_signed_response_alg,
response_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters
.into_iter()
.collect(),
forward_login_hint: provider.forward_login_hint,
ui_order,
on_backchannel_logout,
},
)
.await?;
}
}
{
let _span = info_span!("cli.config.sync.clients").entered();
let config_ids = clients_config
.iter()
.map(|c| c.client_id)
.collect::<BTreeSet<_>>();
let existing = repo.oauth2_client().all_static().await?;
let existing_ids = existing.iter().map(|p| p.id).collect::<BTreeSet<_>>();
let to_delete = existing.into_iter().filter(|p| !config_ids.contains(&p.id));
if prune {
for client in to_delete {
info!(client.id = %client.client_id, "Deleting client");
if dry_run {
continue;
}
repo.oauth2_client().delete(client).await?;
}
} else {
let len = to_delete.count();
match len {
0 => {}
1 => warn!(
"A static client in the database is not in the config. Run with `--prune` to delete it."
),
n => warn!(
"{n} static clients in the database are not in the config. Run with `--prune` to delete them."
),
}
}
for client in clients_config {
let _span = info_span!("client", client.id = %client.client_id).entered();
if existing_ids.contains(&client.client_id) {
info!(client.id = %client.client_id, "Updating client");
} else {
info!(client.id = %client.client_id, "Adding client");
}
if dry_run {
continue;
}
let client_secret = client.client_secret().await?;
let client_name = client.client_name.as_ref();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks.as_ref();
let jwks_uri = client.jwks_uri.as_ref();
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| encrypter.encrypt_to_string(client_secret.as_bytes()))
.transpose()?;
repo.oauth2_client()
.upsert_static(
client.client_id,
client_name.cloned(),
client_auth_method,
encrypted_client_secret,
jwks.cloned(),
jwks_uri.cloned(),
client.redirect_uris,
)
.await?;
}
}
// Get the lock and release it to commit the transaction
let lock = repo.into_inner();
let txn = lock.release_now().await?;
if dry_run {
info!("Dry run, rolling back changes");
txn.rollback().await?;
} else {
txn.commit().await?;
}
Ok(())
}

View file

@ -0,0 +1,293 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::sync::{LazyLock, OnceLock};
use anyhow::Context as _;
use bytes::Bytes;
use http_body_util::Full;
use hyper::{Response, header::CONTENT_TYPE};
use mas_config::{
MetricsConfig, MetricsExporterKind, Propagator, TelemetryConfig, TracingConfig,
TracingExporterKind,
};
use opentelemetry::{
InstrumentationScope, KeyValue,
metrics::Meter,
propagation::{TextMapCompositePropagator, TextMapPropagator},
trace::TracerProvider as _,
};
use opentelemetry_otlp::{WithExportConfig, WithHttpConfig};
use opentelemetry_prometheus_text_exporter::PrometheusExporter;
use opentelemetry_sdk::{
Resource,
metrics::{ManualReader, SdkMeterProvider, periodic_reader_with_async_runtime::PeriodicReader},
propagation::{BaggagePropagator, TraceContextPropagator},
trace::{
IdGenerator, Sampler, SdkTracerProvider, Tracer,
span_processor_with_async_runtime::BatchSpanProcessor,
},
};
use opentelemetry_semantic_conventions as semcov;
static SCOPE: LazyLock<InstrumentationScope> = LazyLock::new(|| {
InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
.with_version(env!("CARGO_PKG_VERSION"))
.with_schema_url(semcov::SCHEMA_URL)
.build()
});
pub static METER: LazyLock<Meter> =
LazyLock::new(|| opentelemetry::global::meter_with_scope(SCOPE.clone()));
pub static TRACER: OnceLock<Tracer> = OnceLock::new();
static METER_PROVIDER: OnceLock<SdkMeterProvider> = OnceLock::new();
static TRACER_PROVIDER: OnceLock<SdkTracerProvider> = OnceLock::new();
static PROMETHEUS_EXPORTER: OnceLock<PrometheusExporter> = OnceLock::new();
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<()> {
let propagator = propagator(&config.tracing.propagators);
// The CORS filter needs to know what headers it should whitelist for
// CORS-protected requests.
mas_http::set_propagator(&propagator);
opentelemetry::global::set_text_map_propagator(propagator);
init_tracer(&config.tracing).context("Failed to configure traces exporter")?;
init_meter(&config.metrics).context("Failed to configure metrics exporter")?;
opentelemetry_instrumentation_process::init()
.context("Failed to configure process instrumentation")?;
opentelemetry_instrumentation_tokio::observe_current_runtime();
Ok(())
}
pub fn shutdown() -> opentelemetry_sdk::error::OTelSdkResult {
if let Some(tracer_provider) = TRACER_PROVIDER.get() {
tracer_provider.shutdown()?;
}
if let Some(meter_provider) = METER_PROVIDER.get() {
meter_provider.shutdown()?;
}
Ok(())
}
fn match_propagator(propagator: Propagator) -> Box<dyn TextMapPropagator + Send + Sync> {
use Propagator as P;
match propagator {
P::TraceContext => Box::new(TraceContextPropagator::new()),
P::Baggage => Box::new(BaggagePropagator::new()),
P::Jaeger => Box::new(opentelemetry_jaeger_propagator::Propagator::new()),
}
}
fn propagator(propagators: &[Propagator]) -> TextMapCompositePropagator {
let propagators = propagators.iter().copied().map(match_propagator).collect();
TextMapCompositePropagator::new(propagators)
}
/// An [`IdGenerator`] which always returns an invalid trace ID and span ID
///
/// This is used when no exporter is being used, so that we don't log the trace
/// ID when we're not tracing.
#[derive(Debug, Clone, Copy)]
struct InvalidIdGenerator;
impl IdGenerator for InvalidIdGenerator {
fn new_trace_id(&self) -> opentelemetry::TraceId {
opentelemetry::TraceId::INVALID
}
fn new_span_id(&self) -> opentelemetry::SpanId {
opentelemetry::SpanId::INVALID
}
}
fn init_tracer(config: &TracingConfig) -> anyhow::Result<()> {
let sample_rate = config.sample_rate.unwrap_or(1.0);
// We sample traces based on the parent if we have one, and if not, we
// sample a ratio based on the configured sample rate
let sampler = Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(sample_rate)));
let tracer_provider_builder = SdkTracerProvider::builder()
.with_resource(resource())
.with_sampler(sampler);
let tracer_provider = match config.exporter {
TracingExporterKind::None => tracer_provider_builder
.with_id_generator(InvalidIdGenerator)
.with_sampler(Sampler::AlwaysOff)
.build(),
TracingExporterKind::Stdout => {
let exporter = opentelemetry_stdout::SpanExporter::default();
tracer_provider_builder
.with_simple_exporter(exporter)
.build()
}
TracingExporterKind::Otlp => {
let mut exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_http_client(mas_http::reqwest_client());
if let Some(endpoint) = &config.endpoint {
exporter = exporter.with_endpoint(endpoint.as_str());
}
let exporter = exporter
.build()
.context("Failed to configure OTLP trace exporter")?;
let batch_processor =
BatchSpanProcessor::builder(exporter, opentelemetry_sdk::runtime::Tokio).build();
tracer_provider_builder
.with_span_processor(batch_processor)
.build()
}
};
TRACER_PROVIDER
.set(tracer_provider.clone())
.map_err(|_| anyhow::anyhow!("TRACER_PROVIDER was set twice"))?;
let tracer = tracer_provider.tracer_with_scope(SCOPE.clone());
TRACER
.set(tracer)
.map_err(|_| anyhow::anyhow!("TRACER was set twice"))?;
opentelemetry::global::set_tracer_provider(tracer_provider);
Ok(())
}
fn otlp_metric_reader(
endpoint: Option<&url::Url>,
) -> anyhow::Result<PeriodicReader<opentelemetry_otlp::MetricExporter>> {
let mut exporter = opentelemetry_otlp::MetricExporter::builder()
.with_http()
.with_http_client(mas_http::reqwest_client());
if let Some(endpoint) = endpoint {
exporter = exporter.with_endpoint(endpoint.to_string());
}
let exporter = exporter
.build()
.context("Failed to configure OTLP metric exporter")?;
let reader = PeriodicReader::builder(exporter, opentelemetry_sdk::runtime::Tokio).build();
Ok(reader)
}
fn stdout_metric_reader() -> PeriodicReader<opentelemetry_stdout::MetricExporter> {
let exporter = opentelemetry_stdout::MetricExporter::builder().build();
PeriodicReader::builder(exporter, opentelemetry_sdk::runtime::Tokio).build()
}
type PromServiceFuture =
std::future::Ready<Result<Response<Full<Bytes>>, std::convert::Infallible>>;
#[allow(clippy::needless_pass_by_value)]
fn prometheus_service_fn<T>(_req: T) -> PromServiceFuture {
let response = if let Some(exporter) = PROMETHEUS_EXPORTER.get() {
// We'll need some space for this, so we preallocate a bit
let mut buffer = Vec::with_capacity(1024);
if let Err(err) = exporter.export(&mut buffer) {
tracing::error!(
error = &err as &dyn std::error::Error,
"Failed to export Prometheus metrics"
);
Response::builder()
.status(500)
.header(CONTENT_TYPE, "text/plain")
.body(Full::new(Bytes::from_static(
b"Failed to export Prometheus metrics, see logs for details",
)))
.unwrap()
} else {
Response::builder()
.status(200)
.header(CONTENT_TYPE, "text/plain;version=1.0.0")
.body(Full::new(Bytes::from(buffer)))
.unwrap()
}
} else {
Response::builder()
.status(500)
.header(CONTENT_TYPE, "text/plain")
.body(Full::new(Bytes::from_static(
b"Prometheus exporter was not enabled in config",
)))
.unwrap()
};
std::future::ready(Ok(response))
}
pub fn prometheus_service<T>() -> tower::util::ServiceFn<fn(T) -> PromServiceFuture> {
if PROMETHEUS_EXPORTER.get().is_none() {
tracing::warn!(
"A Prometheus resource was mounted on a listener, but the Prometheus exporter was not setup in the config"
);
}
tower::service_fn(prometheus_service_fn as _)
}
fn prometheus_metric_reader() -> anyhow::Result<PrometheusExporter> {
let exporter = PrometheusExporter::builder().without_scope_info().build();
PROMETHEUS_EXPORTER
.set(exporter.clone())
.map_err(|_| anyhow::anyhow!("PROMETHEUS_EXPORTER was set twice"))?;
Ok(exporter)
}
fn init_meter(config: &MetricsConfig) -> anyhow::Result<()> {
let meter_provider_builder = SdkMeterProvider::builder();
let meter_provider_builder = match config.exporter {
MetricsExporterKind::None => meter_provider_builder.with_reader(ManualReader::default()),
MetricsExporterKind::Stdout => meter_provider_builder.with_reader(stdout_metric_reader()),
MetricsExporterKind::Otlp => {
meter_provider_builder.with_reader(otlp_metric_reader(config.endpoint.as_ref())?)
}
MetricsExporterKind::Prometheus => {
meter_provider_builder.with_reader(prometheus_metric_reader()?)
}
};
let meter_provider = meter_provider_builder.with_resource(resource()).build();
METER_PROVIDER
.set(meter_provider.clone())
.map_err(|_| anyhow::anyhow!("METER_PROVIDER was set twice"))?;
opentelemetry::global::set_meter_provider(meter_provider.clone());
Ok(())
}
fn resource() -> Resource {
Resource::builder()
.with_service_name(env!("CARGO_PKG_NAME"))
.with_detectors(&[
Box::new(opentelemetry_resource_detectors::HostResourceDetector::default()),
Box::new(opentelemetry_resource_detectors::OsResourceDetector),
Box::new(opentelemetry_resource_detectors::ProcessResourceDetector),
])
.with_attributes([
KeyValue::new(semcov::resource::SERVICE_VERSION, crate::VERSION),
KeyValue::new(semcov::resource::PROCESS_RUNTIME_NAME, "rust"),
KeyValue::new(
semcov::resource::PROCESS_RUNTIME_VERSION,
env!("VERGEN_RUSTC_SEMVER"),
),
])
.build()
}

View file

@ -0,0 +1,592 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{sync::Arc, time::Duration};
use anyhow::Context;
use mas_config::{
AccountConfig, BrandingConfig, CaptchaConfig, DatabaseConfig, EmailConfig, EmailSmtpMode,
EmailTransportKind, ExperimentalConfig, HomeserverKind, MatrixConfig, PasswordsConfig,
PolicyConfig, TemplatesConfig,
};
use mas_context::LogContext;
use mas_data_model::{SessionExpirationConfig, SessionLimitConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
use mas_matrix_synapse::{LegacySynapseConnection, SynapseConnection};
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_storage::{BoxRepositoryFactory, RepositoryAccess, RepositoryFactory};
use mas_templates::{SiteConfigExt, Templates};
use sqlx::{
ConnectOptions, Executor, PgConnection, PgPool,
postgres::{PgConnectOptions, PgPoolOptions},
};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{Instrument, log::LevelFilter};
pub async fn password_manager_from_config(
config: &PasswordsConfig,
) -> Result<PasswordManager, anyhow::Error> {
if !config.enabled() {
return Ok(PasswordManager::disabled());
}
let schemes = config.load().await?.into_iter().map(
|(version, algorithm, cost, secret, unicode_normalization)| {
use mas_handlers::passwords::Hasher;
let hasher = match algorithm {
mas_config::PasswordAlgorithm::Pbkdf2 => {
Hasher::pbkdf2(secret, unicode_normalization)
}
mas_config::PasswordAlgorithm::Bcrypt => {
Hasher::bcrypt(cost, secret, unicode_normalization)
}
mas_config::PasswordAlgorithm::Argon2id => {
Hasher::argon2id(secret, unicode_normalization)
}
};
(version, hasher)
},
);
PasswordManager::new(config.minimum_complexity(), schemes)
}
pub fn mailer_from_config(
config: &EmailConfig,
templates: &Templates,
) -> Result<Mailer, anyhow::Error> {
let from = config
.from
.parse()
.context("invalid email configuration: invalid 'from' address")?;
let reply_to = config
.reply_to
.parse()
.context("invalid email configuration: invalid 'reply_to' address")?;
let transport = match config.transport() {
EmailTransportKind::Blackhole => MailTransport::blackhole(),
EmailTransportKind::Smtp => {
// This should have been set ahead of time
let hostname = config
.hostname()
.context("invalid email configuration: missing hostname")?;
let mode = config
.mode()
.context("invalid email configuration: missing mode")?;
let credentials = match (config.username(), config.password()) {
(Some(username), Some(password)) => Some(mas_email::SmtpCredentials::new(
username.to_owned(),
password.to_owned(),
)),
(None, None) => None,
_ => {
anyhow::bail!("invalid email configuration: missing username or password");
}
};
let mode = match mode {
EmailSmtpMode::Plain => mas_email::SmtpMode::Plain,
EmailSmtpMode::StartTls => mas_email::SmtpMode::StartTls,
EmailSmtpMode::Tls => mas_email::SmtpMode::Tls,
};
MailTransport::smtp(mode, hostname, config.port(), credentials)
.context("failed to build SMTP transport")?
}
EmailTransportKind::Sendmail => MailTransport::sendmail(config.command()),
};
Ok(Mailer::new(templates.clone(), transport, from, reply_to))
}
/// Test the connection to the mailer in a background task
pub fn test_mailer_in_background(mailer: &Mailer, timeout: Duration) {
let mailer = mailer.clone();
let span = tracing::info_span!("cli.test_mailer");
tokio::spawn(
LogContext::new("mailer-test").run(async move || {
match tokio::time::timeout(timeout, mailer.test_connection()).await {
Ok(Ok(())) => {}
Ok(Err(err)) => {
tracing::warn!(
error = &err as &dyn std::error::Error,
"Could not connect to the mail backend, tasks sending mails may fail!"
);
}
Err(_) => {
tracing::warn!("Timed out while testing the mail backend connection, tasks sending mails may fail!");
}
}
})
.instrument(span)
);
}
pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
.context("failed to open OPA WASM policy file")?;
let entrypoints = mas_policy::Entrypoints {
register: config.register_entrypoint.clone(),
client_registration: config.client_registration_entrypoint.clone(),
authorization_grant: config.authorization_grant_entrypoint.clone(),
compat_login: config.compat_login_entrypoint.clone(),
email: config.email_entrypoint.clone(),
};
let session_limit_config =
experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
});
let data = mas_policy::Data::new(matrix_config.homeserver.clone(), session_limit_config)
.with_rest(config.data.clone());
PolicyFactory::load(policy_file, data, entrypoints)
.await
.context("failed to load the policy")
}
pub fn captcha_config_from_config(
captcha_config: &CaptchaConfig,
) -> Result<Option<mas_data_model::CaptchaConfig>, anyhow::Error> {
let Some(service) = captcha_config.service else {
return Ok(None);
};
let service = match service {
mas_config::CaptchaServiceKind::RecaptchaV2 => mas_data_model::CaptchaService::RecaptchaV2,
mas_config::CaptchaServiceKind::CloudflareTurnstile => {
mas_data_model::CaptchaService::CloudflareTurnstile
}
mas_config::CaptchaServiceKind::HCaptcha => mas_data_model::CaptchaService::HCaptcha,
};
Ok(Some(mas_data_model::CaptchaConfig {
service,
site_key: captcha_config
.site_key
.clone()
.context("missing site key")?,
secret_key: captcha_config
.secret_key
.clone()
.context("missing secret key")?,
}))
}
pub fn site_config_from_config(
branding_config: &BrandingConfig,
matrix_config: &MatrixConfig,
experimental_config: &ExperimentalConfig,
password_config: &PasswordsConfig,
account_config: &AccountConfig,
captcha_config: &CaptchaConfig,
) -> Result<SiteConfig, anyhow::Error> {
let captcha = captcha_config_from_config(captcha_config)?;
let session_expiration = experimental_config
.inactive_session_expiration
.as_ref()
.map(|c| SessionExpirationConfig {
oauth_session_inactivity_ttl: c.expire_oauth_sessions.then_some(c.ttl),
compat_session_inactivity_ttl: c.expire_compat_sessions.then_some(c.ttl),
user_session_inactivity_ttl: c.expire_user_sessions.then_some(c.ttl),
});
Ok(SiteConfig {
access_token_ttl: experimental_config.access_token_ttl,
compat_token_ttl: experimental_config.compat_token_ttl,
server_name: matrix_config.homeserver.clone(),
policy_uri: branding_config.policy_uri.clone(),
tos_uri: branding_config.tos_uri.clone(),
imprint: branding_config.imprint.clone(),
password_login_enabled: password_config.enabled(),
password_registration_enabled: password_config.enabled()
&& account_config.password_registration_enabled,
password_registration_email_required: account_config.password_registration_email_required,
registration_token_required: account_config.registration_token_required,
email_change_allowed: account_config.email_change_allowed,
displayname_change_allowed: account_config.displayname_change_allowed,
password_change_allowed: password_config.enabled()
&& account_config.password_change_allowed,
account_recovery_allowed: password_config.enabled()
&& account_config.password_recovery_enabled,
account_deactivation_allowed: account_config.account_deactivation_allowed,
captcha,
minimum_password_complexity: password_config.minimum_complexity(),
session_expiration,
login_with_email_allowed: account_config.login_with_email_allowed,
plan_management_iframe_uri: experimental_config.plan_management_iframe_uri.clone(),
session_limit: experimental_config
.session_limit
.as_ref()
.map(|c| SessionLimitConfig {
soft_limit: c.soft_limit,
hard_limit: c.hard_limit,
}),
})
}
pub async fn templates_from_config(
config: &TemplatesConfig,
site_config: &SiteConfig,
url_builder: &UrlBuilder,
strict: bool,
stabilise: bool,
) -> Result<Templates, anyhow::Error> {
Templates::load(
config.path.clone(),
url_builder.clone(),
(!stabilise).then(|| config.assets_manifest.clone()),
config.translations_path.clone(),
site_config.templates_branding(),
site_config.templates_features(),
strict,
)
.await
.with_context(|| format!("Failed to load the templates at {}", config.path))
}
fn database_connect_options_from_config(
config: &DatabaseConfig,
opts: &DatabaseConnectOptions,
) -> Result<PgConnectOptions, anyhow::Error> {
let options = if let Some(uri) = config.uri.as_deref() {
uri.parse()
.context("could not parse database connection string")?
} else {
let mut opts = PgConnectOptions::new().application_name("matrix-authentication-service");
if let Some(host) = config.host.as_deref() {
opts = opts.host(host);
}
if let Some(port) = config.port {
opts = opts.port(port);
}
if let Some(socket) = config.socket.as_deref() {
opts = opts.socket(socket);
}
if let Some(username) = config.username.as_deref() {
opts = opts.username(username);
}
if let Some(password) = config.password.as_deref() {
opts = opts.password(password);
}
if let Some(database) = config.database.as_deref() {
opts = opts.database(database);
}
opts
};
let options = match (config.ssl_ca.as_deref(), config.ssl_ca_file.as_deref()) {
(None, None) => options,
(Some(pem), None) => options.ssl_root_cert_from_pem(pem.as_bytes().to_owned()),
(None, Some(path)) => options.ssl_root_cert(path),
(Some(_), Some(_)) => {
anyhow::bail!("invalid database configuration: both `ssl_ca` and `ssl_ca_file` are set")
}
};
let options = match (
config.ssl_certificate.as_deref(),
config.ssl_certificate_file.as_deref(),
) {
(None, None) => options,
(Some(pem), None) => options.ssl_client_cert_from_pem(pem.as_bytes()),
(None, Some(path)) => options.ssl_client_cert(path),
(Some(_), Some(_)) => {
anyhow::bail!(
"invalid database configuration: both `ssl_certificate` and `ssl_certificate_file` are set"
)
}
};
let options = match (config.ssl_key.as_deref(), config.ssl_key_file.as_deref()) {
(None, None) => options,
(Some(pem), None) => options.ssl_client_key_from_pem(pem.as_bytes()),
(None, Some(path)) => options.ssl_client_key(path),
(Some(_), Some(_)) => {
anyhow::bail!(
"invalid database configuration: both `ssl_key` and `ssl_key_file` are set"
)
}
};
let options = match &config.ssl_mode {
Some(ssl_mode) => {
let ssl_mode = match ssl_mode {
mas_config::PgSslMode::Disable => sqlx::postgres::PgSslMode::Disable,
mas_config::PgSslMode::Allow => sqlx::postgres::PgSslMode::Allow,
mas_config::PgSslMode::Prefer => sqlx::postgres::PgSslMode::Prefer,
mas_config::PgSslMode::Require => sqlx::postgres::PgSslMode::Require,
mas_config::PgSslMode::VerifyCa => sqlx::postgres::PgSslMode::VerifyCa,
mas_config::PgSslMode::VerifyFull => sqlx::postgres::PgSslMode::VerifyFull,
};
options.ssl_mode(ssl_mode)
}
None => options,
};
let mut options = options.log_statements(LevelFilter::Debug);
if opts.log_slow_statements {
options = options.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
}
Ok(options)
}
/// Create a database connection pool from the configuration
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
let options = database_connect_options_from_config(config, &DatabaseConnectOptions::default())?;
PgPoolOptions::new()
.max_connections(config.max_connections.into())
.min_connections(config.min_connections)
.acquire_timeout(config.connect_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.after_connect(|conn, _meta| {
Box::pin(async move {
// Unlisten from all channels, as we might be connected via a connection pooler
// that doesn't clean up LISTEN/NOTIFY state when reusing connections.
conn.execute("UNLISTEN *;").await?;
Ok(())
})
})
.connect_with(options)
.await
.context("could not connect to the database")
}
pub struct DatabaseConnectOptions {
pub log_slow_statements: bool,
}
impl Default for DatabaseConnectOptions {
fn default() -> Self {
Self {
log_slow_statements: true,
}
}
}
/// Create a single database connection from the configuration
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_connection_from_config(
config: &DatabaseConfig,
) -> Result<PgConnection, anyhow::Error> {
database_connect_options_from_config(config, &DatabaseConnectOptions::default())?
.connect()
.await
.context("could not connect to the database")
}
/// Create a single database connection from the configuration,
/// with specific options.
#[tracing::instrument(name = "db.connect", skip_all)]
pub async fn database_connection_from_config_with_options(
config: &DatabaseConfig,
options: &DatabaseConnectOptions,
) -> Result<PgConnection, anyhow::Error> {
database_connect_options_from_config(config, options)?
.connect()
.await
.context("could not connect to the database")
}
/// Update the policy factory dynamic data from the database and spawn a task to
/// periodically update it
// XXX: this could be put somewhere else?
pub async fn load_policy_factory_dynamic_data_continuously(
policy_factory: &Arc<PolicyFactory>,
repository_factory: BoxRepositoryFactory,
cancellation_token: CancellationToken,
task_tracker: &TaskTracker,
) -> Result<(), anyhow::Error> {
let policy_factory = policy_factory.clone();
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await?;
task_tracker.spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
tokio::select! {
() = cancellation_token.cancelled() => {
return;
}
_ = interval.tick() => {}
}
if let Err(err) =
load_policy_factory_dynamic_data(&policy_factory, &*repository_factory).await
{
tracing::error!(
error = ?err,
"Failed to load policy factory dynamic data"
);
cancellation_token.cancel();
return;
}
}
});
Ok(())
}
/// Update the policy factory dynamic data from the database
#[tracing::instrument(name = "policy.load_dynamic_data", skip_all)]
pub async fn load_policy_factory_dynamic_data(
policy_factory: &PolicyFactory,
repository_factory: &(dyn RepositoryFactory + Send + Sync),
) -> Result<(), anyhow::Error> {
let mut repo = repository_factory
.create()
.await
.context("Failed to acquire database connection")?;
if let Some(data) = repo.policy_data().get().await? {
let id = data.id;
let updated = policy_factory.set_dynamic_data(data).await?;
if updated {
tracing::info!(policy_data.id = %id, "Loaded dynamic policy data from the database");
}
}
Ok(())
}
/// Create a clonable, type-erased [`HomeserverConnection`] from the
/// configuration
pub async fn homeserver_connection_from_config(
config: &MatrixConfig,
http_client: reqwest::Client,
) -> anyhow::Result<Arc<dyn HomeserverConnection>> {
Ok(match config.kind {
HomeserverKind::Synapse | HomeserverKind::SynapseModern => {
Arc::new(SynapseConnection::new(
config.homeserver.clone(),
config.endpoint.clone(),
config.secret().await?,
http_client,
))
}
HomeserverKind::SynapseLegacy => Arc::new(LegacySynapseConnection::new(
config.homeserver.clone(),
config.endpoint.clone(),
config.secret().await?,
http_client,
)),
HomeserverKind::SynapseReadOnly => {
let connection = SynapseConnection::new(
config.homeserver.clone(),
config.endpoint.clone(),
config.secret().await?,
http_client,
);
let readonly = ReadOnlyHomeserverConnection::new(connection);
Arc::new(readonly)
}
})
}
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use zeroize::Zeroizing;
use super::*;
#[tokio::test]
async fn test_password_manager_from_config() {
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let password = Zeroizing::new("hunter2".to_owned());
// Test a valid, enabled config
let config = serde_json::from_value(serde_json::json!({
"schemes": [{
"version": 42,
"algorithm": "argon2id"
}, {
"version": 10,
"algorithm": "bcrypt"
}]
}))
.unwrap();
let manager = password_manager_from_config(&config).await;
assert!(manager.is_ok());
let manager = manager.unwrap();
assert!(manager.is_enabled());
let hashed = manager.hash(&mut rng, password.clone()).await;
assert!(hashed.is_ok());
let (version, hashed) = hashed.unwrap();
assert_eq!(version, 42);
assert!(hashed.starts_with("$argon2id$"));
// Test a valid, disabled config
let config = serde_json::from_value(serde_json::json!({
"enabled": false,
"schemes": []
}))
.unwrap();
let manager = password_manager_from_config(&config).await;
assert!(manager.is_ok());
let manager = manager.unwrap();
assert!(!manager.is_enabled());
let res = manager.hash(&mut rng, password.clone()).await;
assert!(res.is_err());
// Test an invalid config
// Repeat the same version twice
let config = serde_json::from_value(serde_json::json!({
"schemes": [{
"version": 42,
"algorithm": "argon2id"
}, {
"version": 42,
"algorithm": "bcrypt"
}]
}))
.unwrap();
let manager = password_manager_from_config(&config).await;
assert!(manager.is_err());
// Empty schemes
let config = serde_json::from_value(serde_json::json!({
"schemes": []
}))
.unwrap();
let manager = password_manager_from_config(&config).await;
assert!(manager.is_err());
}
}

View file

@ -0,0 +1,53 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[package]
name = "mas-config"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
publish.workspace = true
[lints]
workspace = true
[dependencies]
anyhow.workspace = true
camino.workspace = true
chrono.workspace = true
figment.workspace = true
futures-util.workspace = true
governor.workspace = true
hex.workspace = true
indoc.workspace = true
ipnetwork.workspace = true
lettre.workspace = true
pem-rfc7468.workspace = true
rand_chacha.workspace = true
rand.workspace = true
rustls-pki-types.workspace = true
schemars.workspace = true
serde_json.workspace = true
serde_with.workspace = true
serde.workspace = true
tokio.workspace = true
tracing.workspace = true
ulid.workspace = true
url.workspace = true
mas-jose.workspace = true
mas-keystore.workspace = true
mas-iana.workspace = true
[features]
docker = []
dist = []
[[bin]]
name = "schema"
doc = false

View file

@ -0,0 +1,14 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use schemars::generate::SchemaSettings;
fn main() {
let generator = SchemaSettings::draft07().into_generator();
let schema = generator.into_root_schema_for::<mas_config::RootConfig>();
serde_json::to_writer_pretty(std::io::stdout(), &schema).expect("Failed to serialize schema");
}

View file

@ -0,0 +1,24 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
#![deny(missing_docs, rustdoc::missing_crate_level_docs)]
#![allow(clippy::module_name_repetitions)]
// derive(JSONSchema) uses &str.to_string()
#![allow(clippy::str_to_string)]
//! Application configuration logic
#[cfg(all(feature = "docker", feature = "dist"))]
compile_error!("Only one of the `docker` and `dist` features can be enabled at once");
pub(crate) mod schema;
mod sections;
pub(crate) mod util;
pub use self::{
sections::*,
util::{ConfigurationSection, ConfigurationSectionExt},
};

View file

@ -0,0 +1,27 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
//! Useful JSON Schema definitions
use std::borrow::Cow;
use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
/// A network hostname
pub struct Hostname;
impl JsonSchema for Hostname {
fn schema_name() -> Cow<'static, str> {
Cow::Borrowed("Hostname")
}
fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
json_schema!({
"type": "string",
"format": "hostname",
})
}
}

View file

@ -0,0 +1,125 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::ConfigurationSection;
const fn default_true() -> bool {
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default_true(value: &bool) -> bool {
*value == default_true()
}
const fn default_false() -> bool {
false
}
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default_false(value: &bool) -> bool {
*value == default_false()
}
/// Configuration section to configure features related to account management
#[allow(clippy::struct_excessive_bools)]
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct AccountConfig {
/// Whether users are allowed to change their email addresses. Defaults to
/// `true`.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub email_change_allowed: bool,
/// Whether users are allowed to change their display names. Defaults to
/// `true`.
///
/// This should be in sync with the policy in the homeserver configuration.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub displayname_change_allowed: bool,
/// Whether to enable self-service password registration. Defaults to
/// `false` if password authentication is enabled.
///
/// This has no effect if password login is disabled.
#[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub password_registration_enabled: bool,
/// Whether self-service password registrations require a valid email.
/// Defaults to `true`.
///
/// This has no effect if password registration is disabled.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub password_registration_email_required: bool,
/// Whether users are allowed to change their passwords. Defaults to `true`.
///
/// This has no effect if password login is disabled.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub password_change_allowed: bool,
/// Whether email-based password recovery is enabled. Defaults to `false`.
///
/// This has no effect if password login is disabled.
#[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub password_recovery_enabled: bool,
/// Whether users are allowed to delete their own account. Defaults to
/// `true`.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub account_deactivation_allowed: bool,
/// Whether users can log in with their email address. Defaults to `false`.
///
/// This has no effect if password login is disabled.
#[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub login_with_email_allowed: bool,
/// Whether registration tokens are required for password registrations.
/// Defaults to `false`.
///
/// When enabled, users must provide a valid registration token during
/// password registration. This has no effect if password registration
/// is disabled.
#[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub registration_token_required: bool,
}
impl Default for AccountConfig {
fn default() -> Self {
Self {
email_change_allowed: default_true(),
displayname_change_allowed: default_true(),
password_registration_enabled: default_false(),
password_registration_email_required: default_true(),
password_change_allowed: default_true(),
password_recovery_enabled: default_false(),
account_deactivation_allowed: default_true(),
login_with_email_allowed: default_false(),
registration_token_required: default_false(),
}
}
}
impl AccountConfig {
/// Returns true if the configuration is the default one
pub(crate) fn is_default(&self) -> bool {
is_default_false(&self.password_registration_enabled)
&& is_default_true(&self.email_change_allowed)
&& is_default_true(&self.displayname_change_allowed)
&& is_default_true(&self.password_change_allowed)
&& is_default_false(&self.password_recovery_enabled)
&& is_default_true(&self.account_deactivation_allowed)
&& is_default_false(&self.login_with_email_allowed)
&& is_default_false(&self.registration_token_required)
}
}
impl ConfigurationSection for AccountConfig {
const PATH: Option<&'static str> = Some("account");
}

View file

@ -0,0 +1,55 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use url::Url;
use crate::ConfigurationSection;
/// Configuration section for tweaking the branding of the service
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize, Default)]
pub struct BrandingConfig {
/// A human-readable name. Defaults to the server's address.
#[serde(skip_serializing_if = "Option::is_none")]
pub service_name: Option<String>,
/// Link to a privacy policy, displayed in the footer of web pages and
/// emails. It is also advertised to clients through the `op_policy_uri`
/// OIDC provider metadata.
#[serde(skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<Url>,
/// Link to a terms of service document, displayed in the footer of web
/// pages and emails. It is also advertised to clients through the
/// `op_tos_uri` OIDC provider metadata.
#[serde(skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<Url>,
/// Legal imprint, displayed in the footer in the footer of web pages and
/// emails.
#[serde(skip_serializing_if = "Option::is_none")]
pub imprint: Option<String>,
/// Logo displayed in some web pages.
#[serde(skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<Url>,
}
impl BrandingConfig {
/// Returns true if the configuration is the default one
pub(crate) fn is_default(&self) -> bool {
self.service_name.is_none()
&& self.policy_uri.is_none()
&& self.tos_uri.is_none()
&& self.imprint.is_none()
&& self.logo_uri.is_none()
}
}
impl ConfigurationSection for BrandingConfig {
const PATH: Option<&'static str> = Some("branding");
}

View file

@ -0,0 +1,83 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error};
use crate::ConfigurationSection;
/// Which service should be used for CAPTCHA protection
#[derive(Clone, Copy, Debug, Deserialize, JsonSchema, Serialize)]
pub enum CaptchaServiceKind {
/// Use Google's reCAPTCHA v2 API
#[serde(rename = "recaptcha_v2")]
RecaptchaV2,
/// Use Cloudflare Turnstile
#[serde(rename = "cloudflare_turnstile")]
CloudflareTurnstile,
/// Use ``HCaptcha``
#[serde(rename = "hcaptcha")]
HCaptcha,
}
/// Configuration section to setup CAPTCHA protection on a few operations
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize, Default)]
pub struct CaptchaConfig {
/// Which service should be used for CAPTCHA protection
#[serde(skip_serializing_if = "Option::is_none")]
pub service: Option<CaptchaServiceKind>,
/// The site key to use
#[serde(skip_serializing_if = "Option::is_none")]
pub site_key: Option<String>,
/// The secret key to use
#[serde(skip_serializing_if = "Option::is_none")]
pub secret_key: Option<String>,
}
impl CaptchaConfig {
/// Returns true if the configuration is the default one
pub(crate) fn is_default(&self) -> bool {
self.service.is_none() && self.site_key.is_none() && self.secret_key.is_none()
}
}
impl ConfigurationSection for CaptchaConfig {
const PATH: Option<&'static str> = Some("captcha");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};
let missing_field = |field: &'static str| {
error_on_field(figment::error::Error::missing_field(field), field)
};
if let Some(CaptchaServiceKind::RecaptchaV2) = self.service {
if self.site_key.is_none() {
return Err(missing_field("site_key").into());
}
if self.secret_key.is_none() {
return Err(missing_field("secret_key").into());
}
}
Ok(())
}
}

View file

@ -0,0 +1,353 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::ops::Deref;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::jwk::PublicJsonWebKeySet;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error};
use serde_with::serde_as;
use ulid::Ulid;
use url::Url;
use super::{ClientSecret, ClientSecretRaw, ConfigurationSection};
/// Authentication method used by clients
#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthMethodConfig {
/// `none`: No authentication
None,
/// `client_secret_basic`: `client_id` and `client_secret` used as basic
/// authorization credentials
ClientSecretBasic,
/// `client_secret_post`: `client_id` and `client_secret` sent in the
/// request body
ClientSecretPost,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// signed using the `client_secret`
ClientSecretJwt,
/// `client_secret_basic`: a `client_assertion` sent in the request body and
/// signed by an asymmetric key
PrivateKeyJwt,
}
impl std::fmt::Display for ClientAuthMethodConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientAuthMethodConfig::None => write!(f, "none"),
ClientAuthMethodConfig::ClientSecretBasic => write!(f, "client_secret_basic"),
ClientAuthMethodConfig::ClientSecretPost => write!(f, "client_secret_post"),
ClientAuthMethodConfig::ClientSecretJwt => write!(f, "client_secret_jwt"),
ClientAuthMethodConfig::PrivateKeyJwt => write!(f, "private_key_jwt"),
}
}
}
/// An OAuth 2.0 client configuration
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ClientConfig {
/// The client ID
#[schemars(
with = "String",
regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
description = "A ULID as per https://github.com/ulid/spec"
)]
pub client_id: Ulid,
/// Authentication method used for this client
client_auth_method: ClientAuthMethodConfig,
/// Name of the `OAuth2` client
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
/// The client secret, used by the `client_secret_basic`,
/// `client_secret_post` and `client_secret_jwt` authentication methods
#[schemars(with = "ClientSecretRaw")]
#[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
#[serde(flatten)]
pub client_secret: Option<ClientSecret>,
/// The JSON Web Key Set (JWKS) used by the `private_key_jwt` authentication
/// method. Mutually exclusive with `jwks_uri`
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks: Option<PublicJsonWebKeySet>,
/// The URL of the JSON Web Key Set (JWKS) used by the `private_key_jwt`
/// authentication method. Mutually exclusive with `jwks`
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<Url>,
/// List of allowed redirect URIs
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub redirect_uris: Vec<Url>,
}
impl ClientConfig {
fn validate(&self) -> Result<(), Box<figment::error::Error>> {
let auth_method = self.client_auth_method;
match self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt => {
if self.jwks.is_none() && self.jwks_uri.is_none() {
let error = figment::error::Error::custom(
"jwks or jwks_uri is required for private_key_jwt",
);
return Err(Box::new(error.with_path("client_auth_method")));
}
if self.jwks.is_some() && self.jwks_uri.is_some() {
let error =
figment::error::Error::custom("jwks and jwks_uri are mutually exclusive");
return Err(Box::new(error.with_path("jwks")));
}
if self.client_secret.is_some() {
let error = figment::error::Error::custom(
"client_secret is not allowed with private_key_jwt",
);
return Err(Box::new(error.with_path("client_secret")));
}
}
ClientAuthMethodConfig::ClientSecretPost
| ClientAuthMethodConfig::ClientSecretBasic
| ClientAuthMethodConfig::ClientSecretJwt => {
if self.client_secret.is_none() {
let error = figment::error::Error::custom(format!(
"client_secret is required for {auth_method}"
));
return Err(Box::new(error.with_path("client_auth_method")));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(format!(
"jwks is not allowed with {auth_method}"
));
return Err(Box::new(error.with_path("jwks")));
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(format!(
"jwks_uri is not allowed with {auth_method}"
));
return Err(Box::new(error.with_path("jwks_uri")));
}
}
ClientAuthMethodConfig::None => {
if self.client_secret.is_some() {
let error = figment::error::Error::custom(
"client_secret is not allowed with none authentication method",
);
return Err(Box::new(error.with_path("client_secret")));
}
if self.jwks.is_some() {
let error = figment::error::Error::custom(
"jwks is not allowed with none authentication method",
);
return Err(Box::new(error));
}
if self.jwks_uri.is_some() {
let error = figment::error::Error::custom(
"jwks_uri is not allowed with none authentication method",
);
return Err(Box::new(error));
}
}
}
Ok(())
}
/// Authentication method used for this client
#[must_use]
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
match self.client_auth_method {
ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
ClientAuthMethodConfig::ClientSecretBasic => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
ClientAuthMethodConfig::ClientSecretPost => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
ClientAuthMethodConfig::ClientSecretJwt => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
ClientAuthMethodConfig::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
}
}
/// Returns the client secret.
///
/// If `client_secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the client secret could not be read from file.
pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
Ok(match &self.client_secret {
Some(client_secret) => Some(client_secret.value().await?),
None => None,
})
}
}
/// List of OAuth 2.0/OIDC clients config
#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
#[serde(transparent)]
pub struct ClientsConfig(#[schemars(with = "Vec::<ClientConfig>")] Vec<ClientConfig>);
impl ClientsConfig {
/// Returns true if all fields are at their default values
pub(crate) fn is_default(&self) -> bool {
self.0.is_empty()
}
}
impl Deref for ClientsConfig {
type Target = Vec<ClientConfig>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl IntoIterator for ClientsConfig {
type Item = ClientConfig;
type IntoIter = std::vec::IntoIter<ClientConfig>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl ConfigurationSection for ClientsConfig {
const PATH: Option<&'static str> = Some("clients");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, client) in self.0.iter().enumerate() {
client.validate().map_err(|mut err| {
// Save the error location information in the error
err.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
err.profile = Some(figment::Profile::Default);
err.path.insert(0, Self::PATH.unwrap().to_owned());
err.path.insert(1, format!("{index}"));
err
})?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use tokio::{runtime::Handle, task};
use super::*;
#[tokio::test]
async fn load_config() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
clients:
- client_id: 01GFWR28C4KNE04WG3HKXB7C9R
client_auth_method: none
redirect_uris:
- https://exemple.fr/callback
- client_id: 01GFWR32NCQ12B8Z0J8CPXRRB6
client_auth_method: client_secret_basic
client_secret_file: secret
- client_id: 01GFWR3WHR93Y5HK389H28VHZ9
client_auth_method: client_secret_post
client_secret: c1!3n753c237
- client_id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
client_auth_method: client_secret_jwt
client_secret_file: secret
- client_id: 01GFWR4BNFDCC4QDG6AMSP1VRR
client_auth_method: private_key_jwt
jwks:
keys:
- kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
- kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
"#,
)?;
jail.create_file("secret", r"c1!3n753c237")?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<ClientsConfig>("clients")?;
assert_eq!(config.0.len(), 5);
assert_eq!(
config.0[0].client_id,
Ulid::from_str("01GFWR28C4KNE04WG3HKXB7C9R").unwrap()
);
assert_eq!(
config.0[0].redirect_uris,
vec!["https://exemple.fr/callback".parse().unwrap()]
);
assert_eq!(
config.0[1].client_id,
Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
);
assert_eq!(config.0[1].redirect_uris, Vec::new());
assert!(config.0[0].client_secret.is_none());
assert!(matches!(config.0[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(matches!(config.0[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
assert!(matches!(config.0[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(config.0[4].client_secret.is_none());
Handle::current().block_on(async move {
assert_eq!(config.0[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.0[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.0[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
});
Ok(())
});
}).await.unwrap();
}
}

View file

@ -0,0 +1,319 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{num::NonZeroU32, time::Duration};
use camino::Utf8PathBuf;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
use crate::schema;
#[allow(clippy::unnecessary_wraps)]
fn default_connection_string() -> Option<String> {
Some("postgresql://".to_owned())
}
fn default_max_connections() -> NonZeroU32 {
NonZeroU32::new(10).unwrap()
}
fn default_connect_timeout() -> Duration {
Duration::from_secs(30)
}
#[allow(clippy::unnecessary_wraps)]
fn default_idle_timeout() -> Option<Duration> {
Some(Duration::from_secs(10 * 60))
}
#[allow(clippy::unnecessary_wraps)]
fn default_max_lifetime() -> Option<Duration> {
Some(Duration::from_secs(30 * 60))
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
uri: default_connection_string(),
host: None,
port: None,
socket: None,
username: None,
password: None,
database: None,
ssl_mode: None,
ssl_ca: None,
ssl_ca_file: None,
ssl_certificate: None,
ssl_certificate_file: None,
ssl_key: None,
ssl_key_file: None,
max_connections: default_max_connections(),
min_connections: Default::default(),
connect_timeout: default_connect_timeout(),
idle_timeout: default_idle_timeout(),
max_lifetime: default_max_lifetime(),
}
}
}
/// Options for controlling the level of protection provided for PostgreSQL SSL
/// connections.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub enum PgSslMode {
/// Only try a non-SSL connection.
Disable,
/// First try a non-SSL connection; if that fails, try an SSL connection.
Allow,
/// First try an SSL connection; if that fails, try a non-SSL connection.
Prefer,
/// Only try an SSL connection. If a root CA file is present, verify the
/// connection in the same way as if `VerifyCa` was specified.
Require,
/// Only try an SSL connection, and verify that the server certificate is
/// issued by a trusted certificate authority (CA).
VerifyCa,
/// Only try an SSL connection; verify that the server certificate is issued
/// by a trusted CA and that the requested server host name matches that
/// in the certificate.
VerifyFull,
}
/// Database connection configuration
#[serde_as]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig {
/// Connection URI
///
/// This must not be specified if `host`, `port`, `socket`, `username`,
/// `password`, or `database` are specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(url, default = "default_connection_string")]
pub uri: Option<String>,
/// Name of host to connect to
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option::<schema::Hostname>")]
pub host: Option<String>,
/// Port number to connect at the server host
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(range(min = 1, max = 65535))]
pub port: Option<u16>,
/// Directory containing the UNIX socket to connect to
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub socket: Option<Utf8PathBuf>,
/// PostgreSQL user name to connect as
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub username: Option<String>,
/// Password to be used if the server demands password authentication
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
/// The database name
///
/// This must not be specified if `uri` is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
/// How to handle SSL connections
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_mode: Option<PgSslMode>,
/// The PEM-encoded root certificate for SSL connections
///
/// This must not be specified if the `ssl_ca_file` option is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_ca: Option<String>,
/// Path to the root certificate for SSL connections
///
/// This must not be specified if the `ssl_ca` option is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_ca_file: Option<Utf8PathBuf>,
/// The PEM-encoded client certificate for SSL connections
///
/// This must not be specified if the `ssl_certificate_file` option is
/// specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_certificate: Option<String>,
/// Path to the client certificate for SSL connections
///
/// This must not be specified if the `ssl_certificate` option is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_certificate_file: Option<Utf8PathBuf>,
/// The PEM-encoded client key for SSL connections
///
/// This must not be specified if the `ssl_key_file` option is specified.
#[serde(skip_serializing_if = "Option::is_none")]
pub ssl_key: Option<String>,
/// Path to the client key for SSL connections
///
/// This must not be specified if the `ssl_key` option is specified.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub ssl_key_file: Option<Utf8PathBuf>,
/// Set the maximum number of connections the pool should maintain
#[serde(default = "default_max_connections")]
pub max_connections: NonZeroU32,
/// Set the minimum number of connections the pool should maintain
#[serde(default)]
pub min_connections: u32,
/// Set the amount of time to attempt connecting to the database
#[schemars(with = "u64")]
#[serde(default = "default_connect_timeout")]
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
pub connect_timeout: Duration,
/// Set a maximum idle duration for individual connections
#[schemars(with = "Option<u64>")]
#[serde(
default = "default_idle_timeout",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub idle_timeout: Option<Duration>,
/// Set the maximum lifetime of individual connections
#[schemars(with = "u64")]
#[serde(
default = "default_max_lifetime",
skip_serializing_if = "Option::is_none"
)]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
pub max_lifetime: Option<Duration>,
}
impl ConfigurationSection for DatabaseConfig {
const PATH: Option<&'static str> = Some("database");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let annotate = |mut error: figment::Error| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned()];
error
};
// Check that the user did not specify both `uri` and the split options at the
// same time
let has_split_options = self.host.is_some()
|| self.port.is_some()
|| self.socket.is_some()
|| self.username.is_some()
|| self.password.is_some()
|| self.database.is_some();
if self.uri.is_some() && has_split_options {
return Err(annotate(figment::error::Error::from(
"uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
)).into());
}
if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
return Err(annotate(figment::error::Error::from(
"ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
))
.into());
}
if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
return Err(annotate(figment::error::Error::from(
"ssl_certificate must not be specified if ssl_certificate_file is specified"
.to_owned(),
))
.into());
}
if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
return Err(annotate(figment::error::Error::from(
"ssl_key must not be specified if ssl_key_file is specified".to_owned(),
))
.into());
}
if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
{
return Err(annotate(figment::error::Error::from(
"both a ssl_certificate and a ssl_key must be set at the same time or none of them"
.to_owned(),
))
.into());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use super::*;
#[test]
fn load_config() {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r"
database:
uri: postgresql://user:password@host/database
",
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<DatabaseConfig>("database")?;
assert_eq!(
config.uri.as_deref(),
Some("postgresql://user:password@host/database")
);
Ok(())
});
}
}

View file

@ -0,0 +1,280 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
#![allow(deprecated)]
use std::{num::NonZeroU16, str::FromStr};
use lettre::message::Mailbox;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error};
use super::ConfigurationSection;
/// Encryption mode to use
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum EmailSmtpMode {
/// Plain text
Plain,
/// `StartTLS` (starts as plain text then upgrade to TLS)
StartTls,
/// TLS
Tls,
}
/// What backend should be used when sending emails
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum EmailTransportKind {
/// Don't send emails anywhere
#[default]
Blackhole,
/// Send emails via an SMTP relay
Smtp,
/// Send emails by calling sendmail
Sendmail,
}
fn default_email() -> String {
r#""Authentication Service" <root@localhost>"#.to_owned()
}
#[allow(clippy::unnecessary_wraps)]
fn default_sendmail_command() -> Option<String> {
Some("sendmail".to_owned())
}
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct EmailConfig {
/// Email address to use as From when sending emails
#[serde(default = "default_email")]
#[schemars(email)]
pub from: String,
/// Email address to use as Reply-To when sending emails
#[serde(default = "default_email")]
#[schemars(email)]
pub reply_to: String,
/// What backend should be used when sending emails
transport: EmailTransportKind,
/// SMTP transport: Connection mode to the relay
#[serde(skip_serializing_if = "Option::is_none")]
mode: Option<EmailSmtpMode>,
/// SMTP transport: Hostname to connect to
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<crate::schema::Hostname>")]
hostname: Option<String>,
/// SMTP transport: Port to connect to. Default is 25 for plain, 465 for TLS
/// and 587 for `StartTLS`
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(range(min = 1, max = 65535))]
port: Option<NonZeroU16>,
/// SMTP transport: Username for use to authenticate when connecting to the
/// SMTP server
///
/// Must be set if the `password` field is set
#[serde(skip_serializing_if = "Option::is_none")]
username: Option<String>,
/// SMTP transport: Password for use to authenticate when connecting to the
/// SMTP server
///
/// Must be set if the `username` field is set
#[serde(skip_serializing_if = "Option::is_none")]
password: Option<String>,
/// Sendmail transport: Command to use to send emails
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(default = "default_sendmail_command")]
command: Option<String>,
}
impl EmailConfig {
/// What backend should be used when sending emails
#[must_use]
pub fn transport(&self) -> EmailTransportKind {
self.transport
}
/// Connection mode to the relay
#[must_use]
pub fn mode(&self) -> Option<EmailSmtpMode> {
self.mode
}
/// Hostname to connect to
#[must_use]
pub fn hostname(&self) -> Option<&str> {
self.hostname.as_deref()
}
/// Port to connect to
#[must_use]
pub fn port(&self) -> Option<NonZeroU16> {
self.port
}
/// Username for use to authenticate when connecting to the SMTP server
#[must_use]
pub fn username(&self) -> Option<&str> {
self.username.as_deref()
}
/// Password for use to authenticate when connecting to the SMTP server
#[must_use]
pub fn password(&self) -> Option<&str> {
self.password.as_deref()
}
/// Command to use to send emails
#[must_use]
pub fn command(&self) -> Option<&str> {
self.command.as_deref()
}
}
impl Default for EmailConfig {
fn default() -> Self {
Self {
from: default_email(),
reply_to: default_email(),
transport: EmailTransportKind::Blackhole,
mode: None,
hostname: None,
port: None,
username: None,
password: None,
command: None,
}
}
}
impl ConfigurationSection for EmailConfig {
const PATH: Option<&'static str> = Some("email");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};
let missing_field = |field: &'static str| {
error_on_field(figment::error::Error::missing_field(field), field)
};
let unexpected_field = |field: &'static str, expected_fields: &'static [&'static str]| {
error_on_field(
figment::error::Error::unknown_field(field, expected_fields),
field,
)
};
match self.transport {
EmailTransportKind::Blackhole => {}
EmailTransportKind::Smtp => {
if let Err(e) = Mailbox::from_str(&self.from) {
return Err(error_on_field(figment::error::Error::custom(e), "from").into());
}
if let Err(e) = Mailbox::from_str(&self.reply_to) {
return Err(error_on_field(figment::error::Error::custom(e), "reply_to").into());
}
match (self.username.is_some(), self.password.is_some()) {
(true, true) | (false, false) => {}
(true, false) => {
return Err(missing_field("password").into());
}
(false, true) => {
return Err(missing_field("username").into());
}
}
if self.mode.is_none() {
return Err(missing_field("mode").into());
}
if self.hostname.is_none() {
return Err(missing_field("hostname").into());
}
if self.command.is_some() {
return Err(unexpected_field(
"command",
&[
"from",
"reply_to",
"transport",
"mode",
"hostname",
"port",
"username",
"password",
],
)
.into());
}
}
EmailTransportKind::Sendmail => {
let expected_fields = &["from", "reply_to", "transport", "command"];
if let Err(e) = Mailbox::from_str(&self.from) {
return Err(error_on_field(figment::error::Error::custom(e), "from").into());
}
if let Err(e) = Mailbox::from_str(&self.reply_to) {
return Err(error_on_field(figment::error::Error::custom(e), "reply_to").into());
}
if self.command.is_none() {
return Err(missing_field("command").into());
}
if self.mode.is_some() {
return Err(unexpected_field("mode", expected_fields).into());
}
if self.hostname.is_some() {
return Err(unexpected_field("hostname", expected_fields).into());
}
if self.port.is_some() {
return Err(unexpected_field("port", expected_fields).into());
}
if self.username.is_some() {
return Err(unexpected_field("username", expected_fields).into());
}
if self.password.is_some() {
return Err(unexpected_field("password", expected_fields).into());
}
}
}
Ok(())
}
}

View file

@ -0,0 +1,126 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::num::NonZeroU64;
use chrono::Duration;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use crate::ConfigurationSection;
fn default_true() -> bool {
true
}
fn default_token_ttl() -> Duration {
Duration::microseconds(5 * 60 * 1000 * 1000)
}
fn is_default_token_ttl(value: &Duration) -> bool {
*value == default_token_ttl()
}
/// Configuration options for the inactive session expiration feature
#[serde_as]
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct InactiveSessionExpirationConfig {
/// Time after which an inactive session is automatically finished
#[schemars(with = "u64", range(min = 600, max = 7_776_000))]
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub ttl: Duration,
/// Should compatibility sessions expire after inactivity
#[serde(default = "default_true")]
pub expire_compat_sessions: bool,
/// Should OAuth 2.0 sessions expire after inactivity
#[serde(default = "default_true")]
pub expire_oauth_sessions: bool,
/// Should user sessions expire after inactivity
#[serde(default = "default_true")]
pub expire_user_sessions: bool,
}
/// Configuration sections for experimental options
///
/// Do not change these options unless you know what you are doing.
#[serde_as]
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct ExperimentalConfig {
/// Time-to-live of access tokens in seconds. Defaults to 5 minutes.
#[schemars(with = "u64", range(min = 60, max = 86400))]
#[serde(
default = "default_token_ttl",
skip_serializing_if = "is_default_token_ttl"
)]
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub access_token_ttl: Duration,
/// Time-to-live of compatibility access tokens in seconds. Defaults to 5
/// minutes.
#[schemars(with = "u64", range(min = 60, max = 86400))]
#[serde(
default = "default_token_ttl",
skip_serializing_if = "is_default_token_ttl"
)]
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub compat_token_ttl: Duration,
/// Experimetal feature to automatically expire inactive sessions
///
/// Disabled by default
#[serde(skip_serializing_if = "Option::is_none")]
pub inactive_session_expiration: Option<InactiveSessionExpirationConfig>,
/// Experimental feature to show a plan management tab and iframe.
/// This value is passed through "as is" to the client without any
/// validation.
#[serde(skip_serializing_if = "Option::is_none")]
pub plan_management_iframe_uri: Option<String>,
/// Experimental feature to limit the number of application sessions per
/// user.
///
/// Disabled by default.
#[serde(skip_serializing_if = "Option::is_none")]
pub session_limit: Option<SessionLimitConfig>,
}
impl Default for ExperimentalConfig {
fn default() -> Self {
Self {
access_token_ttl: default_token_ttl(),
compat_token_ttl: default_token_ttl(),
inactive_session_expiration: None,
plan_management_iframe_uri: None,
session_limit: None,
}
}
}
impl ExperimentalConfig {
pub(crate) fn is_default(&self) -> bool {
is_default_token_ttl(&self.access_token_ttl)
&& is_default_token_ttl(&self.compat_token_ttl)
&& self.inactive_session_expiration.is_none()
&& self.plan_management_iframe_uri.is_none()
&& self.session_limit.is_none()
}
}
impl ConfigurationSection for ExperimentalConfig {
const PATH: Option<&'static str> = Some("experimental");
}
/// Configuration options for the session limit feature
#[derive(Clone, Debug, Deserialize, JsonSchema, Serialize)]
pub struct SessionLimitConfig {
pub soft_limit: NonZeroU64,
pub hard_limit: NonZeroU64,
}

View file

@ -0,0 +1,473 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
#![allow(deprecated)]
use std::borrow::Cow;
use anyhow::bail;
use camino::Utf8PathBuf;
use ipnetwork::IpNetwork;
use mas_keystore::PrivateKey;
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, pem::PemObject};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use url::Url;
use super::ConfigurationSection;
fn default_public_base() -> Url {
"http://[::]:8080".parse().unwrap()
}
#[cfg(not(any(feature = "docker", feature = "dist")))]
fn http_listener_assets_path_default() -> Utf8PathBuf {
"./frontend/dist/".into()
}
#[cfg(feature = "docker")]
fn http_listener_assets_path_default() -> Utf8PathBuf {
"/usr/local/share/mas-cli/assets/".into()
}
#[cfg(feature = "dist")]
fn http_listener_assets_path_default() -> Utf8PathBuf {
"./share/assets/".into()
}
fn is_default_http_listener_assets_path(value: &Utf8PathBuf) -> bool {
*value == http_listener_assets_path_default()
}
fn default_trusted_proxies() -> Vec<IpNetwork> {
vec![
IpNetwork::new([192, 168, 0, 0].into(), 16).unwrap(),
IpNetwork::new([172, 16, 0, 0].into(), 12).unwrap(),
IpNetwork::new([10, 0, 0, 0].into(), 10).unwrap(),
IpNetwork::new(std::net::Ipv4Addr::LOCALHOST.into(), 8).unwrap(),
IpNetwork::new([0xfd00, 0, 0, 0, 0, 0, 0, 0].into(), 8).unwrap(),
IpNetwork::new(std::net::Ipv6Addr::LOCALHOST.into(), 128).unwrap(),
]
}
/// Kind of socket
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum UnixOrTcp {
/// UNIX domain socket
Unix,
/// TCP socket
Tcp,
}
impl UnixOrTcp {
/// UNIX domain socket
#[must_use]
pub const fn unix() -> Self {
Self::Unix
}
/// TCP socket
#[must_use]
pub const fn tcp() -> Self {
Self::Tcp
}
}
/// Configuration of a single listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(untagged)]
pub enum BindConfig {
/// Listen on the specified host and port
Listen {
/// Host on which to listen.
///
/// Defaults to listening on all addresses
#[serde(skip_serializing_if = "Option::is_none")]
host: Option<String>,
/// Port on which to listen.
port: u16,
},
/// Listen on the specified address
Address {
/// Host and port on which to listen
#[schemars(
example = &"[::1]:8080",
example = &"[::]:8080",
example = &"127.0.0.1:8080",
example = &"0.0.0.0:8080",
)]
address: String,
},
/// Listen on a UNIX domain socket
Unix {
/// Path to the socket
#[schemars(with = "String")]
socket: Utf8PathBuf,
},
/// Accept connections on file descriptors passed by the parent process.
///
/// This is useful for grabbing sockets passed by systemd.
///
/// See <https://www.freedesktop.org/software/systemd/man/sd_listen_fds.html>
FileDescriptor {
/// Index of the file descriptor. Note that this is offseted by 3
/// because of the standard input/output sockets, so setting
/// here a value of `0` will grab the file descriptor `3`
#[serde(default)]
fd: usize,
/// Whether the socket is a TCP socket or a UNIX domain socket. Defaults
/// to TCP.
#[serde(default = "UnixOrTcp::tcp")]
kind: UnixOrTcp,
},
}
/// Configuration related to TLS on a listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TlsConfig {
/// PEM-encoded X509 certificate chain
///
/// Exactly one of `certificate` or `certificate_file` must be set.
#[serde(skip_serializing_if = "Option::is_none")]
pub certificate: Option<String>,
/// File containing the PEM-encoded X509 certificate chain
///
/// Exactly one of `certificate` or `certificate_file` must be set.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub certificate_file: Option<Utf8PathBuf>,
/// PEM-encoded private key
///
/// Exactly one of `key` or `key_file` must be set.
#[serde(skip_serializing_if = "Option::is_none")]
pub key: Option<String>,
/// File containing a PEM or DER-encoded private key
///
/// Exactly one of `key` or `key_file` must be set.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub key_file: Option<Utf8PathBuf>,
/// Password used to decode the private key
///
/// One of `password` or `password_file` must be set if the key is
/// encrypted.
#[serde(skip_serializing_if = "Option::is_none")]
pub password: Option<String>,
/// Password file used to decode the private key
///
/// One of `password` or `password_file` must be set if the key is
/// encrypted.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub password_file: Option<Utf8PathBuf>,
}
impl TlsConfig {
/// Load the TLS certificate chain and key file from disk
///
/// # Errors
///
/// Returns an error if an error was encountered either while:
/// - reading the certificate, key or password files
/// - decoding the key as PEM or DER
/// - decrypting the key if encrypted
/// - a password was provided but the key was not encrypted
/// - decoding the certificate chain as PEM
/// - the certificate chain is empty
pub fn load(
&self,
) -> Result<(PrivateKeyDer<'static>, Vec<CertificateDer<'static>>), anyhow::Error> {
let password = match (&self.password, &self.password_file) {
(None, None) => None,
(Some(_), Some(_)) => {
bail!("Only one of `password` or `password_file` can be set at a time")
}
(Some(password), None) => Some(Cow::Borrowed(password)),
(None, Some(path)) => Some(Cow::Owned(std::fs::read_to_string(path)?)),
};
// Read the key either embedded in the config file or on disk
let key = match (&self.key, &self.key_file) {
(None, None) => bail!("Either `key` or `key_file` must be set"),
(Some(_), Some(_)) => bail!("Only one of `key` or `key_file` can be set at a time"),
(Some(key), None) => {
// If the key was embedded in the config file, assume it is formatted as PEM
if let Some(password) = password {
PrivateKey::load_encrypted_pem(key, password.as_bytes())?
} else {
PrivateKey::load_pem(key)?
}
}
(None, Some(path)) => {
// When reading from disk, it might be either PEM or DER. `PrivateKey::load*`
// will try both.
let key = std::fs::read(path)?;
if let Some(password) = password {
PrivateKey::load_encrypted(&key, password.as_bytes())?
} else {
PrivateKey::load(&key)?
}
}
};
// Re-serialize the key to PKCS#8 DER, so rustls can consume it
let key = key.to_pkcs8_der()?;
let key = PrivatePkcs8KeyDer::from(key.to_vec()).into();
let certificate_chain_pem = match (&self.certificate, &self.certificate_file) {
(None, None) => bail!("Either `certificate` or `certificate_file` must be set"),
(Some(_), Some(_)) => {
bail!("Only one of `certificate` or `certificate_file` can be set at a time")
}
(Some(certificate), None) => Cow::Borrowed(certificate),
(None, Some(path)) => Cow::Owned(std::fs::read_to_string(path)?),
};
let certificate_chain = CertificateDer::pem_slice_iter(certificate_chain_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
if certificate_chain.is_empty() {
bail!("TLS certificate chain is empty (or invalid)")
}
Ok((key, certificate_chain))
}
}
/// HTTP resources to mount
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
#[serde(tag = "name", rename_all = "lowercase")]
pub enum Resource {
/// Healthcheck endpoint (/health)
Health,
/// Prometheus metrics endpoint (/metrics)
Prometheus,
/// OIDC discovery endpoints
Discovery,
/// Pages destined to be viewed by humans
Human,
/// GraphQL endpoint
GraphQL {
/// Enabled the GraphQL playground
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
playground: bool,
/// Allow access for OAuth 2.0 clients (undocumented)
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
undocumented_oauth2_access: bool,
},
/// OAuth-related APIs
OAuth,
/// Matrix compatibility API
Compat,
/// Static files
Assets {
/// Path to the directory to serve.
#[serde(
default = "http_listener_assets_path_default",
skip_serializing_if = "is_default_http_listener_assets_path"
)]
#[schemars(with = "String")]
path: Utf8PathBuf,
},
/// Admin API, served at `/api/admin/v1`
AdminApi,
/// Mount a "/connection-info" handler which helps debugging informations on
/// the upstream connection
#[serde(rename = "connection-info")]
ConnectionInfo,
}
/// Configuration of a listener
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct ListenerConfig {
/// A unique name for this listener which will be shown in traces and in
/// metrics labels
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
/// List of resources to mount
pub resources: Vec<Resource>,
/// HTTP prefix to mount the resources on
#[serde(skip_serializing_if = "Option::is_none")]
pub prefix: Option<String>,
/// List of sockets to bind
pub binds: Vec<BindConfig>,
/// Accept `HAProxy`'s Proxy Protocol V1
#[serde(default)]
pub proxy_protocol: bool,
/// If set, makes the listener use TLS with the provided certificate and key
#[serde(skip_serializing_if = "Option::is_none")]
pub tls: Option<TlsConfig>,
}
/// Configuration related to the web server
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct HttpConfig {
/// List of listeners to run
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
/// List of trusted reverse proxies that can set the `X-Forwarded-For`
/// header
#[serde(default = "default_trusted_proxies")]
#[schemars(with = "Vec<String>", inner(ip))]
pub trusted_proxies: Vec<IpNetwork>,
/// Public URL base from where the authentication service is reachable
pub public_base: Url,
/// OIDC issuer URL. Defaults to `public_base` if not set.
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<Url>,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
listeners: vec![
ListenerConfig {
name: Some("web".to_owned()),
resources: vec![
Resource::Discovery,
Resource::Human,
Resource::OAuth,
Resource::Compat,
Resource::GraphQL {
playground: false,
undocumented_oauth2_access: false,
},
Resource::Assets {
path: http_listener_assets_path_default(),
},
],
prefix: None,
tls: None,
proxy_protocol: false,
binds: vec![BindConfig::Address {
address: "[::]:8080".into(),
}],
},
ListenerConfig {
name: Some("internal".to_owned()),
resources: vec![Resource::Health],
prefix: None,
tls: None,
proxy_protocol: false,
binds: vec![BindConfig::Listen {
host: Some("localhost".to_owned()),
port: 8081,
}],
},
],
trusted_proxies: default_trusted_proxies(),
issuer: Some(default_public_base()),
public_base: default_public_base(),
}
}
}
impl ConfigurationSection for HttpConfig {
const PATH: Option<&'static str> = Some("http");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, listener) in self.listeners.iter().enumerate() {
let annotate = |mut error: figment::Error| {
error.metadata = figment
.find_metadata(&format!("{root}.listeners", root = Self::PATH.unwrap()))
.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![
Self::PATH.unwrap().to_owned(),
"listeners".to_owned(),
index.to_string(),
];
error
};
if listener.resources.is_empty() {
return Err(
annotate(figment::Error::from("listener has no resources".to_owned())).into(),
);
}
if listener.binds.is_empty() {
return Err(annotate(figment::Error::from(
"listener does not bind to any address".to_owned(),
))
.into());
}
if let Some(tls_config) = &listener.tls {
if tls_config.certificate.is_some() && tls_config.certificate_file.is_some() {
return Err(annotate(figment::Error::from(
"Only one of `certificate` or `certificate_file` can be set at a time"
.to_owned(),
))
.into());
}
if tls_config.certificate.is_none() && tls_config.certificate_file.is_none() {
return Err(annotate(figment::Error::from(
"TLS configuration is missing a certificate".to_owned(),
))
.into());
}
if tls_config.key.is_some() && tls_config.key_file.is_some() {
return Err(annotate(figment::Error::from(
"Only one of `key` or `key_file` can be set at a time".to_owned(),
))
.into());
}
if tls_config.key.is_none() && tls_config.key_file.is_none() {
return Err(annotate(figment::Error::from(
"TLS configuration is missing a private key".to_owned(),
))
.into());
}
if tls_config.password.is_some() && tls_config.password_file.is_some() {
return Err(annotate(figment::Error::from(
"Only one of `password` or `password_file` can be set at a time".to_owned(),
))
.into());
}
}
}
Ok(())
}
}

View file

@ -0,0 +1,235 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use anyhow::bail;
use camino::Utf8PathBuf;
use rand::{
Rng,
distributions::{Alphanumeric, DistString},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use url::Url;
use super::ConfigurationSection;
fn default_homeserver() -> String {
"localhost:8008".to_owned()
}
fn default_endpoint() -> Url {
Url::parse("http://localhost:8008/").unwrap()
}
/// The kind of homeserver it is.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum HomeserverKind {
/// Homeserver is Synapse, version 1.135.0 or newer
#[default]
Synapse,
/// Homeserver is Synapse, version 1.135.0 or newer, in read-only mode
///
/// This is meant for testing rolling out Matrix Authentication Service with
/// no risk of writing data to the homeserver.
SynapseReadOnly,
/// Homeserver is Synapse, using the legacy API
SynapseLegacy,
/// Homeserver is Synapse, with the modern API available (>= 1.135.0)
SynapseModern,
}
/// Shared secret between MAS and the homeserver.
///
/// It either holds the secret value directly or references a file where the
/// secret is stored.
#[derive(Clone, Debug)]
pub enum Secret {
File(Utf8PathBuf),
Value(String),
}
/// Secret fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
struct SecretRaw {
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
secret_file: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
secret: Option<String>,
}
impl TryFrom<SecretRaw> for Secret {
type Error = anyhow::Error;
fn try_from(value: SecretRaw) -> Result<Self, Self::Error> {
match (value.secret, value.secret_file) {
(None, None) => bail!("Missing `secret` or `secret_file`"),
(None, Some(path)) => Ok(Secret::File(path)),
(Some(secret), None) => Ok(Secret::Value(secret)),
(Some(_), Some(_)) => bail!("Cannot specify both `secret` and `secret_file`"),
}
}
}
impl From<Secret> for SecretRaw {
fn from(value: Secret) -> Self {
match value {
Secret::File(path) => SecretRaw {
secret_file: Some(path),
secret: None,
},
Secret::Value(secret) => SecretRaw {
secret_file: None,
secret: Some(secret),
},
}
}
}
/// Configuration related to the Matrix homeserver
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MatrixConfig {
/// The kind of homeserver it is.
#[serde(default)]
pub kind: HomeserverKind,
/// The server name of the homeserver.
#[serde(default = "default_homeserver")]
pub homeserver: String,
/// Shared secret to use for calls to the admin API
#[schemars(with = "SecretRaw")]
#[serde_as(as = "serde_with::TryFromInto<SecretRaw>")]
#[serde(flatten)]
pub secret: Secret,
/// The base URL of the homeserver's client API
#[serde(default = "default_endpoint")]
pub endpoint: Url,
}
impl ConfigurationSection for MatrixConfig {
const PATH: Option<&'static str> = Some("matrix");
}
impl MatrixConfig {
/// Returns the shared secret.
///
/// If `secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the shared secret could not be read from file.
pub async fn secret(&self) -> anyhow::Result<String> {
Ok(match &self.secret {
Secret::File(path) => {
let raw = tokio::fs::read_to_string(path).await?;
// Trim the secret when read from file to match Synapse's behaviour
raw.trim().to_string()
}
Secret::Value(secret) => secret.clone(),
})
}
pub(crate) fn generate<R>(mut rng: R) -> Self
where
R: Rng + Send,
{
Self {
kind: HomeserverKind::default(),
homeserver: default_homeserver(),
secret: Secret::Value(Alphanumeric.sample_string(&mut rng, 32)),
endpoint: default_endpoint(),
}
}
pub(crate) fn test() -> Self {
Self {
kind: HomeserverKind::default(),
homeserver: default_homeserver(),
secret: Secret::Value("test".to_owned()),
endpoint: default_endpoint(),
}
}
}
#[cfg(test)]
mod tests {
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use tokio::{runtime::Handle, task};
use super::*;
#[tokio::test]
async fn load_config() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r"
matrix:
homeserver: matrix.org
secret_file: secret
",
)?;
jail.create_file("secret", r"m472!x53c237")?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<MatrixConfig>("matrix")?;
Handle::current().block_on(async move {
assert_eq!(&config.homeserver, "matrix.org");
assert!(matches!(config.secret, Secret::File(ref p) if p == "secret"));
assert_eq!(config.secret().await.unwrap(), "m472!x53c237");
});
Ok(())
});
})
.await
.unwrap();
}
#[tokio::test]
async fn load_config_inline_secrets() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r"
matrix:
homeserver: matrix.org
secret: m472!x53c237
",
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<MatrixConfig>("matrix")?;
Handle::current().block_on(async move {
assert_eq!(&config.homeserver, "matrix.org");
assert!(matches!(config.secret, Secret::Value(ref v) if v == "m472!x53c237"));
assert_eq!(config.secret().await.unwrap(), "m472!x53c237");
});
Ok(())
});
})
.await
.unwrap();
}
}

View file

@ -0,0 +1,386 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use anyhow::bail;
use camino::Utf8PathBuf;
use rand::Rng;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
mod account;
mod branding;
mod captcha;
mod clients;
mod database;
mod email;
mod experimental;
mod http;
mod matrix;
mod passwords;
mod policy;
mod rate_limiting;
mod secrets;
mod telemetry;
mod templates;
mod upstream_oauth2;
pub use self::{
account::AccountConfig,
branding::BrandingConfig,
captcha::{CaptchaConfig, CaptchaServiceKind},
clients::{ClientAuthMethodConfig, ClientConfig, ClientsConfig},
database::{DatabaseConfig, PgSslMode},
email::{EmailConfig, EmailSmtpMode, EmailTransportKind},
experimental::ExperimentalConfig,
http::{
BindConfig as HttpBindConfig, HttpConfig, ListenerConfig as HttpListenerConfig,
Resource as HttpResource, TlsConfig as HttpTlsConfig, UnixOrTcp,
},
matrix::{HomeserverKind, MatrixConfig},
passwords::{
Algorithm as PasswordAlgorithm, HashingScheme as PasswordHashingScheme, PasswordsConfig,
},
policy::PolicyConfig,
rate_limiting::RateLimitingConfig,
secrets::SecretsConfig,
telemetry::{
MetricsConfig, MetricsExporterKind, Propagator, TelemetryConfig, TracingConfig,
TracingExporterKind,
},
templates::TemplatesConfig,
upstream_oauth2::{
ClaimsImports as UpstreamOAuth2ClaimsImports, DiscoveryMode as UpstreamOAuth2DiscoveryMode,
EmailImportPreference as UpstreamOAuth2EmailImportPreference,
ImportAction as UpstreamOAuth2ImportAction,
OnBackchannelLogout as UpstreamOAuth2OnBackchannelLogout,
OnConflict as UpstreamOAuth2OnConflict, PkceMethod as UpstreamOAuth2PkceMethod,
Provider as UpstreamOAuth2Provider, ResponseMode as UpstreamOAuth2ResponseMode,
TokenAuthMethod as UpstreamOAuth2TokenAuthMethod, UpstreamOAuth2Config,
},
};
use crate::util::ConfigurationSection;
/// Application configuration root
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RootConfig {
/// List of OAuth 2.0/OIDC clients config
#[serde(default, skip_serializing_if = "ClientsConfig::is_default")]
pub clients: ClientsConfig,
/// Configuration of the HTTP server
#[serde(default)]
pub http: HttpConfig,
/// Database connection configuration
#[serde(default)]
pub database: DatabaseConfig,
/// Configuration related to sending monitoring data
#[serde(default, skip_serializing_if = "TelemetryConfig::is_default")]
pub telemetry: TelemetryConfig,
/// Configuration related to templates
#[serde(default, skip_serializing_if = "TemplatesConfig::is_default")]
pub templates: TemplatesConfig,
/// Configuration related to sending emails
#[serde(default)]
pub email: EmailConfig,
/// Application secrets
pub secrets: SecretsConfig,
/// Configuration related to user passwords
#[serde(default)]
pub passwords: PasswordsConfig,
/// Configuration related to the homeserver
pub matrix: MatrixConfig,
/// Configuration related to the OPA policies
#[serde(default, skip_serializing_if = "PolicyConfig::is_default")]
pub policy: PolicyConfig,
/// Configuration related to limiting the rate of user actions to prevent
/// abuse
#[serde(default, skip_serializing_if = "RateLimitingConfig::is_default")]
pub rate_limiting: RateLimitingConfig,
/// Configuration related to upstream OAuth providers
#[serde(default, skip_serializing_if = "UpstreamOAuth2Config::is_default")]
pub upstream_oauth2: UpstreamOAuth2Config,
/// Configuration section for tweaking the branding of the service
#[serde(default, skip_serializing_if = "BrandingConfig::is_default")]
pub branding: BrandingConfig,
/// Configuration section to setup CAPTCHA protection on a few operations
#[serde(default, skip_serializing_if = "CaptchaConfig::is_default")]
pub captcha: CaptchaConfig,
/// Configuration section to configure features related to account
/// management
#[serde(default, skip_serializing_if = "AccountConfig::is_default")]
pub account: AccountConfig,
/// Experimental configuration options
#[serde(default, skip_serializing_if = "ExperimentalConfig::is_default")]
pub experimental: ExperimentalConfig,
}
impl ConfigurationSection for RootConfig {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.clients.validate(figment)?;
self.http.validate(figment)?;
self.database.validate(figment)?;
self.telemetry.validate(figment)?;
self.templates.validate(figment)?;
self.email.validate(figment)?;
self.passwords.validate(figment)?;
self.secrets.validate(figment)?;
self.matrix.validate(figment)?;
self.policy.validate(figment)?;
self.rate_limiting.validate(figment)?;
self.upstream_oauth2.validate(figment)?;
self.branding.validate(figment)?;
self.captcha.validate(figment)?;
self.account.validate(figment)?;
self.experimental.validate(figment)?;
Ok(())
}
}
impl RootConfig {
/// Generate a new configuration with random secrets
///
/// # Errors
///
/// Returns an error if the secrets could not be generated
pub async fn generate<R>(mut rng: R) -> anyhow::Result<Self>
where
R: Rng + Send,
{
Ok(Self {
clients: ClientsConfig::default(),
http: HttpConfig::default(),
database: DatabaseConfig::default(),
telemetry: TelemetryConfig::default(),
templates: TemplatesConfig::default(),
email: EmailConfig::default(),
passwords: PasswordsConfig::default(),
secrets: SecretsConfig::generate(&mut rng).await?,
matrix: MatrixConfig::generate(&mut rng),
policy: PolicyConfig::default(),
rate_limiting: RateLimitingConfig::default(),
upstream_oauth2: UpstreamOAuth2Config::default(),
branding: BrandingConfig::default(),
captcha: CaptchaConfig::default(),
account: AccountConfig::default(),
experimental: ExperimentalConfig::default(),
})
}
/// Configuration used in tests
#[must_use]
pub fn test() -> Self {
Self {
clients: ClientsConfig::default(),
http: HttpConfig::default(),
database: DatabaseConfig::default(),
telemetry: TelemetryConfig::default(),
templates: TemplatesConfig::default(),
passwords: PasswordsConfig::default(),
email: EmailConfig::default(),
secrets: SecretsConfig::test(),
matrix: MatrixConfig::test(),
policy: PolicyConfig::default(),
rate_limiting: RateLimitingConfig::default(),
upstream_oauth2: UpstreamOAuth2Config::default(),
branding: BrandingConfig::default(),
captcha: CaptchaConfig::default(),
account: AccountConfig::default(),
experimental: ExperimentalConfig::default(),
}
}
}
/// Partial configuration actually used by the server
#[allow(missing_docs)]
#[derive(Debug, Deserialize)]
pub struct AppConfig {
#[serde(default)]
pub http: HttpConfig,
#[serde(default)]
pub database: DatabaseConfig,
#[serde(default)]
pub templates: TemplatesConfig,
#[serde(default)]
pub email: EmailConfig,
pub secrets: SecretsConfig,
#[serde(default)]
pub passwords: PasswordsConfig,
pub matrix: MatrixConfig,
#[serde(default)]
pub policy: PolicyConfig,
#[serde(default)]
pub rate_limiting: RateLimitingConfig,
#[serde(default)]
pub branding: BrandingConfig,
#[serde(default)]
pub captcha: CaptchaConfig,
#[serde(default)]
pub account: AccountConfig,
#[serde(default)]
pub experimental: ExperimentalConfig,
}
impl ConfigurationSection for AppConfig {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.http.validate(figment)?;
self.database.validate(figment)?;
self.templates.validate(figment)?;
self.email.validate(figment)?;
self.passwords.validate(figment)?;
self.secrets.validate(figment)?;
self.matrix.validate(figment)?;
self.policy.validate(figment)?;
self.rate_limiting.validate(figment)?;
self.branding.validate(figment)?;
self.captcha.validate(figment)?;
self.account.validate(figment)?;
self.experimental.validate(figment)?;
Ok(())
}
}
/// Partial config used by the `mas-cli config sync` command
#[allow(missing_docs)]
#[derive(Debug, Deserialize)]
pub struct SyncConfig {
#[serde(default)]
pub database: DatabaseConfig,
pub secrets: SecretsConfig,
#[serde(default)]
pub clients: ClientsConfig,
#[serde(default)]
pub upstream_oauth2: UpstreamOAuth2Config,
}
impl ConfigurationSection for SyncConfig {
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.database.validate(figment)?;
self.secrets.validate(figment)?;
self.clients.validate(figment)?;
self.upstream_oauth2.validate(figment)?;
Ok(())
}
}
/// Client secret config option.
///
/// It either holds the client secret value directly or references a file where
/// the client secret is stored.
#[derive(Clone, Debug)]
pub enum ClientSecret {
/// Path to the file containing the client secret.
File(Utf8PathBuf),
/// Client secret value.
Value(String),
}
/// Client secret fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
pub struct ClientSecretRaw {
/// Path to the file containing the client secret. The client secret is used
/// by the `client_secret_basic`, `client_secret_post` and
/// `client_secret_jwt` authentication methods.
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
client_secret_file: Option<Utf8PathBuf>,
/// Alternative to `client_secret_file`: Reads the client secret directly
/// from the config.
#[serde(skip_serializing_if = "Option::is_none")]
client_secret: Option<String>,
}
impl ClientSecret {
/// Returns the client secret.
///
/// If `client_secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the client secret could not be read from file.
pub async fn value(&self) -> anyhow::Result<String> {
Ok(match self {
ClientSecret::File(path) => tokio::fs::read_to_string(path).await?,
ClientSecret::Value(client_secret) => client_secret.clone(),
})
}
}
impl TryFrom<ClientSecretRaw> for Option<ClientSecret> {
type Error = anyhow::Error;
fn try_from(value: ClientSecretRaw) -> Result<Self, Self::Error> {
match (value.client_secret, value.client_secret_file) {
(None, None) => Ok(None),
(None, Some(path)) => Ok(Some(ClientSecret::File(path))),
(Some(client_secret), None) => Ok(Some(ClientSecret::Value(client_secret))),
(Some(_), Some(_)) => {
bail!("Cannot specify both `client_secret` and `client_secret_file`")
}
}
}
}
impl From<Option<ClientSecret>> for ClientSecretRaw {
fn from(value: Option<ClientSecret>) -> Self {
match value {
Some(ClientSecret::File(path)) => ClientSecretRaw {
client_secret_file: Some(path),
client_secret: None,
},
Some(ClientSecret::Value(client_secret)) => ClientSecretRaw {
client_secret_file: None,
client_secret: Some(client_secret),
},
None => ClientSecretRaw {
client_secret_file: None,
client_secret: None,
},
}
}
}

View file

@ -0,0 +1,230 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::cmp::Reverse;
use anyhow::bail;
use camino::Utf8PathBuf;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::ConfigurationSection;
fn default_schemes() -> Vec<HashingScheme> {
vec![HashingScheme {
version: 1,
algorithm: Algorithm::default(),
cost: None,
secret: None,
secret_file: None,
unicode_normalization: false,
}]
}
fn default_enabled() -> bool {
true
}
fn default_minimum_complexity() -> u8 {
3
}
/// User password hashing config
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct PasswordsConfig {
/// Whether password-based authentication is enabled
#[serde(default = "default_enabled")]
pub enabled: bool,
/// The hashing schemes to use for hashing and validating passwords
///
/// The hashing scheme with the highest version number will be used for
/// hashing new passwords.
#[serde(default = "default_schemes")]
pub schemes: Vec<HashingScheme>,
/// Score between 0 and 4 determining the minimum allowed password
/// complexity. Scores are based on the ESTIMATED number of guesses
/// needed to guess the password.
///
/// - 0: less than 10^2 (100)
/// - 1: less than 10^4 (10'000)
/// - 2: less than 10^6 (1'000'000)
/// - 3: less than 10^8 (100'000'000)
/// - 4: any more than that
#[serde(default = "default_minimum_complexity")]
minimum_complexity: u8,
}
impl Default for PasswordsConfig {
fn default() -> Self {
Self {
enabled: default_enabled(),
schemes: default_schemes(),
minimum_complexity: default_minimum_complexity(),
}
}
}
impl ConfigurationSection for PasswordsConfig {
const PATH: Option<&'static str> = Some("passwords");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let annotate = |mut error: figment::Error| {
error.metadata = figment.find_metadata(Self::PATH.unwrap()).cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned()];
error
};
if !self.enabled {
// Skip validation if password-based authentication is disabled
return Ok(());
}
if self.schemes.is_empty() {
return Err(annotate(figment::Error::from(
"Requires at least one password scheme in the config".to_owned(),
))
.into());
}
for scheme in &self.schemes {
if scheme.secret.is_some() && scheme.secret_file.is_some() {
return Err(annotate(figment::Error::from(
"Cannot specify both `secret` and `secret_file`".to_owned(),
))
.into());
}
}
Ok(())
}
}
impl PasswordsConfig {
/// Whether password-based authentication is enabled
#[must_use]
pub fn enabled(&self) -> bool {
self.enabled
}
/// Minimum complexity of passwords, from 0 to 4, according to the zxcvbn
/// scorer.
#[must_use]
pub fn minimum_complexity(&self) -> u8 {
self.minimum_complexity
}
/// Load the password hashing schemes defined by the config
///
/// # Errors
///
/// Returns an error if the config is invalid, or if the secret file could
/// not be read.
pub async fn load(
&self,
) -> Result<Vec<(u16, Algorithm, Option<u32>, Option<Vec<u8>>, bool)>, anyhow::Error> {
let mut schemes: Vec<&HashingScheme> = self.schemes.iter().collect();
schemes.sort_unstable_by_key(|a| Reverse(a.version));
schemes.dedup_by_key(|a| a.version);
if schemes.len() != self.schemes.len() {
// Some schemes had duplicated versions
bail!("Multiple password schemes have the same versions");
}
if schemes.is_empty() {
bail!("Requires at least one password scheme in the config");
}
let mut mapped_result = Vec::with_capacity(schemes.len());
for scheme in schemes {
let secret = match (&scheme.secret, &scheme.secret_file) {
(Some(secret), None) => Some(secret.clone().into_bytes()),
(None, Some(secret_file)) => {
let secret = tokio::fs::read(secret_file).await?;
Some(secret)
}
(Some(_), Some(_)) => bail!("Cannot specify both `secret` and `secret_file`"),
(None, None) => None,
};
mapped_result.push((
scheme.version,
scheme.algorithm,
scheme.cost,
secret,
scheme.unicode_normalization,
));
}
Ok(mapped_result)
}
}
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default_false(value: &bool) -> bool {
!*value
}
/// Parameters for a password hashing scheme
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct HashingScheme {
/// The version of the hashing scheme. They must be unique, and the highest
/// version will be used for hashing new passwords.
pub version: u16,
/// The hashing algorithm to use
pub algorithm: Algorithm,
/// Whether to apply Unicode normalization to the password before hashing
///
/// Defaults to `false`, and generally recommended to stay false. This is
/// although recommended when importing password hashs from Synapse, as it
/// applies an NFKC normalization to the password before hashing it.
#[serde(default, skip_serializing_if = "is_default_false")]
pub unicode_normalization: bool,
/// Cost for the bcrypt algorithm
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(default = "default_bcrypt_cost")]
pub cost: Option<u32>,
/// An optional secret to use when hashing passwords. This makes it harder
/// to brute-force the passwords in case of a database leak.
#[serde(skip_serializing_if = "Option::is_none")]
pub secret: Option<String>,
/// Same as `secret`, but read from a file.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub secret_file: Option<Utf8PathBuf>,
}
#[allow(clippy::unnecessary_wraps)]
fn default_bcrypt_cost() -> Option<u32> {
Some(12)
}
/// A hashing algorithm
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum Algorithm {
/// bcrypt
Bcrypt,
/// argon2id
#[default]
Argon2id,
/// PBKDF2
Pbkdf2,
}

View file

@ -0,0 +1,178 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use camino::Utf8PathBuf;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
#[cfg(not(any(feature = "docker", feature = "dist")))]
fn default_policy_path() -> Utf8PathBuf {
"./policies/policy.wasm".into()
}
#[cfg(feature = "docker")]
fn default_policy_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/policy.wasm".into()
}
#[cfg(feature = "dist")]
fn default_policy_path() -> Utf8PathBuf {
"./share/policy.wasm".into()
}
fn is_default_policy_path(value: &Utf8PathBuf) -> bool {
*value == default_policy_path()
}
fn default_client_registration_entrypoint() -> String {
"client_registration/violation".to_owned()
}
fn is_default_client_registration_entrypoint(value: &String) -> bool {
*value == default_client_registration_entrypoint()
}
fn default_register_entrypoint() -> String {
"register/violation".to_owned()
}
fn is_default_register_entrypoint(value: &String) -> bool {
*value == default_register_entrypoint()
}
fn default_authorization_grant_entrypoint() -> String {
"authorization_grant/violation".to_owned()
}
fn is_default_authorization_grant_entrypoint(value: &String) -> bool {
*value == default_authorization_grant_entrypoint()
}
fn default_password_entrypoint() -> String {
"password/violation".to_owned()
}
fn is_default_password_entrypoint(value: &String) -> bool {
*value == default_password_entrypoint()
}
fn default_compat_login_entrypoint() -> String {
"compat_login/violation".to_owned()
}
fn is_default_compat_login_entrypoint(value: &String) -> bool {
*value == default_compat_login_entrypoint()
}
fn default_email_entrypoint() -> String {
"email/violation".to_owned()
}
fn is_default_email_entrypoint(value: &String) -> bool {
*value == default_email_entrypoint()
}
fn default_data() -> serde_json::Value {
serde_json::json!({})
}
fn is_default_data(value: &serde_json::Value) -> bool {
*value == default_data()
}
/// Application secrets
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct PolicyConfig {
/// Path to the WASM module
#[serde(
default = "default_policy_path",
skip_serializing_if = "is_default_policy_path"
)]
#[schemars(with = "String")]
pub wasm_module: Utf8PathBuf,
/// Entrypoint to use when evaluating client registrations
#[serde(
default = "default_client_registration_entrypoint",
skip_serializing_if = "is_default_client_registration_entrypoint"
)]
pub client_registration_entrypoint: String,
/// Entrypoint to use when evaluating user registrations
#[serde(
default = "default_register_entrypoint",
skip_serializing_if = "is_default_register_entrypoint"
)]
pub register_entrypoint: String,
/// Entrypoint to use when evaluating authorization grants
#[serde(
default = "default_authorization_grant_entrypoint",
skip_serializing_if = "is_default_authorization_grant_entrypoint"
)]
pub authorization_grant_entrypoint: String,
/// Entrypoint to use when evaluating compatibility logins
#[serde(
default = "default_compat_login_entrypoint",
skip_serializing_if = "is_default_compat_login_entrypoint"
)]
pub compat_login_entrypoint: String,
/// Entrypoint to use when changing password
#[serde(
default = "default_password_entrypoint",
skip_serializing_if = "is_default_password_entrypoint"
)]
pub password_entrypoint: String,
/// Entrypoint to use when adding an email address
#[serde(
default = "default_email_entrypoint",
skip_serializing_if = "is_default_email_entrypoint"
)]
pub email_entrypoint: String,
/// Arbitrary data to pass to the policy
#[serde(default = "default_data", skip_serializing_if = "is_default_data")]
pub data: serde_json::Value,
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
wasm_module: default_policy_path(),
client_registration_entrypoint: default_client_registration_entrypoint(),
register_entrypoint: default_register_entrypoint(),
authorization_grant_entrypoint: default_authorization_grant_entrypoint(),
compat_login_entrypoint: default_compat_login_entrypoint(),
password_entrypoint: default_password_entrypoint(),
email_entrypoint: default_email_entrypoint(),
data: default_data(),
}
}
}
impl PolicyConfig {
/// Returns true if the configuration is the default one
pub(crate) fn is_default(&self) -> bool {
is_default_policy_path(&self.wasm_module)
&& is_default_client_registration_entrypoint(&self.client_registration_entrypoint)
&& is_default_register_entrypoint(&self.register_entrypoint)
&& is_default_authorization_grant_entrypoint(&self.authorization_grant_entrypoint)
&& is_default_password_entrypoint(&self.password_entrypoint)
&& is_default_email_entrypoint(&self.email_entrypoint)
&& is_default_data(&self.data)
}
}
impl ConfigurationSection for PolicyConfig {
const PATH: Option<&'static str> = Some("policy");
}

View file

@ -0,0 +1,298 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{num::NonZeroU32, time::Duration};
use governor::Quota;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error as _};
use crate::ConfigurationSection;
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimitingConfig {
/// Account Recovery-specific rate limits
#[serde(default)]
pub account_recovery: AccountRecoveryRateLimitingConfig,
/// Login-specific rate limits
#[serde(default)]
pub login: LoginRateLimitingConfig,
/// Controls how many registrations attempts are permitted
/// based on source address.
#[serde(default = "default_registration")]
pub registration: RateLimiterConfiguration,
/// Email authentication-specific rate limits
#[serde(default)]
pub email_authentication: EmailauthenticationRateLimitingConfig,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LoginRateLimitingConfig {
/// Controls how many login attempts are permitted
/// based on source IP address.
/// This can protect against brute force login attempts.
///
/// Note: this limit also applies to password checks when a user attempts to
/// change their own password.
#[serde(default = "default_login_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many login attempts are permitted
/// based on the account that is being attempted to be logged into.
/// This can protect against a distributed brute force attack
/// but should be set high enough to prevent someone's account being
/// casually locked out.
///
/// Note: this limit also applies to password checks when a user attempts to
/// change their own password.
#[serde(default = "default_login_per_account")]
pub per_account: RateLimiterConfiguration,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AccountRecoveryRateLimitingConfig {
/// Controls how many account recovery attempts are permitted
/// based on source IP address.
/// This can protect against causing e-mail spam to many targets.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many account recovery attempts are permitted
/// based on the e-mail address entered into the recovery form.
/// This can protect against causing e-mail spam to one target.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_address")]
pub per_address: RateLimiterConfiguration,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct EmailauthenticationRateLimitingConfig {
/// Controls how many email authentication attempts are permitted
/// based on the source IP address.
/// This can protect against causing e-mail spam to many targets.
#[serde(default = "default_email_authentication_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many email authentication attempts are permitted
/// based on the e-mail address entered into the authentication form.
/// This can protect against causing e-mail spam to one target.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_email_authentication_per_address")]
pub per_address: RateLimiterConfiguration,
/// Controls how many authentication emails are permitted to be sent per
/// authentication session. This ensures not too many authentication codes
/// are created for the same authentication session.
#[serde(default = "default_email_authentication_emails_per_session")]
pub emails_per_session: RateLimiterConfiguration,
/// Controls how many code authentication attempts are permitted per
/// authentication session. This can protect against brute-forcing the
/// code.
#[serde(default = "default_email_authentication_attempt_per_session")]
pub attempt_per_session: RateLimiterConfiguration,
}
#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimiterConfiguration {
/// A one-off burst of actions that the user can perform
/// in one go without waiting.
pub burst: NonZeroU32,
/// How quickly the allowance replenishes, in number of actions per second.
/// Can be fractional to replenish slower.
pub per_second: f64,
}
impl ConfigurationSection for RateLimitingConfig {
const PATH: Option<&'static str> = Some("rate_limiting");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let metadata = figment.find_metadata(Self::PATH.unwrap());
let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};
let error_on_nested_field =
|mut error: figment::error::Error, container: &'static str, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![
Self::PATH.unwrap().to_owned(),
container.to_owned(),
field.to_owned(),
];
error
};
// Check one limiter's configuration for errors
let error_on_limiter =
|limiter: &RateLimiterConfiguration| -> Option<figment::error::Error> {
let recip = limiter.per_second.recip();
// period must be at least 1 nanosecond according to the governor library
if recip < 1.0e-9 || !recip.is_finite() {
return Some(figment::error::Error::custom(
"`per_second` must be a number that is more than zero and less than 1_000_000_000 (1e9)",
));
}
None
};
if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
return Err(error_on_nested_field(error, "account_recovery", "per_ip").into());
}
if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
return Err(error_on_nested_field(error, "account_recovery", "per_address").into());
}
if let Some(error) = error_on_limiter(&self.registration) {
return Err(error_on_field(error, "registration").into());
}
if let Some(error) = error_on_limiter(&self.login.per_ip) {
return Err(error_on_nested_field(error, "login", "per_ip").into());
}
if let Some(error) = error_on_limiter(&self.login.per_account) {
return Err(error_on_nested_field(error, "login", "per_account").into());
}
Ok(())
}
}
impl RateLimitingConfig {
pub(crate) fn is_default(config: &RateLimitingConfig) -> bool {
config == &RateLimitingConfig::default()
}
}
impl RateLimiterConfiguration {
pub fn to_quota(self) -> Option<Quota> {
let reciprocal = self.per_second.recip();
if !reciprocal.is_finite() {
return None;
}
Some(Quota::with_period(Duration::from_secs_f64(reciprocal))?.allow_burst(self.burst))
}
}
fn default_login_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 60.0,
}
}
fn default_login_per_account() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(1800).unwrap(),
per_second: 1800.0 / 3600.0,
}
}
fn default_registration() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}
fn default_account_recovery_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}
fn default_account_recovery_per_address() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 1.0 / 3600.0,
}
}
fn default_email_authentication_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(5).unwrap(),
per_second: 1.0 / 60.0,
}
}
fn default_email_authentication_per_address() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 1.0 / 3600.0,
}
}
fn default_email_authentication_emails_per_session() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(2).unwrap(),
per_second: 1.0 / 300.0,
}
}
fn default_email_authentication_attempt_per_session() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(10).unwrap(),
per_second: 1.0 / 60.0,
}
}
impl Default for RateLimitingConfig {
fn default() -> Self {
RateLimitingConfig {
login: LoginRateLimitingConfig::default(),
registration: default_registration(),
account_recovery: AccountRecoveryRateLimitingConfig::default(),
email_authentication: EmailauthenticationRateLimitingConfig::default(),
}
}
}
impl Default for LoginRateLimitingConfig {
fn default() -> Self {
LoginRateLimitingConfig {
per_ip: default_login_per_ip(),
per_account: default_login_per_account(),
}
}
}
impl Default for AccountRecoveryRateLimitingConfig {
fn default() -> Self {
AccountRecoveryRateLimitingConfig {
per_ip: default_account_recovery_per_ip(),
per_address: default_account_recovery_per_address(),
}
}
}
impl Default for EmailauthenticationRateLimitingConfig {
fn default() -> Self {
EmailauthenticationRateLimitingConfig {
per_ip: default_email_authentication_per_ip(),
per_address: default_email_authentication_per_address(),
emails_per_session: default_email_authentication_emails_per_session(),
attempt_per_session: default_email_authentication_attempt_per_session(),
}
}
}

View file

@ -0,0 +1,709 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::borrow::Cow;
use anyhow::{Context, bail};
use camino::Utf8PathBuf;
use futures_util::future::{try_join, try_join_all};
use mas_jose::jwk::{JsonWebKey, JsonWebKeySet, Thumbprint};
use mas_keystore::{Encrypter, Keystore, PrivateKey};
use rand::{Rng, SeedableRng, distributions::Standard, prelude::Distribution as _};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use tokio::task;
use tracing::info;
use super::ConfigurationSection;
/// Password config option.
///
/// It either holds the password value directly or references a file where the
/// password is stored.
#[derive(Clone, Debug)]
pub enum Password {
File(Utf8PathBuf),
Value(String),
}
/// Password fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
struct PasswordRaw {
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
password_file: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
password: Option<String>,
}
impl TryFrom<PasswordRaw> for Option<Password> {
type Error = anyhow::Error;
fn try_from(value: PasswordRaw) -> Result<Self, Self::Error> {
match (value.password, value.password_file) {
(None, None) => Ok(None),
(None, Some(path)) => Ok(Some(Password::File(path))),
(Some(password), None) => Ok(Some(Password::Value(password))),
(Some(_), Some(_)) => bail!("Cannot specify both `password` and `password_file`"),
}
}
}
impl From<Option<Password>> for PasswordRaw {
fn from(value: Option<Password>) -> Self {
match value {
Some(Password::File(path)) => PasswordRaw {
password_file: Some(path),
password: None,
},
Some(Password::Value(password)) => PasswordRaw {
password_file: None,
password: Some(password),
},
None => PasswordRaw {
password_file: None,
password: None,
},
}
}
}
/// Key config option.
///
/// It either holds the key value directly or references a file where the key is
/// stored.
#[derive(Clone, Debug)]
pub enum Key {
File(Utf8PathBuf),
Value(String),
}
/// Key fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
struct KeyRaw {
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
key_file: Option<Utf8PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
key: Option<String>,
}
impl TryFrom<KeyRaw> for Key {
type Error = anyhow::Error;
fn try_from(value: KeyRaw) -> Result<Key, Self::Error> {
match (value.key, value.key_file) {
(None, None) => bail!("Missing `key` or `key_file`"),
(None, Some(path)) => Ok(Key::File(path)),
(Some(key), None) => Ok(Key::Value(key)),
(Some(_), Some(_)) => bail!("Cannot specify both `key` and `key_file`"),
}
}
}
impl From<Key> for KeyRaw {
fn from(value: Key) -> Self {
match value {
Key::File(path) => KeyRaw {
key_file: Some(path),
key: None,
},
Key::Value(key) => KeyRaw {
key_file: None,
key: Some(key),
},
}
}
}
/// A single key with its key ID and optional password.
#[serde_as]
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
pub struct KeyConfig {
/// The key ID `kid` of the key as used by JWKs.
///
/// If not given, `kid` will be the keys RFC 7638 JWK Thumbprint.
#[serde(skip_serializing_if = "Option::is_none")]
kid: Option<String>,
#[schemars(with = "PasswordRaw")]
#[serde_as(as = "serde_with::TryFromInto<PasswordRaw>")]
#[serde(flatten)]
password: Option<Password>,
#[schemars(with = "KeyRaw")]
#[serde_as(as = "serde_with::TryFromInto<KeyRaw>")]
#[serde(flatten)]
key: Key,
}
impl KeyConfig {
/// Returns the password in case any is provided.
///
/// If `password_file` was given, the password is read from that file.
async fn password(&self) -> anyhow::Result<Option<Cow<'_, [u8]>>> {
Ok(match &self.password {
Some(Password::File(path)) => Some(Cow::Owned(tokio::fs::read(path).await?)),
Some(Password::Value(password)) => Some(Cow::Borrowed(password.as_bytes())),
None => None,
})
}
/// Returns the key.
///
/// If `key_file` was given, the key is read from that file.
async fn key(&self) -> anyhow::Result<Cow<'_, [u8]>> {
Ok(match &self.key {
Key::File(path) => Cow::Owned(tokio::fs::read(path).await?),
Key::Value(key) => Cow::Borrowed(key.as_bytes()),
})
}
/// Returns the JSON Web Key derived from this key config.
///
/// Password and/or key are read from file if theyre given as path.
async fn json_web_key(&self) -> anyhow::Result<JsonWebKey<mas_keystore::PrivateKey>> {
let (key, password) = try_join(self.key(), self.password()).await?;
let private_key = match password {
Some(password) => PrivateKey::load_encrypted(&key, password)?,
None => PrivateKey::load(&key)?,
};
let kid = match self.kid.clone() {
Some(kid) => kid,
None => private_key.thumbprint_sha256_base64(),
};
Ok(JsonWebKey::new(private_key)
.with_kid(kid)
.with_use(mas_iana::jose::JsonWebKeyUse::Sig))
}
}
/// Encryption config option.
#[derive(Debug, Clone)]
pub enum Encryption {
File(Utf8PathBuf),
Value([u8; 32]),
}
/// Encryption fields as serialized in JSON.
#[serde_as]
#[derive(JsonSchema, Serialize, Deserialize, Debug, Clone)]
struct EncryptionRaw {
/// File containing the encryption key for secure cookies.
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
encryption_file: Option<Utf8PathBuf>,
/// Encryption key for secure cookies.
#[schemars(
with = "Option<String>",
regex(pattern = r"[0-9a-fA-F]{64}"),
example = &"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff"
)]
#[serde_as(as = "Option<serde_with::hex::Hex>")]
#[serde(skip_serializing_if = "Option::is_none")]
encryption: Option<[u8; 32]>,
}
impl TryFrom<EncryptionRaw> for Encryption {
type Error = anyhow::Error;
fn try_from(value: EncryptionRaw) -> Result<Encryption, Self::Error> {
match (value.encryption, value.encryption_file) {
(None, None) => bail!("Missing `encryption` or `encryption_file`"),
(None, Some(path)) => Ok(Encryption::File(path)),
(Some(encryption), None) => Ok(Encryption::Value(encryption)),
(Some(_), Some(_)) => bail!("Cannot specify both `encryption` and `encryption_file`"),
}
}
}
impl From<Encryption> for EncryptionRaw {
fn from(value: Encryption) -> Self {
match value {
Encryption::File(path) => EncryptionRaw {
encryption_file: Some(path),
encryption: None,
},
Encryption::Value(encryption) => EncryptionRaw {
encryption_file: None,
encryption: Some(encryption),
},
}
}
}
/// Reads all keys from the given directory.
async fn key_configs_from_path(path: &Utf8PathBuf) -> anyhow::Result<Vec<KeyConfig>> {
let mut result = vec![];
let mut read_dir = tokio::fs::read_dir(path).await?;
while let Some(dir_entry) = read_dir.next_entry().await? {
if !dir_entry.path().is_file() {
continue;
}
result.push(KeyConfig {
kid: None,
password: None,
key: Key::File(dir_entry.path().try_into()?),
});
}
Ok(result)
}
/// Application secrets
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SecretsConfig {
/// Encryption key for secure cookies
#[schemars(with = "EncryptionRaw")]
#[serde_as(as = "serde_with::TryFromInto<EncryptionRaw>")]
#[serde(flatten)]
encryption: Encryption,
/// List of private keys to use for signing and encrypting payloads.
#[serde(skip_serializing_if = "Option::is_none")]
keys: Option<Vec<KeyConfig>>,
/// Directory of private keys to use for signing and encrypting payloads.
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
keys_dir: Option<Utf8PathBuf>,
}
impl SecretsConfig {
/// Derive a signing and verifying keystore out of the config
///
/// # Errors
///
/// Returns an error when a key could not be imported
#[tracing::instrument(name = "secrets.load", skip_all)]
pub async fn key_store(&self) -> anyhow::Result<Keystore> {
let key_configs = self.key_configs().await?;
let web_keys = try_join_all(key_configs.iter().map(KeyConfig::json_web_key)).await?;
Ok(Keystore::new(JsonWebKeySet::new(web_keys)))
}
/// Derive an [`Encrypter`] out of the config
///
/// # Errors
///
/// Returns an error when the Encryptor can not be created.
pub async fn encrypter(&self) -> anyhow::Result<Encrypter> {
Ok(Encrypter::new(&self.encryption().await?))
}
/// Returns the encryption secret.
///
/// # Errors
///
/// Returns an error when the encryption secret could not be read from file.
pub async fn encryption(&self) -> anyhow::Result<[u8; 32]> {
// Read the encryption secret either embedded in the config file or on disk
match self.encryption {
Encryption::Value(encryption) => Ok(encryption),
Encryption::File(ref path) => {
let mut bytes = [0; 32];
let content = tokio::fs::read(path).await?;
hex::decode_to_slice(content, &mut bytes).context(
"Content of `encryption_file` must contain hex characters \
encoding exactly 32 bytes",
)?;
Ok(bytes)
}
}
}
/// Returns a combined list of key configs given inline and from files.
///
/// If `keys_dir` was given, the keys are read from file.
async fn key_configs(&self) -> anyhow::Result<Vec<KeyConfig>> {
let mut key_configs = match &self.keys_dir {
Some(keys_dir) => key_configs_from_path(keys_dir).await?,
None => vec![],
};
let inline_key_configs = self.keys.as_deref().unwrap_or_default();
key_configs.extend(inline_key_configs.iter().cloned());
Ok(key_configs)
}
}
impl ConfigurationSection for SecretsConfig {
const PATH: Option<&'static str> = Some("secrets");
}
impl SecretsConfig {
#[expect(clippy::similar_names, reason = "Key type names are very similar")]
#[tracing::instrument(skip_all)]
pub(crate) async fn generate<R>(mut rng: R) -> anyhow::Result<Self>
where
R: Rng + Send,
{
info!("Generating keys...");
let span = tracing::info_span!("rsa");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let rsa_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_rsa(key_rng).unwrap();
info!("Done generating RSA key");
ret
})
.await
.context("could not join blocking task")?;
let rsa_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(rsa_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_p256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p256_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_p256(key_rng);
info!("Done generating EC P-256 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_p256_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(ec_p256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_p384");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_p384_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_p384(key_rng);
info!("Done generating EC P-384 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_p384_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(ec_p384_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
let span = tracing::info_span!("ec_k256");
let key_rng = rand_chacha::ChaChaRng::from_rng(&mut rng)?;
let ec_k256_key = task::spawn_blocking(move || {
let _entered = span.enter();
let ret = PrivateKey::generate_ec_k256(key_rng);
info!("Done generating EC secp256k1 key");
ret
})
.await
.context("could not join blocking task")?;
let ec_k256_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(ec_k256_key.to_pem(pem_rfc7468::LineEnding::LF)?.to_string()),
};
Ok(Self {
encryption: Encryption::Value(Standard.sample(&mut rng)),
keys: Some(vec![rsa_key, ec_p256_key, ec_p384_key, ec_k256_key]),
keys_dir: None,
})
}
pub(crate) fn test() -> Self {
let rsa_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(
indoc::indoc! {r"
-----BEGIN PRIVATE KEY-----
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
Gh7BNzCeN+D6
-----END PRIVATE KEY-----
"}
.to_owned(),
),
};
let ecdsa_key = KeyConfig {
kid: None,
password: None,
key: Key::Value(
indoc::indoc! {r"
-----BEGIN PRIVATE KEY-----
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
-----END PRIVATE KEY-----
"}
.to_owned(),
),
};
Self {
encryption: Encryption::Value([0xEA; 32]),
keys: Some(vec![rsa_key, ecdsa_key]),
keys_dir: None,
}
}
}
#[cfg(test)]
mod tests {
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use mas_jose::constraints::Constrainable;
use tokio::{runtime::Handle, task};
use super::*;
#[tokio::test]
async fn load_config() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
indoc::indoc! {r"
secrets:
encryption_file: encryption
keys_dir: keys
"},
)?;
jail.create_file(
"encryption",
"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff",
)?;
jail.create_dir("keys")?;
jail.create_file(
"keys/key1",
indoc::indoc! {r"
-----BEGIN RSA PRIVATE KEY-----
MIIJKQIBAAKCAgEA6oR6LXzJOziUxcRryonLTM5Xkfr9cYPCKvnwsWoAHfd2MC6Q
OCAWSQnNcNz5RTeQUcLEaA8sxQi64zpCwO9iH8y8COCaO8u9qGkOOuJwWnmPfeLs
cEwALEp0LZ67eSUPsMaz533bs4C8p+2UPMd+v7Td8TkkYoqgUrfYuT0bDTMYVsSe
wcNB5qsI7hDLf1t5FX6KU79/Asn1K3UYHTdN83mghOlM4zh1l1CJdtgaE1jAg4Ml
1X8yG+cT+Ks8gCSGQfIAlVFV4fvvzmpokNKfwAI/b3LS2/ft4ZrK+RCTsWsjUu38
Zr8jbQMtDznzBHMw1LoaHpwRNjbJZ7uA6x5ikbwz5NAlfCITTta6xYn8qvaBfiYJ
YyUFl0kIHm9Kh9V9p54WPMCFCcQx12deovKV82S6zxTeMflDdosJDB/uG9dT2qPt
wkpTD6xAOx5h59IhfiY0j4ScTl725GygVzyK378soP3LQ/vBixQLpheALViotodH
fJknsrelaISNkrnapZL3QE5C1SUoaUtMG9ovRz5HDpMx5ooElEklq7shFWDhZXbp
2ndU5RPRCZO3Szop/Xhn2mNWQoEontFh79WIf+wS8TkJIRXhjtYBt3+s96z0iqSg
gDmE8BcP4lP1+TAUY1d7+QEhGCsTJa9TYtfDtNNfuYI9e3mq6LEpHYKWOvECAwEA
AQKCAgAlF60HaCGf50lzT6eePQCAdnEtWrMeyDCRgZTLStvCjEhk7d3LssTeP9mp
oe8fPomUv6c3BOds2/5LQFockABHd/y/CV9RA973NclAEQlPlhiBrb793Vd4VJJe
6331dveDW0+ggVdFjfVzjhqQfnE9ZcsQ2JvjpiTI0Iv2cy7F01tke0GCSMgx8W1p
J2jjDOxwNOKGGoIT8S4roHVJnFy3nM4sbNtyDj+zHimP4uBE8m2zSgQAP60E8sia
3+Ki1flnkXJRgQWCHR9cg5dkXfFRz56JmcdgxAHGWX2vD9XRuFi5nitPc6iTw8PV
u7GvS3+MC0oO+1pRkTAhOGv3RDK3Uqmy2zrMUuWkEsz6TVId6gPl7+biRJcP+aER
plJkeC9J9nSizbQPwErGByzoHGLjADgBs9hwqYkPcN38b6jR5S/VDQ+RncCyI87h
s/0pIs/fNlfw4LtpBrolP6g++vo6KUufmE3kRNN9dN4lNOoKjUGkcmX6MGnwxiw6
NN/uEqf9+CKQele1XeUhRPNJc9Gv+3Ly5y/wEi6FjfVQmCK4hNrl3tvuZw+qkGbq
Au9Jhk7wV81An7fbhBRIXrwOY9AbOKNqUfY+wpKi5vyJFS1yzkFaYSTKTBspkuHW
pWbohO+KreREwaR5HOMK8tQMTLEAeE3taXGsQMJSJ15lRrLc7QKCAQEA68TV/R8O
C4p+vnGJyhcfDJt6+KBKWlroBy75BG7Dg7/rUXaj+MXcqHi+whRNXMqZchSwzUfS
B2WK/HrOBye8JLKDeA3B5TumJaF19vV7EY/nBF2QdRmI1r33Cp+RWUvAcjKa/v2u
KksV3btnJKXCu/stdAyTK7nU0on4qBzm5WZxuIJv6VMHLDNPFdCk+4gM8LuJ3ITU
l7XuZd4gXccPNj0VTeOYiMjIwxtNmE9RpCkTLm92Z7MI+htciGk1xvV0N4m1BXwA
7qhl1nBgVuJyux4dEYFIeQNhLpHozkEz913QK2gDAHL9pAeiUYJntq4p8HNvfHiQ
vE3wTzil3aUFnwKCAQEA/qQm1Nx5By6an5UunrOvltbTMjsZSDnWspSQbX//j6mL
2atQLe3y/Nr7E5SGZ1kFD9tgAHTuTGVqjvTqp5dBPw4uo146K2RJwuvaYUzNK26c
VoGfMfsI+/bfMfjFnEmGRARZdMr8cvhU+2m04hglsSnNGxsvvPdsiIbRaVDx+JvN
C5C281WlN0WeVd7zNTZkdyUARNXfCxBHQPuYkP5Mz2roZeYlJMWU04i8Cx0/SEuu
bhZQDaNTccSdPDFYcyDDlpqp+mN+U7m+yUPOkVpaxQiSYJZ+NOQsNcAVYfjzyY0E
/VP3s2GddjCJs0amf9SeW0LiMAHPgTp8vbMSRPVVbwKCAQEAmZsSd+llsys2TEmY
pivONN6PjbCRALE9foCiCLtJcmr1m4uaZRg0HScd0UB87rmoo2TLk9L5CYyksr4n
wQ2oTJhpgywjaYAlTVsWiiGBXv3MW1HCLijGuHHno+o2PmFWLpC93ufUMwXcZywT
lRLR/rs07+jJcbGO8OSnNpAt9sN5z+Zblz5a6/c5zVK0SpRnKehld2CrSXRkr8W6
fJ6WUJYXbTmdRXDbLBJ7yYHUBQolzxkboZBJhvmQnec9/DQq1YxIfhw+Vz8rqjxo
5/J9IWALPD5owz7qb/bsIITmoIFkgQMxAXfpvJaksEov3Bs4g8oRlpzOX4C/0j1s
Ay3irQKCAQEAwRJ/qufcEFkCvjsj1QsS+MC785shyUSpiE/izlO91xTLx+f/7EM9
+QCkXK1B1zyE/Qft24rNYDmJOQl0nkuuGfxL2mzImDv7PYMM2reb3PGKMoEnzoKz
xi/h/YbNdnm9BvdxSH/cN+QYs2Pr1X5Pneu+622KnbHQphfq0fqg7Upchwdb4Faw
5Z6wthVMvK0YMcppUMgEzOOz0w6xGEbowGAkA5cj1KTG+jjzs02ivNM9V5Utb5nF
3D4iphAYK3rNMfTlKsejciIlCX+TMVyb9EdSjU+uM7ZJ2xtgWx+i4NA+10GCT42V
EZct4TORbN0ukK2+yH2m8yoAiOks0gJemwKCAQAMGROGt8O4HfhpUdOq01J2qvQL
m5oUXX8w1I95XcoAwCqb+dIan8UbCyl/79lbqNpQlHbRy3wlXzWwH9aHKsfPlCvk
5dE1qrdMdQhLXwP109bRmTiScuU4zfFgHw3XgQhMFXxNp9pze197amLws0TyuBW3
fupS4kM5u6HKCeBYcw2WP5ukxf8jtn29tohLBiA2A7NYtml9xTer6BBP0DTh+QUn
IJL6jSpuCNxBPKIK7p6tZZ0nMBEdAWMxglYm0bmHpTSd3pgu3ltCkYtDlDcTIaF0
Q4k44lxUTZQYwtKUVQXBe4ZvaT/jIEMS7K5bsAy7URv/toaTaiEh1hguwSmf
-----END RSA PRIVATE KEY-----
"},
)?;
jail.create_file(
"keys/key2",
indoc::indoc! {r"
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKlZz/GnH0idVH1PnAF4HQNwRafgBaE2tmyN1wjfdOQqoAoGCCqGSM49
AwEHoUQDQgAEHrgPeG+Mt8eahih1h4qaPjhl7jT25cdzBkg3dbVks6gBR2Rx4ug9
h27LAir5RqxByHvua2XsP46rSTChof78uw==
-----END EC PRIVATE KEY-----
"},
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<SecretsConfig>("secrets")?;
Handle::current().block_on(async move {
assert!(
matches!(config.encryption, Encryption::File(ref p) if p == "encryption")
);
assert_eq!(
config.encryption().await.unwrap(),
[
0, 0, 17, 17, 34, 34, 51, 51, 68, 68, 85, 85, 102, 102, 119, 119, 136,
136, 153, 153, 170, 170, 187, 187, 204, 204, 221, 221, 238, 238, 255,
255
]
);
let mut key_config = config.key_configs().await.unwrap();
key_config.sort_by_key(|a| {
if let Key::File(p) = &a.key {
Some(p.clone())
} else {
None
}
});
let key_store = config.key_store().await.unwrap();
assert!(key_config[0].kid.is_none());
assert!(matches!(&key_config[0].key, Key::File(p) if p == "keys/key1"));
assert!(key_store.iter().any(|k| k.kid() == Some("xmgGCzGtQFmhEOP0YAqBt-oZyVauSVMXcf4kwcgGZLc")));
assert!(key_config[1].kid.is_none());
assert!(matches!(&key_config[1].key, Key::File(p) if p == "keys/key2"));
assert!(key_store.iter().any(|k| k.kid() == Some("ONUCn80fsiISFWKrVMEiirNVr-QEvi7uQI0QH9q9q4o")));
});
Ok(())
});
})
.await
.unwrap();
}
#[tokio::test]
async fn load_config_inline_secrets() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
indoc::indoc! {r"
secrets:
encryption: >-
0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff
keys:
- kid: lekid0
key: |
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIOtZfDuXZr/NC0V3sisR4Chf7RZg6a2dpZesoXMlsPeRoAoGCCqGSM49
AwEHoUQDQgAECfpqx64lrR85MOhdMxNmIgmz8IfmM5VY9ICX9aoaArnD9FjgkBIl
fGmQWxxXDSWH6SQln9tROVZaduenJqDtDw==
-----END EC PRIVATE KEY-----
- key: |
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKlZz/GnH0idVH1PnAF4HQNwRafgBaE2tmyN1wjfdOQqoAoGCCqGSM49
AwEHoUQDQgAEHrgPeG+Mt8eahih1h4qaPjhl7jT25cdzBkg3dbVks6gBR2Rx4ug9
h27LAir5RqxByHvua2XsP46rSTChof78uw==
-----END EC PRIVATE KEY-----
"},
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<SecretsConfig>("secrets")?;
Handle::current().block_on(async move {
assert_eq!(
config.encryption().await.unwrap(),
[
0, 0, 17, 17, 34, 34, 51, 51, 68, 68, 85, 85, 102, 102, 119, 119, 136,
136, 153, 153, 170, 170, 187, 187, 204, 204, 221, 221, 238, 238, 255,
255
]
);
let key_store = config.key_store().await.unwrap();
assert!(key_store.iter().any(|k| k.kid() == Some("lekid0")));
assert!(key_store.iter().any(|k| k.kid() == Some("ONUCn80fsiISFWKrVMEiirNVr-QEvi7uQI0QH9q9q4o")));
});
Ok(())
});
})
.await
.unwrap();
}
#[tokio::test]
async fn load_config_mixed_key_sources() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
indoc::indoc! {r"
secrets:
encryption_file: encryption
keys_dir: keys
keys:
- kid: lekid0
key: |
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIOtZfDuXZr/NC0V3sisR4Chf7RZg6a2dpZesoXMlsPeRoAoGCCqGSM49
AwEHoUQDQgAECfpqx64lrR85MOhdMxNmIgmz8IfmM5VY9ICX9aoaArnD9FjgkBIl
fGmQWxxXDSWH6SQln9tROVZaduenJqDtDw==
-----END EC PRIVATE KEY-----
"},
)?;
jail.create_dir("keys")?;
jail.create_file(
"keys/key_from_file",
indoc::indoc! {r"
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIKlZz/GnH0idVH1PnAF4HQNwRafgBaE2tmyN1wjfdOQqoAoGCCqGSM49
AwEHoUQDQgAEHrgPeG+Mt8eahih1h4qaPjhl7jT25cdzBkg3dbVks6gBR2Rx4ug9
h27LAir5RqxByHvua2XsP46rSTChof78uw==
-----END EC PRIVATE KEY-----
"},
)?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<SecretsConfig>("secrets")?;
Handle::current().block_on(async move {
let key_config = config.key_configs().await.unwrap();
let key_store = config.key_store().await.unwrap();
assert!(key_config[0].kid.is_none());
assert!(matches!(&key_config[0].key, Key::File(p) if p == "keys/key_from_file"));
assert!(key_store.iter().any(|k| k.kid() == Some("ONUCn80fsiISFWKrVMEiirNVr-QEvi7uQI0QH9q9q4o")));
assert!(key_store.iter().any(|k| k.kid() == Some("lekid0")));
});
Ok(())
});
})
.await
.unwrap();
}
}

View file

@ -0,0 +1,221 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error as _};
use serde_with::skip_serializing_none;
use url::Url;
use super::ConfigurationSection;
/// Propagation format for incoming and outgoing requests
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum Propagator {
/// Propagate according to the W3C Trace Context specification
TraceContext,
/// Propagate according to the W3C Baggage specification
Baggage,
/// Propagate trace context with Jaeger compatible headers
Jaeger,
}
#[allow(clippy::unnecessary_wraps)]
fn otlp_endpoint_default() -> Option<String> {
Some("https://localhost:4318".to_owned())
}
/// Exporter to use when exporting traces
#[skip_serializing_none]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "lowercase")]
pub enum TracingExporterKind {
/// Don't export traces
#[default]
None,
/// Export traces to the standard output. Only useful for debugging
Stdout,
/// Export traces to an OpenTelemetry protocol compatible endpoint
Otlp,
}
/// Configuration related to exporting traces
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct TracingConfig {
/// Exporter to use when exporting traces
#[serde(default)]
pub exporter: TracingExporterKind,
/// OTLP exporter: OTLP over HTTP compatible endpoint
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(url, default = "otlp_endpoint_default")]
pub endpoint: Option<Url>,
/// List of propagation formats to use for incoming and outgoing requests
#[serde(default)]
pub propagators: Vec<Propagator>,
/// Sample rate for traces
///
/// Defaults to `1.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub sample_rate: Option<f64>,
}
impl TracingConfig {
/// Returns true if all fields are at their default values
fn is_default(&self) -> bool {
matches!(self.exporter, TracingExporterKind::None)
&& self.endpoint.is_none()
&& self.propagators.is_empty()
}
}
/// Exporter to use when exporting metrics
#[skip_serializing_none]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "lowercase")]
pub enum MetricsExporterKind {
/// Don't export metrics
#[default]
None,
/// Export metrics to stdout. Only useful for debugging
Stdout,
/// Export metrics to an OpenTelemetry protocol compatible endpoint
Otlp,
/// Export metrics via Prometheus. An HTTP listener with the `prometheus`
/// resource must be setup to expose the Promethes metrics.
Prometheus,
}
/// Configuration related to exporting metrics
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct MetricsConfig {
/// Exporter to use when exporting metrics
#[serde(default)]
pub exporter: MetricsExporterKind,
/// OTLP exporter: OTLP over HTTP compatible endpoint
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(url, default = "otlp_endpoint_default")]
pub endpoint: Option<Url>,
}
impl MetricsConfig {
/// Returns true if all fields are at their default values
fn is_default(&self) -> bool {
matches!(self.exporter, MetricsExporterKind::None) && self.endpoint.is_none()
}
}
/// Configuration related to the Sentry integration
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct SentryConfig {
/// Sentry DSN
#[schemars(url, example = &"https://public@host:port/1")]
#[serde(skip_serializing_if = "Option::is_none")]
pub dsn: Option<String>,
/// Environment to use when sending events to Sentry
///
/// Defaults to `production` if not set.
#[schemars(example = &"production")]
#[serde(skip_serializing_if = "Option::is_none")]
pub environment: Option<String>,
/// Sample rate for event submissions
///
/// Defaults to `1.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub sample_rate: Option<f32>,
/// Sample rate for tracing transactions
///
/// Defaults to `0.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub traces_sample_rate: Option<f32>,
}
impl SentryConfig {
/// Returns true if all fields are at their default values
fn is_default(&self) -> bool {
self.dsn.is_none()
}
}
/// Configuration related to sending monitoring data
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct TelemetryConfig {
/// Configuration related to exporting traces
#[serde(default, skip_serializing_if = "TracingConfig::is_default")]
pub tracing: TracingConfig,
/// Configuration related to exporting metrics
#[serde(default, skip_serializing_if = "MetricsConfig::is_default")]
pub metrics: MetricsConfig,
/// Configuration related to the Sentry integration
#[serde(default, skip_serializing_if = "SentryConfig::is_default")]
pub sentry: SentryConfig,
}
impl TelemetryConfig {
/// Returns true if all fields are at their default values
pub(crate) fn is_default(&self) -> bool {
self.tracing.is_default() && self.metrics.is_default() && self.sentry.is_default()
}
}
impl ConfigurationSection for TelemetryConfig {
const PATH: Option<&'static str> = Some("telemetry");
fn validate(
&self,
_figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
if let Some(sample_rate) = self.sentry.sample_rate
&& !(0.0..=1.0).contains(&sample_rate)
{
return Err(figment::error::Error::custom(
"Sentry sample rate must be between 0.0 and 1.0",
)
.with_path("sentry.sample_rate")
.into());
}
if let Some(sample_rate) = self.sentry.traces_sample_rate
&& !(0.0..=1.0).contains(&sample_rate)
{
return Err(figment::error::Error::custom(
"Sentry sample rate must be between 0.0 and 1.0",
)
.with_path("sentry.traces_sample_rate")
.into());
}
if let Some(sample_rate) = self.tracing.sample_rate
&& !(0.0..=1.0).contains(&sample_rate)
{
return Err(figment::error::Error::custom(
"Tracing sample rate must be between 0.0 and 1.0",
)
.with_path("tracing.sample_rate")
.into());
}
Ok(())
}
}

View file

@ -0,0 +1,116 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use camino::Utf8PathBuf;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::ConfigurationSection;
#[cfg(not(any(feature = "docker", feature = "dist")))]
fn default_path() -> Utf8PathBuf {
"./templates/".into()
}
#[cfg(feature = "docker")]
fn default_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/templates/".into()
}
#[cfg(feature = "dist")]
fn default_path() -> Utf8PathBuf {
"./share/templates/".into()
}
fn is_default_path(value: &Utf8PathBuf) -> bool {
*value == default_path()
}
#[cfg(not(any(feature = "docker", feature = "dist")))]
fn default_assets_path() -> Utf8PathBuf {
"./frontend/dist/manifest.json".into()
}
#[cfg(feature = "docker")]
fn default_assets_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/manifest.json".into()
}
#[cfg(feature = "dist")]
fn default_assets_path() -> Utf8PathBuf {
"./share/manifest.json".into()
}
fn is_default_assets_path(value: &Utf8PathBuf) -> bool {
*value == default_assets_path()
}
#[cfg(not(any(feature = "docker", feature = "dist")))]
fn default_translations_path() -> Utf8PathBuf {
"./translations/".into()
}
#[cfg(feature = "docker")]
fn default_translations_path() -> Utf8PathBuf {
"/usr/local/share/mas-cli/translations/".into()
}
#[cfg(feature = "dist")]
fn default_translations_path() -> Utf8PathBuf {
"./share/translations/".into()
}
fn is_default_translations_path(value: &Utf8PathBuf) -> bool {
*value == default_translations_path()
}
/// Configuration related to templates
#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
pub struct TemplatesConfig {
/// Path to the folder which holds the templates
#[serde(default = "default_path", skip_serializing_if = "is_default_path")]
#[schemars(with = "Option<String>")]
pub path: Utf8PathBuf,
/// Path to the assets manifest
#[serde(
default = "default_assets_path",
skip_serializing_if = "is_default_assets_path"
)]
#[schemars(with = "Option<String>")]
pub assets_manifest: Utf8PathBuf,
/// Path to the translations
#[serde(
default = "default_translations_path",
skip_serializing_if = "is_default_translations_path"
)]
#[schemars(with = "Option<String>")]
pub translations_path: Utf8PathBuf,
}
impl Default for TemplatesConfig {
fn default() -> Self {
Self {
path: default_path(),
assets_manifest: default_assets_path(),
translations_path: default_translations_path(),
}
}
}
impl TemplatesConfig {
/// Returns true if all fields are at their default values
pub(crate) fn is_default(&self) -> bool {
is_default_path(&self.path)
&& is_default_assets_path(&self.assets_manifest)
&& is_default_translations_path(&self.translations_path)
}
}
impl ConfigurationSection for TemplatesConfig {
const PATH: Option<&'static str> = Some("templates");
}

View file

@ -0,0 +1,803 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::collections::BTreeMap;
use camino::Utf8PathBuf;
use mas_iana::jose::JsonWebSignatureAlg;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error};
use serde_with::{serde_as, skip_serializing_none};
use ulid::Ulid;
use url::Url;
use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
/// Upstream OAuth 2.0 providers configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
pub struct UpstreamOAuth2Config {
/// List of OAuth 2.0 providers
pub providers: Vec<Provider>,
}
impl UpstreamOAuth2Config {
/// Returns true if the configuration is the default one
pub(crate) fn is_default(&self) -> bool {
self.providers.is_empty()
}
}
impl ConfigurationSection for UpstreamOAuth2Config {
const PATH: Option<&'static str> = Some("upstream_oauth2");
fn validate(
&self,
figment: &figment::Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
for (index, provider) in self.providers.iter().enumerate() {
let annotate = |mut error: figment::Error| {
error.metadata = figment
.find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![
Self::PATH.unwrap().to_owned(),
"providers".to_owned(),
index.to_string(),
];
error
};
if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
&& provider.issuer.is_none()
{
return Err(annotate(figment::Error::custom(
"The `issuer` field is required when discovery is enabled",
))
.into());
}
match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::PrivateKeyJwt
| TokenAuthMethod::SignInWithApple => {
if provider.client_secret.is_some() {
return Err(annotate(figment::Error::custom(
"Unexpected field `client_secret` for the selected authentication method",
)).into());
}
}
TokenAuthMethod::ClientSecretBasic
| TokenAuthMethod::ClientSecretPost
| TokenAuthMethod::ClientSecretJwt => {
if provider.client_secret.is_none() {
return Err(annotate(figment::Error::missing_field("client_secret")).into());
}
}
}
match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::ClientSecretBasic
| TokenAuthMethod::ClientSecretPost
| TokenAuthMethod::SignInWithApple => {
if provider.token_endpoint_auth_signing_alg.is_some() {
return Err(annotate(figment::Error::custom(
"Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
)).into());
}
}
TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
if provider.token_endpoint_auth_signing_alg.is_none() {
return Err(annotate(figment::Error::missing_field(
"token_endpoint_auth_signing_alg",
))
.into());
}
}
}
match provider.token_endpoint_auth_method {
TokenAuthMethod::SignInWithApple => {
if provider.sign_in_with_apple.is_none() {
return Err(
annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
);
}
}
_ => {
if provider.sign_in_with_apple.is_some() {
return Err(annotate(figment::Error::custom(
"Unexpected field `sign_in_with_apple` for the selected authentication method",
)).into());
}
}
}
if provider.claims_imports.skip_confirmation {
if provider.claims_imports.localpart.action != ImportAction::Require {
return Err(annotate(figment::Error::custom(
"The field `action` must be `require` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.localpart").into());
}
if provider.claims_imports.email.action == ImportAction::Suggest {
return Err(annotate(figment::Error::custom(
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.email").into());
}
if provider.claims_imports.displayname.action == ImportAction::Suggest {
return Err(annotate(figment::Error::custom(
"The field `action` must not be `suggest` when `skip_confirmation` is set to `true`",
)).with_path("claims_imports.displayname").into());
}
}
if matches!(
provider.claims_imports.localpart.on_conflict,
OnConflict::Add | OnConflict::Replace | OnConflict::Set
) && !matches!(
provider.claims_imports.localpart.action,
ImportAction::Force | ImportAction::Require
) {
return Err(annotate(figment::Error::custom(
"The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
)).with_path("claims_imports.localpart").into());
}
}
Ok(())
}
}
/// The response mode we ask the provider to use for the callback
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ResponseMode {
/// `query`: The provider will send the response as a query string in the
/// URL search parameters
Query,
/// `form_post`: The provider will send the response as a POST request with
/// the response parameters in the request body
///
/// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
FormPost,
}
/// Authentication methods used against the OAuth 2.0 provider
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum TokenAuthMethod {
/// `none`: No authentication
None,
/// `client_secret_basic`: `client_id` and `client_secret` used as basic
/// authorization credentials
ClientSecretBasic,
/// `client_secret_post`: `client_id` and `client_secret` sent in the
/// request body
ClientSecretPost,
/// `client_secret_jwt`: a `client_assertion` sent in the request body and
/// signed using the `client_secret`
ClientSecretJwt,
/// `private_key_jwt`: a `client_assertion` sent in the request body and
/// signed by an asymmetric key
PrivateKeyJwt,
/// `sign_in_with_apple`: a special method for Signin with Apple
SignInWithApple,
}
/// How to handle a claim
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum ImportAction {
/// Ignore the claim
#[default]
Ignore,
/// Suggest the claim value, but allow the user to change it
Suggest,
/// Force the claim value, but don't fail if it is missing
Force,
/// Force the claim value, and fail if it is missing
Require,
}
impl ImportAction {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, ImportAction::Ignore)
}
}
/// How to handle an existing localpart claim
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum OnConflict {
/// Fails the upstream OAuth 2.0 login on conflict
#[default]
Fail,
/// Adds the upstream OAuth 2.0 identity link, regardless of whether there
/// is an existing link or not
Add,
/// Replace any existing upstream OAuth 2.0 identity link
Replace,
/// Adds the upstream OAuth 2.0 identity link *only* if there is no existing
/// link for this provider on the matching user
Set,
}
impl OnConflict {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, OnConflict::Fail)
}
}
/// What should be done for the subject attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct SubjectImportPreference {
/// The Jinja2 template to use for the subject attribute
///
/// If not provided, the default template is `{{ user.sub }}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl SubjectImportPreference {
const fn is_default(&self) -> bool {
self.template.is_none()
}
}
/// What should be done for the localpart attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct LocalpartImportPreference {
/// How to handle the attribute
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the localpart attribute
///
/// If not provided, the default template is `{{ user.preferred_username }}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
/// How to handle conflicts on the claim, default value is `Fail`
#[serde(default, skip_serializing_if = "OnConflict::is_default")]
pub on_conflict: OnConflict,
}
impl LocalpartImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default() && self.template.is_none()
}
}
/// What should be done for the displayname attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct DisplaynameImportPreference {
/// How to handle the attribute
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the displayname attribute
///
/// If not provided, the default template is `{{ user.name }}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl DisplaynameImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default() && self.template.is_none()
}
}
/// What should be done with the email attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct EmailImportPreference {
/// How to handle the claim
#[serde(default, skip_serializing_if = "ImportAction::is_default")]
pub action: ImportAction,
/// The Jinja2 template to use for the email address attribute
///
/// If not provided, the default template is `{{ user.email }}`
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl EmailImportPreference {
const fn is_default(&self) -> bool {
self.action.is_default() && self.template.is_none()
}
}
/// What should be done for the account name attribute
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct AccountNameImportPreference {
/// The Jinja2 template to use for the account name. This name is only used
/// for display purposes.
///
/// If not provided, it will be ignored.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub template: Option<String>,
}
impl AccountNameImportPreference {
const fn is_default(&self) -> bool {
self.template.is_none()
}
}
/// How claims should be imported
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
pub struct ClaimsImports {
/// How to determine the subject of the user
#[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
pub subject: SubjectImportPreference,
/// Whether to skip the interactive screen prompting the user to confirm the
/// attributes that are being imported. This requires `localpart.action` to
/// be `require` and other attribute actions to be either `ignore`, `force`
/// or `require`
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub skip_confirmation: bool,
/// Import the localpart of the MXID
#[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
pub localpart: LocalpartImportPreference,
/// Import the displayname of the user.
#[serde(
default,
skip_serializing_if = "DisplaynameImportPreference::is_default"
)]
pub displayname: DisplaynameImportPreference,
/// Import the email address of the user
#[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
pub email: EmailImportPreference,
/// Set a human-readable name for the upstream account for display purposes
#[serde(
default,
skip_serializing_if = "AccountNameImportPreference::is_default"
)]
pub account_name: AccountNameImportPreference,
}
impl ClaimsImports {
const fn is_default(&self) -> bool {
self.subject.is_default()
&& self.localpart.is_default()
&& !self.skip_confirmation
&& self.displayname.is_default()
&& self.email.is_default()
&& self.account_name.is_default()
}
}
/// How to discover the provider's configuration
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum DiscoveryMode {
/// Use OIDC discovery with strict metadata verification
#[default]
Oidc,
/// Use OIDC discovery with relaxed metadata verification
Insecure,
/// Use a static configuration
Disabled,
}
impl DiscoveryMode {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, DiscoveryMode::Oidc)
}
}
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum PkceMethod {
/// Use PKCE if the provider supports it
///
/// Defaults to no PKCE if provider discovery is disabled
#[default]
Auto,
/// Always use PKCE with the S256 challenge method
Always,
/// Never use PKCE
Never,
}
impl PkceMethod {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, PkceMethod::Auto)
}
}
fn default_true() -> bool {
true
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_true(value: &bool) -> bool {
*value
}
#[allow(clippy::ref_option)]
fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
*signed_response_alg == signed_response_alg_default()
}
#[allow(clippy::unnecessary_wraps)]
fn signed_response_alg_default() -> JsonWebSignatureAlg {
JsonWebSignatureAlg::Rs256
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SignInWithApple {
/// The private key file used to sign the `id_token`
#[serde(skip_serializing_if = "Option::is_none")]
#[schemars(with = "Option<String>")]
pub private_key_file: Option<Utf8PathBuf>,
/// The private key used to sign the `id_token`
#[serde(skip_serializing_if = "Option::is_none")]
pub private_key: Option<String>,
/// The Team ID of the Apple Developer Portal
pub team_id: String,
/// The key ID of the Apple Developer Portal
pub key_id: String,
}
fn default_scope() -> String {
"openid".to_owned()
}
fn is_default_scope(scope: &str) -> bool {
scope == default_scope()
}
/// What to do when receiving an OIDC Backchannel logout request.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum OnBackchannelLogout {
/// Do nothing
#[default]
DoNothing,
/// Only log out the MAS 'browser session' started by this OIDC session
LogoutBrowserOnly,
/// Log out all sessions started by this OIDC session, including MAS
/// 'browser sessions' and client sessions
LogoutAll,
}
impl OnBackchannelLogout {
#[allow(clippy::trivially_copy_pass_by_ref)]
const fn is_default(&self) -> bool {
matches!(self, OnBackchannelLogout::DoNothing)
}
}
/// Configuration for one upstream OAuth 2 provider.
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider {
/// Whether this provider is enabled.
///
/// Defaults to `true`
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub enabled: bool,
/// An internal unique identifier for this provider
#[schemars(
with = "String",
regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
description = "A ULID as per https://github.com/ulid/spec"
)]
pub id: Ulid,
/// The ID of the provider that was used by Synapse.
/// In order to perform a Synapse-to-MAS migration, this must be specified.
///
/// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
///
/// ### For `oidc_providers`:
/// This should be specified as `oidc-` followed by the ID that was
/// configured as `idp_id` in one of the `oidc_providers` in the Synapse
/// configuration.
/// For example, if Synapse's configuration contained `idp_id: wombat` for
/// this provider, then specify `oidc-wombat` here.
///
/// ### For `oidc_config` (legacy):
/// Specify `oidc` here.
#[serde(skip_serializing_if = "Option::is_none")]
pub synapse_idp_id: Option<String>,
/// The OIDC issuer URL
///
/// This is required if OIDC discovery is enabled (which is the default)
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,
/// A human-readable name for the provider, that will be shown to users
#[serde(skip_serializing_if = "Option::is_none")]
pub human_name: Option<String>,
/// A brand identifier used to customise the UI, e.g. `apple`, `google`,
/// `github`, etc.
///
/// Values supported by the default template are:
///
/// - `apple`
/// - `google`
/// - `facebook`
/// - `github`
/// - `gitlab`
/// - `twitter`
/// - `discord`
#[serde(skip_serializing_if = "Option::is_none")]
pub brand_name: Option<String>,
/// The client ID to use when authenticating with the provider
pub client_id: String,
/// The client secret to use when authenticating with the provider
///
/// Used by the `client_secret_basic`, `client_secret_post`, and
/// `client_secret_jwt` methods
#[schemars(with = "ClientSecretRaw")]
#[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
#[serde(flatten)]
pub client_secret: Option<ClientSecret>,
/// The method to authenticate the client with the provider
pub token_endpoint_auth_method: TokenAuthMethod,
/// Additional parameters for the `sign_in_with_apple` method
#[serde(skip_serializing_if = "Option::is_none")]
pub sign_in_with_apple: Option<SignInWithApple>,
/// The JWS algorithm to use when authenticating the client with the
/// provider
///
/// Used by the `client_secret_jwt` and `private_key_jwt` methods
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
/// Expected signature for the JWT payload returned by the token
/// authentication endpoint.
///
/// Defaults to `RS256`.
#[serde(
default = "signed_response_alg_default",
skip_serializing_if = "is_signed_response_alg_default"
)]
pub id_token_signed_response_alg: JsonWebSignatureAlg,
/// The scopes to request from the provider
///
/// Defaults to `openid`.
#[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
pub scope: String,
/// How to discover the provider's configuration
///
/// Defaults to `oidc`, which uses OIDC discovery with strict metadata
/// verification
#[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
pub discovery_mode: DiscoveryMode,
/// Whether to use proof key for code exchange (PKCE) when requesting and
/// exchanging the token.
///
/// Defaults to `auto`, which uses PKCE if the provider supports it.
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
pub pkce_method: PkceMethod,
/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the `id_token` from the
/// `token_endpoint`.
///
/// Defaults to `false`.
#[serde(default)]
pub fetch_userinfo: bool,
/// Expected signature for the JWT payload returned by the userinfo
/// endpoint.
///
/// If not specified, the response is expected to be an unsigned JSON
/// payload.
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,
/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,
/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint: Option<Url>,
/// The URL to use for getting the provider's public keys
///
/// Defaults to the `jwks_uri` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<Url>,
/// The response mode we ask the provider to use for the callback
#[serde(skip_serializing_if = "Option::is_none")]
pub response_mode: Option<ResponseMode>,
/// How claims should be imported from the `id_token` provided by the
/// provider
#[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
pub claims_imports: ClaimsImports,
/// Additional parameters to include in the authorization request
///
/// Orders of the keys are not preserved.
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub additional_authorization_parameters: BTreeMap<String, String>,
/// Whether the `login_hint` should be forwarded to the provider in the
/// authorization request.
///
/// Defaults to `false`.
#[serde(default)]
pub forward_login_hint: bool,
/// What to do when receiving an OIDC Backchannel logout request.
///
/// Defaults to `do_nothing`.
#[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
pub on_backchannel_logout: OnBackchannelLogout,
}
impl Provider {
/// Returns the client secret.
///
/// If `client_secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the client secret could not be read from file.
pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
Ok(match &self.client_secret {
Some(client_secret) => Some(client_secret.value().await?),
None => None,
})
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use tokio::{runtime::Handle, task};
use super::*;
#[tokio::test]
async fn load_config() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
upstream_oauth2:
providers:
- id: 01GFWR28C4KNE04WG3HKXB7C9R
client_id: upstream-oauth2
token_endpoint_auth_method: none
- id: 01GFWR32NCQ12B8Z0J8CPXRRB6
client_id: upstream-oauth2
client_secret_file: secret
token_endpoint_auth_method: client_secret_basic
- id: 01GFWR3WHR93Y5HK389H28VHZ9
client_id: upstream-oauth2
client_secret: c1!3n753c237
token_endpoint_auth_method: client_secret_post
- id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
client_id: upstream-oauth2
client_secret_file: secret
token_endpoint_auth_method: client_secret_jwt
- id: 01GFWR4BNFDCC4QDG6AMSP1VRR
client_id: upstream-oauth2
token_endpoint_auth_method: private_key_jwt
jwks:
keys:
- kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
- kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
"#,
)?;
jail.create_file("secret", r"c1!3n753c237")?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
assert_eq!(config.providers.len(), 5);
assert_eq!(
config.providers[1].id,
Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
);
assert!(config.providers[0].client_secret.is_none());
assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(config.providers[4].client_secret.is_none());
Handle::current().block_on(async move {
assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
});
Ok(())
});
}).await.unwrap();
}
}

View file

@ -0,0 +1,76 @@
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use figment::Figment;
use serde::de::DeserializeOwned;
/// Trait implemented by all configuration section to help loading specific part
/// of the config and generate the sample config.
pub trait ConfigurationSection: Sized + DeserializeOwned {
/// Specify where this section should live relative to the root.
const PATH: Option<&'static str> = None;
/// Validate the configuration section
///
/// # Errors
///
/// Returns an error if the configuration is invalid
fn validate(
&self,
_figment: &Figment,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(())
}
/// Extract configuration from a Figment instance.
///
/// # Errors
///
/// Returns an error if the configuration could not be loaded
fn extract(
figment: &Figment,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
let this: Self = if let Some(path) = Self::PATH {
figment.extract_inner(path)?
} else {
figment.extract()?
};
this.validate(figment)?;
Ok(this)
}
}
/// Extension trait for [`ConfigurationSection`] to allow extracting the
/// configuration section from a [`Figment`] or return the default value if the
/// section is not present.
pub trait ConfigurationSectionExt: ConfigurationSection + Default {
/// Extract the configuration section from the given [`Figment`], or return
/// the default value if the section is not present.
///
/// # Errors
///
/// Returns an error if the configuration section is invalid.
fn extract_or_default(
figment: &Figment,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
let this: Self = if let Some(path) = Self::PATH {
// If the configuration section is not present, we return the default value
if !figment.contains(path) {
return Ok(Self::default());
}
figment.extract_inner(path)?
} else {
figment.extract()?
};
this.validate(figment)?;
Ok(this)
}
}
impl<T: ConfigurationSection + Default> ConfigurationSectionExt for T {}

View file

@ -0,0 +1,29 @@
# Copyright 2025 New Vector Ltd.
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
# Please see LICENSE files in the repository root for full details.
[package]
name = "mas-context"
version.workspace = true
authors.workspace = true
edition.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
publish.workspace = true
[lints]
workspace = true
[dependencies]
console.workspace = true
pin-project-lite.workspace = true
quanta.workspace = true
tokio.workspace = true
tower-service.workspace = true
tower-layer.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
tracing-opentelemetry.workspace = true
opentelemetry.workspace = true

View file

@ -0,0 +1,143 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use console::{Color, Style};
use opentelemetry::TraceId;
use tracing::{Level, Subscriber};
use tracing_opentelemetry::OtelData;
use tracing_subscriber::{
fmt::{
FormatEvent, FormatFields,
format::{DefaultFields, Writer},
time::{FormatTime, SystemTime},
},
registry::LookupSpan,
};
use crate::LogContext;
/// An event formatter usable by the [`tracing_subscriber`] crate, which
/// includes the log context and the OTEL trace ID.
#[derive(Debug, Default)]
pub struct EventFormatter;
struct FmtLevel<'a> {
level: &'a Level,
ansi: bool,
}
impl<'a> FmtLevel<'a> {
pub(crate) fn new(level: &'a Level, ansi: bool) -> Self {
Self { level, ansi }
}
}
const TRACE_STR: &str = "TRACE";
const DEBUG_STR: &str = "DEBUG";
const INFO_STR: &str = " INFO";
const WARN_STR: &str = " WARN";
const ERROR_STR: &str = "ERROR";
const TRACE_STYLE: Style = Style::new().fg(Color::Magenta);
const DEBUG_STYLE: Style = Style::new().fg(Color::Blue);
const INFO_STYLE: Style = Style::new().fg(Color::Green);
const WARN_STYLE: Style = Style::new().fg(Color::Yellow);
const ERROR_STYLE: Style = Style::new().fg(Color::Red);
impl std::fmt::Display for FmtLevel<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let msg = match *self.level {
Level::TRACE => TRACE_STYLE.force_styling(self.ansi).apply_to(TRACE_STR),
Level::DEBUG => DEBUG_STYLE.force_styling(self.ansi).apply_to(DEBUG_STR),
Level::INFO => INFO_STYLE.force_styling(self.ansi).apply_to(INFO_STR),
Level::WARN => WARN_STYLE.force_styling(self.ansi).apply_to(WARN_STR),
Level::ERROR => ERROR_STYLE.force_styling(self.ansi).apply_to(ERROR_STR),
};
write!(f, "{msg}")
}
}
struct TargetFmt<'a> {
target: &'a str,
line: Option<u32>,
}
impl<'a> TargetFmt<'a> {
pub(crate) fn new(metadata: &tracing::Metadata<'a>) -> Self {
Self {
target: metadata.target(),
line: metadata.line(),
}
}
}
impl std::fmt::Display for TargetFmt<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.target)?;
if let Some(line) = self.line {
write!(f, ":{line}")?;
}
Ok(())
}
}
impl<S, N> FormatEvent<S, N> for EventFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'writer> FormatFields<'writer> + 'static,
{
fn format_event(
&self,
ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &tracing::Event<'_>,
) -> std::fmt::Result {
let ansi = writer.has_ansi_escapes();
let metadata = event.metadata();
SystemTime.format_time(&mut writer)?;
let level = FmtLevel::new(metadata.level(), ansi);
write!(&mut writer, " {level} ")?;
// If there is no explicit 'name' set in the event macro, it will have the
// 'event {filename}:{line}' value. In this case, we want to display the target:
// the module from where it was emitted. In other cases, we want to
// display the explit name of the event we have set.
let style = Style::new().dim().force_styling(ansi);
if metadata.name().starts_with("event ") {
write!(&mut writer, "{} ", style.apply_to(TargetFmt::new(metadata)))?;
} else {
write!(&mut writer, "{} ", style.apply_to(metadata.name()))?;
}
LogContext::maybe_with(|log_context| {
let log_context = Style::new()
.bold()
.force_styling(ansi)
.apply_to(log_context);
write!(&mut writer, "{log_context} - ")
})
.transpose()?;
let field_fromatter = DefaultFields::new();
field_fromatter.format_fields(writer.by_ref(), event)?;
// If we have a OTEL span, we can add the trace ID to the end of the log line
if let Some(span) = ctx.lookup_current()
&& let Some(otel) = span.extensions().get::<OtelData>()
&& let Some(trace_id) = otel.trace_id()
&& trace_id != TraceId::INVALID
{
let label = Style::new()
.italic()
.force_styling(ansi)
.apply_to("trace.id");
write!(&mut writer, " {label}={trace_id}")?;
}
writeln!(&mut writer)
}
}

View file

@ -0,0 +1,59 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll},
};
use quanta::Instant;
use tokio::task::futures::TaskLocalFuture;
use crate::LogContext;
pub type LogContextFuture<F> = TaskLocalFuture<crate::LogContext, PollRecordingFuture<F>>;
impl LogContext {
/// Wrap a future with the given log context
pub(crate) fn wrap_future<F: Future>(&self, future: F) -> LogContextFuture<F> {
let future = PollRecordingFuture::new(future);
crate::CURRENT_LOG_CONTEXT.scope(self.clone(), future)
}
}
pin_project_lite::pin_project! {
/// A future which records the elapsed time and the number of polls in the
/// active log context
pub struct PollRecordingFuture<F> {
#[pin]
inner: F,
}
}
impl<F: Future> PollRecordingFuture<F> {
pub(crate) fn new(inner: F) -> Self {
Self { inner }
}
}
impl<F: Future> Future for PollRecordingFuture<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let start = Instant::now();
let this = self.project();
let result = this.inner.poll(cx);
// Record the number of polls and the time we spent polling the future
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
let _ = crate::CURRENT_LOG_CONTEXT.try_with(|c| {
c.inner.polls.fetch_add(1, Ordering::Relaxed);
c.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
});
result
}
}

View file

@ -0,0 +1,41 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::borrow::Cow;
use tower_layer::Layer;
use tower_service::Service;
use crate::LogContextService;
/// A layer which creates a log context for each request.
pub struct LogContextLayer<R> {
tagger: fn(&R) -> Cow<'static, str>,
}
impl<R> Clone for LogContextLayer<R> {
fn clone(&self) -> Self {
Self {
tagger: self.tagger,
}
}
}
impl<R> LogContextLayer<R> {
pub fn new(tagger: fn(&R) -> Cow<'static, str>) -> Self {
Self { tagger }
}
}
impl<S, R> Layer<S> for LogContextLayer<R>
where
S: Service<R>,
{
type Service = LogContextService<S, R>;
fn layer(&self, inner: S) -> Self::Service {
LogContextService::new(inner, self.tagger)
}
}

View file

@ -0,0 +1,152 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
mod fmt;
mod future;
mod layer;
mod service;
use std::{
borrow::Cow,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use quanta::Instant;
use tokio::task_local;
pub use self::{
fmt::EventFormatter,
future::{LogContextFuture, PollRecordingFuture},
layer::LogContextLayer,
service::LogContextService,
};
/// A counter which increments each time we create a new log context
/// It will wrap around if we create more than [`u64::MAX`] contexts
static LOG_CONTEXT_INDEX: AtomicU64 = AtomicU64::new(0);
task_local! {
pub static CURRENT_LOG_CONTEXT: LogContext;
}
/// A log context saves informations about the current task, such as the
/// elapsed time, the number of polls, and the poll time.
#[derive(Clone)]
pub struct LogContext {
inner: Arc<LogContextInner>,
}
struct LogContextInner {
/// A user-defined tag for the log context
tag: Cow<'static, str>,
/// A unique index for the log context
index: u64,
/// The time when the context was created
start: Instant,
/// The number of [`Future::poll`] recorded
polls: AtomicU64,
/// An approximation of the total CPU time spent in the context, in
/// nanoseconds
cpu_time: AtomicU64,
}
impl LogContext {
/// Create a new log context with the given tag
pub fn new(tag: impl Into<Cow<'static, str>>) -> Self {
let tag = tag.into();
let inner = LogContextInner {
tag,
index: LOG_CONTEXT_INDEX.fetch_add(1, Ordering::Relaxed),
start: Instant::now(),
polls: AtomicU64::new(0),
cpu_time: AtomicU64::new(0),
};
Self {
inner: Arc::new(inner),
}
}
/// Run a closure with the current log context, if any
pub fn maybe_with<F, R>(f: F) -> Option<R>
where
F: FnOnce(&Self) -> R,
{
CURRENT_LOG_CONTEXT.try_with(f).ok()
}
/// Run the async function `f` with the given log context. It will wrap the
/// output future to record poll and CPU statistics.
pub fn run<F: FnOnce() -> Fut, Fut: Future>(&self, f: F) -> LogContextFuture<Fut> {
let future = self.run_sync(f);
self.wrap_future(future)
}
/// Run the sync function `f` with the given log context, recording the CPU
/// time spent.
pub fn run_sync<F: FnOnce() -> R, R>(&self, f: F) -> R {
let start = Instant::now();
let result = CURRENT_LOG_CONTEXT.sync_scope(self.clone(), f);
let elapsed = start.elapsed().as_nanos().try_into().unwrap_or(u64::MAX);
self.inner.cpu_time.fetch_add(elapsed, Ordering::Relaxed);
result
}
/// Create a snapshot of the log context statistics
#[must_use]
pub fn stats(&self) -> LogContextStats {
let polls = self.inner.polls.load(Ordering::Relaxed);
let cpu_time = self.inner.cpu_time.load(Ordering::Relaxed);
let cpu_time = Duration::from_nanos(cpu_time);
let elapsed = self.inner.start.elapsed();
LogContextStats {
polls,
cpu_time,
elapsed,
}
}
}
impl std::fmt::Display for LogContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tag = &self.inner.tag;
let index = self.inner.index;
write!(f, "{tag}-{index}")
}
}
/// A snapshot of a log context statistics
#[derive(Debug, Clone, Copy)]
pub struct LogContextStats {
/// How many times the context was polled
pub polls: u64,
/// The approximate CPU time spent in the context
pub cpu_time: Duration,
/// How much time elapsed since the context was created
pub elapsed: Duration,
}
impl std::fmt::Display for LogContextStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let polls = self.polls;
#[expect(clippy::cast_precision_loss)]
let cpu_time_ms = self.cpu_time.as_nanos() as f64 / 1_000_000.;
#[expect(clippy::cast_precision_loss)]
let elapsed_ms = self.elapsed.as_nanos() as f64 / 1_000_000.;
write!(
f,
"polls: {polls}, cpu: {cpu_time_ms:.1}ms, elapsed: {elapsed_ms:.1}ms",
)
}
}

View file

@ -0,0 +1,54 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::{
borrow::Cow,
task::{Context, Poll},
};
use tower_service::Service;
use crate::{LogContext, LogContextFuture};
/// A service which wraps another service and creates a log context for
/// each request.
pub struct LogContextService<S, R> {
inner: S,
tagger: fn(&R) -> Cow<'static, str>,
}
impl<S: Clone, R> Clone for LogContextService<S, R> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
tagger: self.tagger,
}
}
}
impl<S, R> LogContextService<S, R> {
pub fn new(inner: S, tagger: fn(&R) -> Cow<'static, str>) -> Self {
Self { inner, tagger }
}
}
impl<S, R> Service<R> for LogContextService<S, R>
where
S: Service<R>,
{
type Response = S::Response;
type Error = S::Error;
type Future = LogContextFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
let tag = (self.tagger)(&req);
let log_context = LogContext::new(tag);
log_context.run(|| self.inner.call(req))
}
}

Some files were not shown because too many files have changed in this diff Show more