ci : switch from pyright to ty (#20826)

* type fixes

* switch to ty

* tweak rules

* tweak more rules

* more tweaks

* final tweak

* use common import-not-found rule
This commit is contained in:
Sigbjørn Skjæret 2026-03-21 08:54:34 +01:00 committed by GitHub
parent cea560f483
commit 29b28a9824
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 181 additions and 124 deletions

View file

@ -4,15 +4,17 @@ on:
push: push:
paths: paths:
- '.github/workflows/python-type-check.yml' - '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json' - 'ty.toml'
- '**.py' - '**.py'
- '**/requirements*.txt' - '**/requirements*.txt'
# - 'pyrightconfig.json'
pull_request: pull_request:
paths: paths:
- '.github/workflows/python-type-check.yml' - '.github/workflows/python-type-check.yml'
- 'pyrightconfig.json' - 'ty.toml'
- '**.py' - '**.py'
- '**/requirements*.txt' - '**/requirements*.txt'
# - 'pyrightconfig.json'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@ -20,8 +22,8 @@ concurrency:
jobs: jobs:
python-type-check: python-type-check:
runs-on: ubuntu-latest runs-on: ubuntu-slim
name: pyright type-check name: python type-check
steps: steps:
- name: Check out source repository - name: Check out source repository
uses: actions/checkout@v6 uses: actions/checkout@v6
@ -29,10 +31,13 @@ jobs:
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.11" python-version: "3.11"
pip-install: -r requirements/requirements-all.txt pip-install: -r requirements/requirements-all.txt ty==0.0.24
- name: Type-check with Pyright # - name: Type-check with Pyright
uses: jakebailey/pyright-action@v2 # uses: jakebailey/pyright-action@v2
with: # with:
version: 1.1.382 # version: 1.1.382
level: warning # level: warning
warnings: true # warnings: true
- name: Type-check with ty
run: |
ty check --output-format=github

View file

@ -31,10 +31,10 @@ import gguf
from gguf.vocab import MistralTokenizerType, MistralVocab from gguf.vocab import MistralTokenizerType, MistralVocab
try: try:
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.base import TokenizerVersion # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer, SentencePieceTokenizer,
) )
@ -45,9 +45,9 @@ except ImportError:
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
_mistral_common_installed = False _mistral_common_installed = False
TokenizerVersion = None TokenizerVersion: Any = None
Tekkenizer = None Tekkenizer: Any = None
SentencePieceTokenizer = None SentencePieceTokenizer: Any = None
_mistral_import_error_msg = ( _mistral_import_error_msg = (
"Mistral format requires `mistral-common` to be installed. Please run " "Mistral format requires `mistral-common` to be installed. Please run "
"`pip install mistral-common[image,audio]` to install it." "`pip install mistral-common[image,audio]` to install it."
@ -220,7 +220,7 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict): if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}") raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys()) tensor_names_from_index.update(weight_map.keys())
part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None) # ty: ignore[invalid-assignment]
part_names = sorted(part_dict.keys()) part_names = sorted(part_dict.keys())
else: else:
weight_map = {} weight_map = {}
@ -5882,7 +5882,7 @@ class InternLM2Model(TextModel):
logger.error(f'Error: Missing {tokenizer_path}') logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1) sys.exit(1)
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
@ -6203,7 +6203,7 @@ class BertModel(TextModel):
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
else: else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
@ -8880,7 +8880,7 @@ class T5Model(TextModel):
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram # some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -9017,7 +9017,7 @@ class T5EncoderModel(TextModel):
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
# some models like Pile-T5 family use BPE tokenizer instead of Unigram # some models like Pile-T5 family use BPE tokenizer instead of Unigram
@ -12279,6 +12279,7 @@ class LazyTorchTensor(gguf.LazyBase):
kwargs = {} kwargs = {}
if func is torch.Tensor.numpy: if func is torch.Tensor.numpy:
assert len(args)
return args[0].numpy() return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs) return cls._wrap_fn(func)(*args, **kwargs)

View file

@ -112,11 +112,11 @@ class Tensor:
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12]) (n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
assert name_len < 4096, 'Absurd tensor name length' assert name_len < 4096, 'Absurd tensor name length'
quant = gguf.GGML_QUANT_SIZES.get(dtype) self.dtype = gguf.GGMLQuantizationType(dtype)
quant = gguf.GGML_QUANT_SIZES.get(self.dtype)
assert quant is not None, 'Unknown tensor type' assert quant is not None, 'Unknown tensor type'
(blksize, tysize) = quant (blksize, tysize) = quant
offset += 12 offset += 12
self.dtype= gguf.GGMLQuantizationType(dtype)
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
offset += 4 * n_dims offset += 4 * n_dims
self.name = bytes(data[offset:offset + name_len]) self.name = bytes(data[offset:offset + name_len])

View file

@ -199,10 +199,13 @@ class LoraTorchTensor:
kwargs = {} kwargs = {}
if func is torch.permute: if func is torch.permute:
assert len(args)
return type(args[0]).permute(*args, **kwargs) return type(args[0]).permute(*args, **kwargs)
elif func is torch.reshape: elif func is torch.reshape:
assert len(args)
return type(args[0]).reshape(*args, **kwargs) return type(args[0]).reshape(*args, **kwargs)
elif func is torch.stack: elif func is torch.stack:
assert len(args)
assert isinstance(args[0], Sequence) assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0) dim = kwargs.get("dim", 0)
assert dim == 0 assert dim == 0
@ -211,6 +214,7 @@ class LoraTorchTensor:
torch.stack([b._lora_B for b in args[0]], dim), torch.stack([b._lora_B for b in args[0]], dim),
) )
elif func is torch.cat: elif func is torch.cat:
assert len(args)
assert isinstance(args[0], Sequence) assert isinstance(args[0], Sequence)
dim = kwargs.get("dim", 0) dim = kwargs.get("dim", 0)
assert dim == 0 assert dim == 0
@ -362,7 +366,7 @@ if __name__ == '__main__':
logger.error(f"Model {hparams['architectures'][0]} is not supported") logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1) sys.exit(1)
class LoraModel(model_class): class LoraModel(model_class): # ty: ignore[unsupported-base]
model_arch = model_class.model_arch model_arch = model_class.model_arch
lora_alpha: float lora_alpha: float

View file

@ -28,9 +28,6 @@ def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
return f'({result})?' if min_items == 0 else result return f'({result})?' if min_items == 0 else result
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
has_min = min_value != None
has_max = max_value != None
def digit_range(from_char: str, to_char: str): def digit_range(from_char: str, to_char: str):
out.append("[") out.append("[")
if from_char == to_char: if from_char == to_char:
@ -106,7 +103,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
out.append(to_str[i]) out.append(to_str[i])
out.append("]") out.append("]")
if has_min and has_max: if min_value is not None and max_value is not None:
if min_value < 0 and max_value < 0: if min_value < 0 and max_value < 0:
out.append("\"-\" (") out.append("\"-\" (")
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
@ -133,7 +130,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
less_decimals = max(decimals_left - 1, 1) less_decimals = max(decimals_left - 1, 1)
if has_min: if min_value is not None:
if min_value < 0: if min_value < 0:
out.append("\"-\" (") out.append("\"-\" (")
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
@ -177,7 +174,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
more_digits(length - 1, less_decimals) more_digits(length - 1, less_decimals)
return return
if has_max: if max_value is not None:
if max_value >= 0: if max_value >= 0:
if top_level: if top_level:
out.append("\"-\" [1-9] ") out.append("\"-\" [1-9] ")

View file

@ -64,7 +64,7 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print("Using SentenceTransformer to apply all numbered layers") print("Using SentenceTransformer to apply all numbered layers")
model = SentenceTransformer(model_path) model = SentenceTransformer(model_path)
tokenizer = model.tokenizer tokenizer = model.tokenizer
config = model[0].auto_model.config # type: ignore config = model[0].auto_model.config
else: else:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
@ -108,8 +108,8 @@ def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device
print(f"Model file: {type(model).__module__}") print(f"Model file: {type(model).__module__}")
# Verify the model is using the correct sliding window # Verify the model is using the correct sliding window
if hasattr(model.config, 'sliding_window'): # type: ignore if hasattr(model.config, 'sliding_window'):
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore print(f"Model's sliding_window: {model.config.sliding_window}")
else: else:
print("Model config does not have sliding_window attribute") print("Model config does not have sliding_window attribute")
@ -152,7 +152,7 @@ def main():
device = next(model.parameters()).device device = next(model.parameters()).device
else: else:
# For SentenceTransformer, get device from the underlying model # For SentenceTransformer, get device from the underlying model
device = next(model[0].auto_model.parameters()).device # type: ignore device = next(model[0].auto_model.parameters()).device
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
@ -177,7 +177,7 @@ def main():
print(f"{token_id:6d} -> '{token_str}'") print(f"{token_id:6d} -> '{token_str}'")
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}")
else: else:
# Standard approach: use base model output only # Standard approach: use base model output only
encoded = tokenizer( encoded = tokenizer(
@ -205,12 +205,12 @@ def main():
print(f"Embedding dimension: {all_embeddings.shape[1]}") print(f"Embedding dimension: {all_embeddings.shape[1]}")
if len(all_embeddings.shape) == 1: if len(all_embeddings.shape) == 1:
n_embd = all_embeddings.shape[0] # type: ignore n_embd = all_embeddings.shape[0]
n_embd_count = 1 n_embd_count = 1
all_embeddings = all_embeddings.reshape(1, -1) all_embeddings = all_embeddings.reshape(1, -1)
else: else:
n_embd = all_embeddings.shape[1] # type: ignore n_embd = all_embeddings.shape[1]
n_embd_count = all_embeddings.shape[0] # type: ignore n_embd_count = all_embeddings.shape[0]
print() print()

View file

@ -2,7 +2,7 @@
import argparse import argparse
import sys import sys
from common import compare_tokens # type: ignore from common import compare_tokens # type: ignore[import-not-found]
def parse_arguments(): def parse_arguments():

View file

@ -6,7 +6,7 @@ import re
from copy import copy from copy import copy
from enum import Enum from enum import Enum
from inspect import getdoc, isclass from inspect import getdoc, isclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel, create_model from pydantic import BaseModel, create_model
@ -1158,7 +1158,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a type annotation # Assert that the parameter has a type annotation
if param.annotation == inspect.Parameter.empty: if param.annotation == inspect.Parameter.empty:
raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation") raise TypeError(f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a type annotation""")
# Find the parameter's description in the docstring # Find the parameter's description in the docstring
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
@ -1166,7 +1166,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
# Assert that the parameter has a description # Assert that the parameter has a description
if not param_doc or not param_doc.description: if not param_doc or not param_doc.description:
raise ValueError( raise ValueError(
f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring") f"""Parameter '{param.name}' in function '{getattr(func, "__name__", "")}' lacks a description in the docstring""")
# Add parameter details to the schema # Add parameter details to the schema
param_docs.append((param.name, param_doc)) param_docs.append((param.name, param_doc))
@ -1177,7 +1177,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
dynamic_fields[param.name] = ( dynamic_fields[param.name] = (
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
# Creating the dynamic model # Creating the dynamic model
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) dynamic_model = create_model(f"{getattr(func, '__name__')}", **dynamic_fields)
for name, param_doc in param_docs: for name, param_doc in param_docs:
dynamic_model.model_fields[name].description = param_doc.description dynamic_model.model_fields[name].description = param_doc.description
@ -1285,7 +1285,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
if items != {}: if items != {}:
array = {"properties": items} array = {"properties": items}
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items") array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
fields[field_name] = (List[array_type], ...) fields[field_name] = (list[array_type], ...) # ty: ignore[invalid-type-form]
else: else:
fields[field_name] = (list, ...) fields[field_name] = (list, ...)
elif field_type == "object": elif field_type == "object":

View file

@ -1300,7 +1300,7 @@ class GGUFWriter:
else: else:
raise ValueError("Invalid GGUF metadata value type or value") raise ValueError("Invalid GGUF metadata value type or value")
return kv_data return bytes(kv_data)
@staticmethod @staticmethod
def format_n_bytes_to_str(num: int) -> str: def format_n_bytes_to_str(num: int) -> str:

View file

@ -138,7 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
if isinstance(meta_noop, tuple): if isinstance(meta_noop, tuple):
dtype, shape = meta_noop dtype, shape = meta_noop
assert callable(shape) assert callable(shape)
res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) # ty: ignore[call-top-callable]
else: else:
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)

View file

@ -91,11 +91,11 @@ class __Quant(ABC):
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None: def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
cls.qtype = qtype cls.qtype = qtype
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype] cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn( cls.__quantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__quantize_array, cls.__quantize_array,
meta_noop=(np.uint8, cls.__shape_to_bytes) meta_noop=(np.uint8, cls.__shape_to_bytes)
) )
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn( cls.__dequantize_lazy: Any = LazyNumpyTensor._wrap_fn(
cls.__dequantize_array, cls.__dequantize_array,
meta_noop=(np.float32, cls.__shape_from_bytes) meta_noop=(np.float32, cls.__shape_from_bytes)
) )

View file

@ -11,33 +11,33 @@ from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVa
try: try:
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
except ImportError: except ImportError:
SentencePieceProcessor = None SentencePieceProcessor: Any = None
try: try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.tekken import Tekkenizer # type: ignore[import-not-found]
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
_filter_valid_tokenizer_files, _filter_valid_tokenizer_files,
) )
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.sentencepiece import ( # type: ignore[import-not-found]
SentencePieceTokenizer, SentencePieceTokenizer,
) )
except ImportError: except ImportError:
_mistral_common_installed = False _mistral_common_installed = False
MistralTokenizer = None MistralTokenizer: Any = None
Tekkenizer = None Tekkenizer: Any = None
SentencePieceTokenizer = None SentencePieceTokenizer: Any = None
_filter_valid_tokenizer_files = None _filter_valid_tokenizer_files: Any = None
else: else:
_mistral_common_installed = True _mistral_common_installed = True
try: try:
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] from mistral_common.tokens.tokenizers.utils import ( # type: ignore[import-not-found]
get_one_valid_tokenizer_file, get_one_valid_tokenizer_file,
) )
except ImportError: except ImportError:
# We still want the conversion to work with older mistral-common versions. # We still want the conversion to work with older mistral-common versions.
get_one_valid_tokenizer_file = None get_one_valid_tokenizer_file: Any = None
import gguf import gguf
@ -703,7 +703,7 @@ class MistralVocab(Vocab):
tokenizer_file_path = base_path / tokenizer_file tokenizer_file_path = base_path / tokenizer_file
self.tokenizer = MistralTokenizer.from_file( self.tokenizer: Any = MistralTokenizer.from_file(
tokenizer_file_path tokenizer_file_path
).instruct_tokenizer.tokenizer ).instruct_tokenizer.tokenizer
self.tokenizer_type = ( self.tokenizer_type = (

View file

@ -1,5 +1,5 @@
{ {
"extraPaths": ["gguf-py", "examples/model-conversion/scripts"], "extraPaths": ["gguf-py", "examples/model-conversion/scripts", "examples/model-conversion/scripts/utils"],
"pythonVersion": "3.9", "pythonVersion": "3.9",
"pythonPlatform": "All", "pythonPlatform": "All",
"reportUnusedImport": "warning", "reportUnusedImport": "warning",

View file

@ -684,6 +684,7 @@ else:
sys.exit(1) sys.exit(1)
assert isinstance(hexsha8_baseline, str)
name_baseline = bench_data.get_commit_name(hexsha8_baseline) name_baseline = bench_data.get_commit_name(hexsha8_baseline)
hexsha8_compare = name_compare = None hexsha8_compare = name_compare = None
@ -717,6 +718,7 @@ else:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
assert isinstance(hexsha8_compare, str)
name_compare = bench_data.get_commit_name(hexsha8_compare) name_compare = bench_data.get_commit_name(hexsha8_compare)
# Get tool-specific configuration # Get tool-specific configuration

View file

@ -241,10 +241,10 @@ class CodeEditor(QPlainTextEdit):
if not self.isReadOnly(): if not self.isReadOnly():
selection = QTextEdit.ExtraSelection() selection = QTextEdit.ExtraSelection()
line_color = QColorConstants.Yellow.lighter(160) line_color = QColorConstants.Yellow.lighter(160)
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra_selections.append(selection) extra_selections.append(selection)
self.setExtraSelections(extra_selections) self.setExtraSelections(extra_selections)
@ -262,8 +262,8 @@ class CodeEditor(QPlainTextEdit):
) )
extra = QTextEdit.ExtraSelection() extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra]) self.setExtraSelections(self.extraSelections() + [extra])
@ -274,8 +274,8 @@ class CodeEditor(QPlainTextEdit):
cursor.select(QTextCursor.SelectionType.LineUnderCursor) cursor.select(QTextCursor.SelectionType.LineUnderCursor)
extra = QTextEdit.ExtraSelection() extra = QTextEdit.ExtraSelection()
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
self.setExtraSelections(self.extraSelections() + [extra]) self.setExtraSelections(self.extraSelections() + [extra])
@ -395,8 +395,8 @@ class JinjaTester(QMainWindow):
ensure_ascii=ensure_ascii, ensure_ascii=ensure_ascii,
) )
) )
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) # ty: ignore[invalid-assignment]
env.globals["raise_exception"] = raise_exception env.globals["raise_exception"] = raise_exception # ty: ignore[invalid-assignment]
try: try:
template = env.from_string(template_str) template = env.from_string(template_str)
output = template.render(context) output = template.render(context)

View file

@ -189,6 +189,7 @@ def benchmark(
data: list[dict] = [] data: list[dict] = []
assert isinstance(prompts, list)
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
if seed_offset >= 0: if seed_offset >= 0:
random.seed(3 * (seed_offset + 1000 * i) + 1) random.seed(3 * (seed_offset + 1000 * i) + 1)

View file

@ -16,8 +16,7 @@ import random
import unicodedata import unicodedata
from pathlib import Path from pathlib import Path
from typing import Any, Iterator, cast from typing import Any, Iterator
from typing_extensions import Buffer
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizer from transformers import AutoTokenizer, PreTrainedTokenizer
@ -114,7 +113,7 @@ class LibLlamaModel:
while num < 0 and len(self.text_buff) < (16 << 20): while num < 0 and len(self.text_buff) < (16 << 20):
self.text_buff = self.ffi.new("uint8_t[]", -2 * num) self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' # pyright: ignore[reportArgumentType]
class Tokenizer: class Tokenizer:
@ -438,7 +437,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
decode_errors = 0 decode_errors = 0
MAX_ERRORS = 10 MAX_ERRORS = 10
logger.info("%s: %s" % (generator.__qualname__, "ini")) logger.info("%s: %s" % (getattr(generator, "__qualname__", ""), "ini"))
for text in generator: for text in generator:
# print(repr(text), text.encode()) # print(repr(text), text.encode())
# print(repr(text), hex(ord(text[0])), text.encode()) # print(repr(text), hex(ord(text[0])), text.encode())
@ -477,7 +476,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
break break
t_total = time.perf_counter() - t_start t_total = time.perf_counter() - t_start
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") logger.info(f"{getattr(generator, '__qualname__', '')}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
def main(argv: list[str] | None = None): def main(argv: list[str] | None = None):

View file

@ -285,7 +285,7 @@ def start_server_background(args):
} }
server_process = subprocess.Popen( server_process = subprocess.Popen(
args, args,
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] # ty: ignore[no-matching-overload]
def server_log(in_stream, out_stream): def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''): for line in iter(in_stream.readline, b''):

View file

@ -9,6 +9,7 @@ sys.path.insert(0, str(path))
from utils import * from utils import *
from enum import Enum from enum import Enum
from typing import TypedDict
server: ServerProcess server: ServerProcess
@ -29,56 +30,73 @@ class CompletionMode(Enum):
NORMAL = "normal" NORMAL = "normal"
STREAMED = "streamed" STREAMED = "streamed"
TEST_TOOL = { class ToolParameters(TypedDict):
"type":"function", type: str
"function": { properties: dict[str, dict]
"name": "test", required: list[str]
"description": "",
"parameters": {
"type": "object",
"properties": {
"success": {"type": "boolean", "const": True},
},
"required": ["success"]
}
}
}
PYTHON_TOOL = { class ToolFunction(TypedDict):
"type": "function", name: str
"function": { description: str
"name": "python", parameters: ToolParameters
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": { class ToolDefinition(TypedDict):
"type": "object", type: str
"properties": { function: ToolFunction
TEST_TOOL = ToolDefinition(
type = "function",
function = ToolFunction(
name = "test",
description = "",
parameters = ToolParameters(
type = "object",
properties = {
"success": {
"type": "boolean",
"const": True,
},
},
required = ["success"],
),
),
)
PYTHON_TOOL = ToolDefinition(
type = "function",
function = ToolFunction(
name = "python",
description = "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
parameters = ToolParameters(
type = "object",
properties = {
"code": { "code": {
"type": "string", "type": "string",
"description": "The code to run in the ipython interpreter." "description": "The code to run in the ipython interpreter.",
} },
}, },
"required": ["code"] required = ["code"],
} ),
} ),
} )
WEATHER_TOOL = { WEATHER_TOOL = ToolDefinition(
"type":"function", type = "function",
"function":{ function = ToolFunction(
"name":"get_current_weather", name = "get_current_weather",
"description":"Get the current weather in a given location", description = "Get the current weather in a given location",
"parameters":{ parameters = ToolParameters(
"type":"object", type = "object",
"properties":{ properties = {
"location":{ "location": {
"type":"string", "type": "string",
"description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" "description": "The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'",
} },
}, },
"required":["location"] required = ["location"],
} ),
} ),
} )
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
body = server.make_any_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={

30
ty.toml Normal file
View file

@ -0,0 +1,30 @@
[environment]
extra-paths = ["./gguf-py", "./examples/model-conversion/scripts", "./tools/server/tests"]
python-version = "3.10"
[rules]
deprecated = "warn"
[src]
exclude = [
"./tools/mtmd/legacy-models/**",
]
[[overrides]]
include = [
"./tools/server/tests/**",
]
[overrides.rules]
unresolved-reference = "ignore"
unresolved-import = "ignore"
unresolved-attribute = "ignore"
[[overrides]]
include = [
"./examples/pydantic_models_to_grammar.py",
]
[overrides.rules]
unsupported-operator = "ignore"
not-subscriptable = "ignore"