Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ if provider is not None:
provider_models = list(provider.iter_models())
```

Provider display names are not unique upstream, so `get_provider_by_name`
returns `None` when a name is shared by more than one provider. Use
`get_providers_by_name` to retrieve every match, or `get_provider_by_id` for an
unambiguous lookup.

## Development

Install dependencies and run the test suite with uv:
Expand Down
2 changes: 0 additions & 2 deletions src/modelsdotdev/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
get_model_by_id,
get_provider_by_id,
get_provider_by_name,
get_providers_by_name,
iter_models,
iter_providers,
parse_model_id,
Expand All @@ -38,7 +37,6 @@
"get_model_by_id",
"get_provider_by_id",
"get_provider_by_name",
"get_providers_by_name",
"iter_models",
"iter_providers",
"parse_model_id",
Expand Down
28 changes: 6 additions & 22 deletions src/modelsdotdev/_internal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,31 +239,15 @@ def qualified_id(self) -> str:
DB_PATH = Path(__file__).parents[1] / "_db.sqlite"


def get_providers_by_name(name: str) -> list[Provider]:
"""Return all providers with the given display name, case-insensitively.

Upstream no longer guarantees provider names are unique, so a name may
resolve to more than one provider. Results are ordered by provider ID.
"""
def get_provider_by_name(name: str) -> Provider | None:
"""Return a provider by display name, using case-insensitive matching."""
with closing(_connect()) as connection:
rows = connection.execute(
row = connection.execute(
f"SELECT {PROVIDER_COLUMNS} FROM providers "
"WHERE name = ? COLLATE NOCASE ORDER BY id",
"WHERE name = ? COLLATE NOCASE",
(name,),
).fetchall()
return [_provider_from_row(row) for row in rows]


def get_provider_by_name(name: str) -> Provider | None:
"""Return the sole provider with the given display name, else ``None``.

Matching is case-insensitive. When the name is ambiguous (shared by two or
more providers) ``None`` is returned, since there is no single correct
match; use :func:`get_providers_by_name` or look up by ID via
:func:`get_provider_by_id` instead.
"""
providers = get_providers_by_name(name)
return providers[0] if len(providers) == 1 else None
).fetchone()
return None if row is None else _provider_from_row(row)


def get_provider_by_id(provider_id: str) -> Provider | None:
Expand Down
2 changes: 1 addition & 1 deletion src/modelsdotdev/_internal/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _schema_sql(tables: tuple[Table, ...]) -> str:
Column(name="env", definition="TEXT NOT NULL"),
),
indexes=(
"CREATE INDEX providers_name_nocase_idx "
"CREATE UNIQUE INDEX providers_name_nocase_idx "
"ON providers(name COLLATE NOCASE);",
),
)
Expand Down
40 changes: 7 additions & 33 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sqlite3
from collections import Counter
from contextlib import closing

import pytest
Expand All @@ -20,7 +19,6 @@
get_model_by_id,
get_provider_by_id,
get_provider_by_name,
get_providers_by_name,
iter_models,
iter_providers,
parse_model_id,
Expand All @@ -36,41 +34,17 @@ def test_provider_iteration_and_lookup_use_real_database() -> None:
key=lambda provider: provider.name.lower(),
)
assert len({provider.id for provider in providers}) == len(providers)

# Upstream no longer guarantees provider names are unique.
name_counts = Counter(provider.name.lower() for provider in providers)

# A provider whose name is unique (case-insensitively) resolves by name.
unique = next(
provider
for provider in providers
if name_counts[provider.name.lower()] == 1
)
assert isinstance(unique, Provider)
assert get_provider_by_id(unique.id) == unique
assert get_provider_by_name(unique.name) == unique
assert get_provider_by_name(unique.name.upper()) == unique
assert get_providers_by_name(unique.name) == [unique]

# An ambiguous name (shared by 2+ providers) resolves to None via the
# singular getter but is fully enumerated by the plural getter.
ambiguous = next(
(
provider.name
for provider in providers
if name_counts[provider.name.lower()] > 1
),
None,
assert len({provider.name.lower() for provider in providers}) == len(
providers,
)
if ambiguous is not None:
matches = get_providers_by_name(ambiguous)
assert len(matches) == name_counts[ambiguous.lower()]
assert matches == get_providers_by_name(ambiguous.upper())
assert get_provider_by_name(ambiguous) is None

provider = providers[0]
assert isinstance(provider, Provider)
assert get_provider_by_id(provider.id) == provider
assert get_provider_by_name(provider.name) == provider
assert get_provider_by_name(provider.name.upper()) == provider
assert get_provider_by_id("missing-provider") is None
assert get_provider_by_name("missing provider") is None
assert get_providers_by_name("missing provider") == []


def test_model_iteration_and_lookup_use_real_database() -> None:
Expand Down