diff --git a/README.md b/README.md index f752d97..56ff2c5 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,6 @@ if provider is not None: provider_models = list(provider.iter_models()) ``` -Provider display names are not unique upstream, so `get_provider_by_name` -returns `None` when a name is shared by more than one provider. Use -`get_providers_by_name` to retrieve every match, or `get_provider_by_id` for an -unambiguous lookup. - ## Development Install dependencies and run the test suite with uv: diff --git a/src/modelsdotdev/__init__.py b/src/modelsdotdev/__init__.py index fd6796d..d1eccc0 100644 --- a/src/modelsdotdev/__init__.py +++ b/src/modelsdotdev/__init__.py @@ -16,7 +16,6 @@ get_model_by_id, get_provider_by_id, get_provider_by_name, - get_providers_by_name, iter_models, iter_providers, parse_model_id, @@ -38,7 +37,6 @@ "get_model_by_id", "get_provider_by_id", "get_provider_by_name", - "get_providers_by_name", "iter_models", "iter_providers", "parse_model_id", diff --git a/src/modelsdotdev/_internal/data.py b/src/modelsdotdev/_internal/data.py index d2eb9d3..d0873bc 100644 --- a/src/modelsdotdev/_internal/data.py +++ b/src/modelsdotdev/_internal/data.py @@ -239,31 +239,15 @@ def qualified_id(self) -> str: DB_PATH = Path(__file__).parents[1] / "_db.sqlite" -def get_providers_by_name(name: str) -> list[Provider]: - """Return all providers with the given display name, case-insensitively. - - Upstream no longer guarantees provider names are unique, so a name may - resolve to more than one provider. Results are ordered by provider ID. - """ +def get_provider_by_name(name: str) -> Provider | None: + """Return a provider by display name, using case-insensitive matching.""" with closing(_connect()) as connection: - rows = connection.execute( + row = connection.execute( f"SELECT {PROVIDER_COLUMNS} FROM providers " - "WHERE name = ? COLLATE NOCASE ORDER BY id", + "WHERE name = ? COLLATE NOCASE", (name,), - ).fetchall() - return [_provider_from_row(row) for row in rows] - - -def get_provider_by_name(name: str) -> Provider | None: - """Return the sole provider with the given display name, else ``None``. - - Matching is case-insensitive. When the name is ambiguous (shared by two or - more providers) ``None`` is returned, since there is no single correct - match; use :func:`get_providers_by_name` or look up by ID via - :func:`get_provider_by_id` instead. - """ - providers = get_providers_by_name(name) - return providers[0] if len(providers) == 1 else None + ).fetchone() + return None if row is None else _provider_from_row(row) def get_provider_by_id(provider_id: str) -> Provider | None: diff --git a/src/modelsdotdev/_internal/schema.py b/src/modelsdotdev/_internal/schema.py index 318465d..48260e7 100644 --- a/src/modelsdotdev/_internal/schema.py +++ b/src/modelsdotdev/_internal/schema.py @@ -87,7 +87,7 @@ def _schema_sql(tables: tuple[Table, ...]) -> str: Column(name="env", definition="TEXT NOT NULL"), ), indexes=( - "CREATE INDEX providers_name_nocase_idx " + "CREATE UNIQUE INDEX providers_name_nocase_idx " "ON providers(name COLLATE NOCASE);", ), ) diff --git a/tests/test_api.py b/tests/test_api.py index e7c0a03..14c3905 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,4 @@ import sqlite3 -from collections import Counter from contextlib import closing import pytest @@ -20,7 +19,6 @@ get_model_by_id, get_provider_by_id, get_provider_by_name, - get_providers_by_name, iter_models, iter_providers, parse_model_id, @@ -36,41 +34,17 @@ def test_provider_iteration_and_lookup_use_real_database() -> None: key=lambda provider: provider.name.lower(), ) assert len({provider.id for provider in providers}) == len(providers) - - # Upstream no longer guarantees provider names are unique. - name_counts = Counter(provider.name.lower() for provider in providers) - - # A provider whose name is unique (case-insensitively) resolves by name. - unique = next( - provider - for provider in providers - if name_counts[provider.name.lower()] == 1 - ) - assert isinstance(unique, Provider) - assert get_provider_by_id(unique.id) == unique - assert get_provider_by_name(unique.name) == unique - assert get_provider_by_name(unique.name.upper()) == unique - assert get_providers_by_name(unique.name) == [unique] - - # An ambiguous name (shared by 2+ providers) resolves to None via the - # singular getter but is fully enumerated by the plural getter. - ambiguous = next( - ( - provider.name - for provider in providers - if name_counts[provider.name.lower()] > 1 - ), - None, + assert len({provider.name.lower() for provider in providers}) == len( + providers, ) - if ambiguous is not None: - matches = get_providers_by_name(ambiguous) - assert len(matches) == name_counts[ambiguous.lower()] - assert matches == get_providers_by_name(ambiguous.upper()) - assert get_provider_by_name(ambiguous) is None + provider = providers[0] + assert isinstance(provider, Provider) + assert get_provider_by_id(provider.id) == provider + assert get_provider_by_name(provider.name) == provider + assert get_provider_by_name(provider.name.upper()) == provider assert get_provider_by_id("missing-provider") is None assert get_provider_by_name("missing provider") is None - assert get_providers_by_name("missing provider") == [] def test_model_iteration_and_lookup_use_real_database() -> None: