diff --git a/README.md b/README.md index 56ff2c5..f752d97 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,11 @@ if provider is not None: provider_models = list(provider.iter_models()) ``` +Provider display names are not unique upstream, so `get_provider_by_name` +returns `None` when a name is shared by more than one provider. Use +`get_providers_by_name` to retrieve every match, or `get_provider_by_id` for an +unambiguous lookup. + ## Development Install dependencies and run the test suite with uv: diff --git a/src/modelsdotdev/__init__.py b/src/modelsdotdev/__init__.py index d1eccc0..fd6796d 100644 --- a/src/modelsdotdev/__init__.py +++ b/src/modelsdotdev/__init__.py @@ -16,6 +16,7 @@ get_model_by_id, get_provider_by_id, get_provider_by_name, + get_providers_by_name, iter_models, iter_providers, parse_model_id, @@ -37,6 +38,7 @@ "get_model_by_id", "get_provider_by_id", "get_provider_by_name", + "get_providers_by_name", "iter_models", "iter_providers", "parse_model_id", diff --git a/src/modelsdotdev/_internal/data.py b/src/modelsdotdev/_internal/data.py index d0873bc..d2eb9d3 100644 --- a/src/modelsdotdev/_internal/data.py +++ b/src/modelsdotdev/_internal/data.py @@ -239,15 +239,31 @@ def qualified_id(self) -> str: DB_PATH = Path(__file__).parents[1] / "_db.sqlite" -def get_provider_by_name(name: str) -> Provider | None: - """Return a provider by display name, using case-insensitive matching.""" +def get_providers_by_name(name: str) -> list[Provider]: + """Return all providers with the given display name, case-insensitively. + + Upstream no longer guarantees provider names are unique, so a name may + resolve to more than one provider. Results are ordered by provider ID. + """ with closing(_connect()) as connection: - row = connection.execute( + rows = connection.execute( f"SELECT {PROVIDER_COLUMNS} FROM providers " - "WHERE name = ? COLLATE NOCASE", + "WHERE name = ? COLLATE NOCASE ORDER BY id", (name,), - ).fetchone() - return None if row is None else _provider_from_row(row) + ).fetchall() + return [_provider_from_row(row) for row in rows] + + +def get_provider_by_name(name: str) -> Provider | None: + """Return the sole provider with the given display name, else ``None``. + + Matching is case-insensitive. When the name is ambiguous (shared by two or + more providers) ``None`` is returned, since there is no single correct + match; use :func:`get_providers_by_name` or look up by ID via + :func:`get_provider_by_id` instead. + """ + providers = get_providers_by_name(name) + return providers[0] if len(providers) == 1 else None def get_provider_by_id(provider_id: str) -> Provider | None: diff --git a/src/modelsdotdev/_internal/schema.py b/src/modelsdotdev/_internal/schema.py index 48260e7..318465d 100644 --- a/src/modelsdotdev/_internal/schema.py +++ b/src/modelsdotdev/_internal/schema.py @@ -87,7 +87,7 @@ def _schema_sql(tables: tuple[Table, ...]) -> str: Column(name="env", definition="TEXT NOT NULL"), ), indexes=( - "CREATE UNIQUE INDEX providers_name_nocase_idx " + "CREATE INDEX providers_name_nocase_idx " "ON providers(name COLLATE NOCASE);", ), ) diff --git a/tests/test_api.py b/tests/test_api.py index 14c3905..e7c0a03 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,5 @@ import sqlite3 +from collections import Counter from contextlib import closing import pytest @@ -19,6 +20,7 @@ get_model_by_id, get_provider_by_id, get_provider_by_name, + get_providers_by_name, iter_models, iter_providers, parse_model_id, @@ -34,17 +36,41 @@ def test_provider_iteration_and_lookup_use_real_database() -> None: key=lambda provider: provider.name.lower(), ) assert len({provider.id for provider in providers}) == len(providers) - assert len({provider.name.lower() for provider in providers}) == len( - providers, + + # Upstream no longer guarantees provider names are unique. + name_counts = Counter(provider.name.lower() for provider in providers) + + # A provider whose name is unique (case-insensitively) resolves by name. + unique = next( + provider + for provider in providers + if name_counts[provider.name.lower()] == 1 + ) + assert isinstance(unique, Provider) + assert get_provider_by_id(unique.id) == unique + assert get_provider_by_name(unique.name) == unique + assert get_provider_by_name(unique.name.upper()) == unique + assert get_providers_by_name(unique.name) == [unique] + + # An ambiguous name (shared by 2+ providers) resolves to None via the + # singular getter but is fully enumerated by the plural getter. + ambiguous = next( + ( + provider.name + for provider in providers + if name_counts[provider.name.lower()] > 1 + ), + None, ) + if ambiguous is not None: + matches = get_providers_by_name(ambiguous) + assert len(matches) == name_counts[ambiguous.lower()] + assert matches == get_providers_by_name(ambiguous.upper()) + assert get_provider_by_name(ambiguous) is None - provider = providers[0] - assert isinstance(provider, Provider) - assert get_provider_by_id(provider.id) == provider - assert get_provider_by_name(provider.name) == provider - assert get_provider_by_name(provider.name.upper()) == provider assert get_provider_by_id("missing-provider") is None assert get_provider_by_name("missing provider") is None + assert get_providers_by_name("missing provider") == [] def test_model_iteration_and_lookup_use_real_database() -> None: