diff --git a/.github/workflows/rust-check.yml b/.github/workflows/rust-check.yml index 97e1029c..d9fc54d8 100644 --- a/.github/workflows/rust-check.yml +++ b/.github/workflows/rust-check.yml @@ -19,7 +19,7 @@ jobs: profile: minimal toolchain: stable override: true - components: rustfmt + components: rustfmt, clippy - name: Check formatting run: cargo fmt -- --check diff --git a/Cargo.toml b/Cargo.toml index f57c2f63..f766da4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "6.0.8" +version = "8.0.1" edition = "2021" authors = ["Dongri Jin "] license = "MIT" @@ -17,7 +17,7 @@ default-tls = ["reqwest/default-tls", "tokio-tungstenite/native-tls"] [dependencies.reqwest] version = "0.12" default-features = false -features = ["charset", "http2", "json", "multipart", "socks"] +features = ["charset", "http2", "json", "multipart", "socks", "stream"] [dependencies.tokio] version = "1" @@ -34,7 +34,7 @@ version = "1" version = "1.7.1" [dependencies.tokio-tungstenite] -version = "0.24.0" +version = "0.28.0" features = ["connect"] [dependencies.futures-util] diff --git a/README.md b/README.md index 3caee96c..46671eeb 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,24 @@ # OpenAI API client library for Rust (unofficial) + The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications. Check out the [docs.rs](https://docs.rs/openai-api-rs/). ## Installation: + Cargo.toml + ```toml [dependencies] -openai-api-rs = "6.0.8" +openai-api-rs = "8.0.1" ``` ## Usage + The library needs to be configured with your account's secret key, which is available on the [website](https://platform.openai.com/account/api-keys). We recommend setting it as an environment variable. Here's an example of initializing the library with the API key loaded from an environment variable and creating a completion: ### Set OPENAI_API_KEY or OPENROUTER_API_KEY to environment variable + ```bash $ export OPENAI_API_KEY=sk-xxxxxxx or @@ -21,12 +26,14 @@ $ export OPENROUTER_API_KEY=sk-xxxxxxx ``` ### Create OpenAI client + ```rust let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; ``` ### Create OpenRouter client + ```rust let api_key = env::var("OPENROUTER_API_KEY").unwrap().to_string(); let mut client = OpenAIClient::builder() @@ -36,6 +43,7 @@ let mut client = OpenAIClient::builder() ``` ### Create request + ```rust let req = ChatCompletionRequest::new( GPT4_O.to_string(), @@ -50,6 +58,7 @@ let req = ChatCompletionRequest::new( ``` ### Send request + ```rust let result = client.chat_completion(req)?; println!("Content: {:?}", result.choices[0].message.content); @@ -60,11 +69,13 @@ for (key, value) in client.headers.unwrap().iter() { ``` ### Set OPENAI_API_BASE to environment variable (optional) + ```bash $ export OPENAI_API_BASE=https://api.openai.com/v1 ``` ## Example of chat completion + ```rust use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; @@ -99,6 +110,7 @@ async fn main() -> Result<(), Box> { ``` ## Example for OpenRouter + ```rust use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; @@ -126,7 +138,7 @@ async fn main() -> Result<(), Box> { let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); - + for (key, value) in client.headers.unwrap().iter() { println!("{}: {:?}", key, value); } @@ -140,6 +152,7 @@ More Examples: [examples](https://github.com/dongri/openai-api-rs/tree/main/exam Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions. ## Supported APIs + - [x] [Completions](https://platform.openai.com/docs/api-reference/completions) - [x] [Chat](https://platform.openai.com/docs/api-reference/chat) - [x] [Edits](https://platform.openai.com/docs/api-reference/edits) @@ -153,6 +166,8 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe - [x] [Assistants](https://platform.openai.com/docs/assistants/overview) - [x] [Batch](https://platform.openai.com/docs/api-reference/batch) - [x] [Realtime](https://platform.openai.com/docs/api-reference/realtime) +- [x] [Responses](https://platform.openai.com/docs/api-reference/responses) ## License + This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE). diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 635add58..3556de08 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O_MINI; use std::env; diff --git a/examples/chat_completion_stream.rs b/examples/chat_completion_stream.rs new file mode 100644 index 00000000..010e3f6a --- /dev/null +++ b/examples/chat_completion_stream.rs @@ -0,0 +1,44 @@ +use futures_util::StreamExt; +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::chat_completion::chat_completion_stream::{ + ChatCompletionStreamRequest, ChatCompletionStreamResponse, +}; +use openai_api_rs::v1::chat_completion::{self}; +use openai_api_rs::v1::common::GPT4_O_MINI; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; + + let req = ChatCompletionStreamRequest::new( + GPT4_O_MINI.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("What is bitcoin?")), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + + let mut result = client.chat_completion_stream(req).await?; + while let Some(response) = result.next().await { + match response.clone() { + ChatCompletionStreamResponse::ToolCall(toolcalls) => { + println!("Tool Call: {:?}", toolcalls); + } + ChatCompletionStreamResponse::Content(content) => { + println!("Content: {:?}", content); + } + ChatCompletionStreamResponse::Done => { + println!("Done"); + } + } + } + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example chat_completion_stream diff --git a/examples/function_call.rs b/examples/function_call.rs index 3935599e..b18769f6 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -1,5 +1,10 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::{ + chat_completion::ChatCompletionRequest, ChatCompletionMessage, +}; +use openai_api_rs::v1::chat_completion::{ + Content, FinishReason, MessageRole, Tool, ToolChoiceType, ToolType, +}; use openai_api_rs::v1::common::GPT4_O; use openai_api_rs::v1::types; use serde::{Deserialize, Serialize}; @@ -32,16 +37,16 @@ async fn main() -> Result<(), Box> { let req = ChatCompletionRequest::new( GPT4_O.to_string(), - vec![chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::user, - content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), + vec![ChatCompletionMessage { + role: MessageRole::user, + content: Content::Text(String::from("What is the price of Ethereum?")), name: None, tool_calls: None, tool_call_id: None, }], ) - .tools(vec![chat_completion::Tool { - r#type: chat_completion::ToolType::Function, + .tools(vec![Tool { + r#type: ToolType::Function, function: types::Function { name: String::from("get_coin_price"), description: Some(String::from("Get the price of a cryptocurrency")), @@ -52,7 +57,7 @@ async fn main() -> Result<(), Box> { }, }, }]) - .tool_choice(chat_completion::ToolChoiceType::Auto); + .tool_choice(ToolChoiceType::Auto); // debug request json // let serialized = serde_json::to_string(&req).unwrap(); @@ -65,14 +70,14 @@ async fn main() -> Result<(), Box> { println!("No finish_reason"); println!("{:?}", result.choices[0].message.content); } - Some(chat_completion::FinishReason::stop) => { + Some(FinishReason::stop) => { println!("Stop"); println!("{:?}", result.choices[0].message.content); } - Some(chat_completion::FinishReason::length) => { + Some(FinishReason::length) => { println!("Length"); } - Some(chat_completion::FinishReason::tool_calls) => { + Some(FinishReason::tool_calls) => { println!("ToolCalls"); #[derive(Deserialize, Serialize)] struct Currency { @@ -90,10 +95,10 @@ async fn main() -> Result<(), Box> { } } } - Some(chat_completion::FinishReason::content_filter) => { + Some(FinishReason::content_filter) => { println!("ContentFilter"); } - Some(chat_completion::FinishReason::null) => { + Some(FinishReason::null) => { println!("Null"); } } diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 901d5d7f..a18d355c 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O; use openai_api_rs::v1::types; use serde::{Deserialize, Serialize}; diff --git a/examples/openrouter.rs b/examples/openrouter.rs index 5295bf41..79b8ec01 100644 --- a/examples/openrouter.rs +++ b/examples/openrouter.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O_MINI; use std::env; diff --git a/examples/openrouter_models.rs b/examples/openrouter_models.rs new file mode 100644 index 00000000..4223b2b6 --- /dev/null +++ b/examples/openrouter_models.rs @@ -0,0 +1,22 @@ +use openai_api_rs::v1::api::OpenAIClient; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENROUTER_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder() + .with_endpoint("https://openrouter.ai/api/v1") + .with_api_key(api_key) + .build()?; + + let result = client.list_models().await?; + let models = result.data; + + for model in models { + println!("Model id: {:?}", model.id); + } + + Ok(()) +} + +// OPENROUTER_API_KEY=xxxx cargo run --package openai-api-rs --example openrouter_models diff --git a/examples/openrouter_reasoning.rs b/examples/openrouter_reasoning.rs index 9bfac3dd..9dd3c125 100644 --- a/examples/openrouter_reasoning.rs +++ b/examples/openrouter_reasoning.rs @@ -1,7 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{ - self, ChatCompletionRequest, Reasoning, ReasoningEffort, ReasoningMode, -}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self, Reasoning, ReasoningEffort, ReasoningMode}; use std::env; #[tokio::main] diff --git a/examples/responses.rs b/examples/responses.rs new file mode 100644 index 00000000..8bc24fec --- /dev/null +++ b/examples/responses.rs @@ -0,0 +1,23 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::common::GPT4_1_MINI; +use openai_api_rs::v1::responses::CreateResponseRequest; +use serde_json::json; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENAI_API_KEY").unwrap(); + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; + + let mut req = CreateResponseRequest::new(); + req.model = Some(GPT4_1_MINI.to_string()); + req.input = Some(json!( + "Tell me a three sentence bedtime story about a unicorn." + )); + req.extra.insert("temperature".to_string(), json!(0.7)); + + let resp = client.create_response(req).await?; + println!("response id: {} status: {:?}", resp.id, resp.status); + println!("response output: {:?}", resp.output); + Ok(()) +} diff --git a/examples/vision.rs b/examples/vision.rs index 7bad362b..67c9af5d 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -1,5 +1,6 @@ use openai_api_rs::v1::api::OpenAIClient; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::chat_completion::ChatCompletionRequest; +use openai_api_rs::v1::chat_completion::{self}; use openai_api_rs::v1::common::GPT4_O; use std::env; diff --git a/src/realtime/client_event.rs b/src/realtime/client_event.rs index 53805381..1c43fd09 100644 --- a/src/realtime/client_event.rs +++ b/src/realtime/client_event.rs @@ -92,7 +92,7 @@ pub enum ClientEvent { impl From for Message { fn from(value: ClientEvent) -> Self { - Message::Text(String::from(&value)) + Message::Text(String::from(&value).into()) } } diff --git a/src/realtime/types.rs b/src/realtime/types.rs index a90ff27f..4cc433da 100644 --- a/src/realtime/types.rs +++ b/src/realtime/types.rs @@ -51,8 +51,11 @@ pub enum AudioFormat { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct AudioTranscription { + #[serde(skip_serializing_if = "Option::is_none")] pub language: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub prompt: Option, } diff --git a/src/v1/api.rs b/src/v1/api.rs index ce0b1c4e..5d8a13d4 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -7,7 +7,10 @@ use crate::v1::audio::{ AudioTranslationRequest, AudioTranslationResponse, }; use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse}; -use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::v1::chat_completion::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::v1::chat_completion::chat_completion_stream::{ + ChatCompletionStream, ChatCompletionStreamRequest, ChatCompletionStreamResponse, +}; use crate::v1::common; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; @@ -32,6 +35,9 @@ use crate::v1::message::{ }; use crate::v1::model::{ModelResponse, ModelsResponse}; use crate::v1::moderation::{CreateModerationRequest, CreateModerationResponse}; +use crate::v1::responses::{ + CountTokensRequest, CountTokensResponse, CreateResponseRequest, ListResponses, ResponseObject, +}; use crate::v1::run::{ CreateRunRequest, CreateThreadAndRunRequest, ListRun, ListRunStep, ModifyRunRequest, RunObject, RunStepObject, @@ -39,11 +45,12 @@ use crate::v1::run::{ use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject}; use bytes::Bytes; +use futures_util::Stream; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::multipart::{Form, Part}; use reqwest::{Client, Method, Response}; use serde::Serialize; -use serde_json::Value; +use serde_json::{to_value, Value}; use url::Url; use std::error::Error; @@ -334,6 +341,40 @@ impl OpenAIClient { self.post("chat/completions", &req).await } + pub async fn chat_completion_stream( + &mut self, + req: ChatCompletionStreamRequest, + ) -> Result, APIError> { + let mut payload = to_value(&req).map_err(|err| APIError::CustomError { + message: format!("Failed to serialize request: {}", err), + })?; + + if let Some(obj) = payload.as_object_mut() { + obj.insert("stream".into(), Value::Bool(true)); + } + + let request = self.build_request(Method::POST, "chat/completions").await; + let request = request.json(&payload); + let response = request.send().await?; + + if response.status().is_success() { + Ok(ChatCompletionStream { + response: Box::pin(response.bytes_stream()), + buffer: String::new(), + first_chunk: true, + }) + } else { + let error_text = response + .text() + .await + .unwrap_or_else(|_| String::from("Unknown error")); + + Err(APIError::CustomError { + message: error_text, + }) + } + } + pub async fn audio_transcription( &mut self, req: AudioTranscriptionRequest, @@ -781,6 +822,70 @@ impl OpenAIClient { self.get(&url).await } + // Responses API + pub async fn create_response( + &mut self, + req: CreateResponseRequest, + ) -> Result { + self.post("responses", &req).await + } + + pub async fn retrieve_response( + &mut self, + response_id: String, + ) -> Result { + self.get(&format!("responses/{response_id}")).await + } + + pub async fn delete_response( + &mut self, + response_id: String, + ) -> Result { + self.delete(&format!("responses/{response_id}")).await + } + + pub async fn cancel_response( + &mut self, + response_id: String, + ) -> Result { + self.post( + &format!("responses/{response_id}/cancel"), + &common::EmptyRequestBody {}, + ) + .await + } + + pub async fn list_response_input_items( + &mut self, + response_id: String, + after: Option, + limit: Option, + order: Option, + ) -> Result { + let mut url = format!("responses/{}/input_items", response_id); + let mut params = vec![]; + if let Some(after) = after { + params.push(format!("after={}", after)); + } + if let Some(limit) = limit { + params.push(format!("limit={}", limit)); + } + if let Some(order) = order { + params.push(format!("order={}", order)); + } + if !params.is_empty() { + url = format!("{}?{}", url, params.join("&")); + } + self.get(&url).await + } + + pub async fn count_response_input_tokens( + &mut self, + req: CountTokensRequest, + ) -> Result { + self.post("responses/input_tokens", &req).await + } + pub async fn list_models(&mut self) -> Result { self.get("models").await } diff --git a/src/v1/audio.rs b/src/v1/audio.rs index 4ab93f4e..3ea2b3bd 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -5,6 +5,13 @@ use crate::impl_builder_methods; pub const WHISPER_1: &str = "whisper-1"; +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum TimestampGranularity { + Word, + Segment, +} + #[derive(Debug, Serialize, Clone)] pub struct AudioTranscriptionRequest { pub model: String, @@ -19,6 +26,8 @@ pub struct AudioTranscriptionRequest { pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub language: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub timestamp_granularities: Option>, } impl AudioTranscriptionRequest { @@ -31,6 +40,7 @@ impl AudioTranscriptionRequest { response_format: None, temperature: None, language: None, + timestamp_granularities: None, } } @@ -43,6 +53,7 @@ impl AudioTranscriptionRequest { response_format: None, temperature: None, language: None, + timestamp_granularities: None, } } } @@ -52,7 +63,8 @@ impl_builder_methods!( prompt: String, response_format: String, temperature: f32, - language: String + language: String, + timestamp_granularities: Vec ); #[derive(Debug, Deserialize, Serialize)] diff --git a/src/v1/chat_completion/chat_completion.rs b/src/v1/chat_completion/chat_completion.rs new file mode 100644 index 00000000..bd2807a9 --- /dev/null +++ b/src/v1/chat_completion/chat_completion.rs @@ -0,0 +1,233 @@ +use crate::v1::chat_completion::{ChatCompletionChoice, Reasoning, Tool, ToolChoiceType}; +use crate::v1::common; +use crate::{ + impl_builder_methods, + v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage}, +}; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_tool_choice")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + /// Optional list of transforms to apply to the chat completion request. + /// + /// Transforms allow modifying the request before it's sent to the API, + /// enabling features like prompt rewriting, content filtering, or other + /// preprocessing steps. When None, no transforms are applied. + #[serde(skip_serializing_if = "Option::is_none")] + pub transforms: Option>, +} + +impl ChatCompletionRequest { + pub fn new(model: String, messages: Vec) -> Self { + Self { + model, + messages, + temperature: None, + top_p: None, + n: None, + response_format: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + tools: None, + parallel_tool_calls: None, + tool_choice: None, + reasoning: None, + transforms: None, + } + } +} + +impl_builder_methods!( + ChatCompletionRequest, + temperature: f64, + top_p: f64, + n: i64, + response_format: Value, + stop: Vec, + max_tokens: i64, + presence_penalty: f64, + frequency_penalty: f64, + logit_bias: HashMap, + user: String, + seed: i64, + tools: Vec, + parallel_tool_calls: bool, + tool_choice: ToolChoiceType, + reasoning: Reasoning, + transforms: Vec +); + +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: Option, + pub object: Option, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: common::Usage, + pub system_fingerprint: Option, +} + +#[cfg(test)] +mod tests { + use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode}; + + use super::*; + use serde_json::json; + + #[test] + fn test_reasoning_effort_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::High, + }), + exclude: Some(false), + enabled: None, + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "effort": "high", + "exclude": false + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_max_tokens_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }), + exclude: None, + enabled: Some(true), + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "max_tokens": 2000, + "enabled": true + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_deserialization() { + let json_str = r#"{"effort": "medium", "exclude": true}"#; + let reasoning: Reasoning = serde_json::from_str(json_str).unwrap(); + + match reasoning.mode { + Some(ReasoningMode::Effort { effort }) => { + assert_eq!(effort, ReasoningEffort::Medium); + } + _ => panic!("Expected effort mode"), + } + assert_eq!(reasoning.exclude, Some(true)); + } + + #[test] + fn test_chat_completion_request_with_reasoning() { + let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]); + + req.reasoning = Some(Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::Low, + }), + exclude: None, + enabled: None, + }); + + let serialized = serde_json::to_value(&req).unwrap(); + assert_eq!(serialized["reasoning"]["effort"], "low"); + } + + #[test] + fn test_transforms_none_serialization() { + let req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is completely omitted from JSON output + assert!(!serialised.as_object().unwrap().contains_key("transforms")); + } + + #[test] + fn test_transforms_some_serialization() { + let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]); + req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is included as a proper JSON array + assert_eq!( + serialised["transforms"], + serde_json::json!(["transform1", "transform2"]) + ); + } + + #[test] + fn test_transforms_some_deserialization() { + let json_str = + r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#; + let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly populated with Some(vec) + assert_eq!( + req.transforms, + Some(vec!["transform1".to_string(), "transform2".to_string()]) + ); + } + + #[test] + fn test_transforms_none_deserialization() { + let json_str = r#"{"model": "gpt-4", "messages": []}"#; + let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly set to None when absent + assert_eq!(req.transforms, None); + } + + #[test] + fn test_transforms_builder_method() { + let transforms = vec!["transform1".to_string(), "transform2".to_string()]; + let req = + ChatCompletionRequest::new("gpt-4".to_string(), vec![]).transforms(transforms.clone()); + // Verify that the transforms field is properly set through the builder method + assert_eq!(req.transforms, Some(transforms)); + } +} diff --git a/src/v1/chat_completion/chat_completion_stream.rs b/src/v1/chat_completion/chat_completion_stream.rs new file mode 100644 index 00000000..f5b3283e --- /dev/null +++ b/src/v1/chat_completion/chat_completion_stream.rs @@ -0,0 +1,361 @@ +use crate::v1::chat_completion::{Reasoning, Tool, ToolCall, ToolChoiceType}; +use crate::{ + impl_builder_methods, + v1::chat_completion::{serialize_tool_choice, ChatCompletionMessage}, +}; + +use futures_util::Stream; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ChatCompletionStreamRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_tool_choice")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + /// Optional list of transforms to apply to the chat completion request. + /// + /// Transforms allow modifying the request before it's sent to the API, + /// enabling features like prompt rewriting, content filtering, or other + /// preprocessing steps. When None, no transforms are applied. + #[serde(skip_serializing_if = "Option::is_none")] + pub transforms: Option>, +} + +impl ChatCompletionStreamRequest { + pub fn new(model: String, messages: Vec) -> Self { + Self { + model, + messages, + temperature: None, + top_p: None, + n: None, + response_format: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + tools: None, + parallel_tool_calls: None, + tool_choice: None, + reasoning: None, + transforms: None, + } + } +} + +impl_builder_methods!( + ChatCompletionStreamRequest, + temperature: f64, + top_p: f64, + n: i64, + response_format: Value, + stop: Vec, + max_tokens: i64, + presence_penalty: f64, + frequency_penalty: f64, + logit_bias: HashMap, + user: String, + seed: i64, + tools: Vec, + parallel_tool_calls: bool, + tool_choice: ToolChoiceType, + reasoning: Reasoning, + transforms: Vec +); + +#[derive(Debug, Clone)] +pub enum ChatCompletionStreamResponse { + Content(String), + ToolCall(Vec), + Done, +} + +pub struct ChatCompletionStream> + Unpin> { + pub response: S, + pub buffer: String, + pub first_chunk: bool, +} + +impl ChatCompletionStream +where + S: Stream> + Unpin, +{ + fn find_event_delimiter(buffer: &str) -> Option<(usize, usize)> { + let carriage_idx = buffer.find("\r\n\r\n"); + let newline_idx = buffer.find("\n\n"); + + match (carriage_idx, newline_idx) { + (Some(r_idx), Some(n_idx)) => { + if r_idx <= n_idx { + Some((r_idx, 4)) + } else { + Some((n_idx, 2)) + } + } + (Some(r_idx), None) => Some((r_idx, 4)), + (None, Some(n_idx)) => Some((n_idx, 2)), + (None, None) => None, + } + } + + fn next_response_from_buffer(&mut self) -> Option { + while let Some((idx, delimiter_len)) = Self::find_event_delimiter(&self.buffer) { + let event = self.buffer[..idx].to_owned(); + self.buffer = self.buffer[idx + delimiter_len..].to_owned(); + + let mut data_payload = String::new(); + for line in event.lines() { + let trimmed_line = line.trim_end_matches('\r'); + if let Some(content) = trimmed_line + .strip_prefix("data: ") + .or_else(|| trimmed_line.strip_prefix("data:")) + { + if !content.is_empty() { + if !data_payload.is_empty() { + data_payload.push('\n'); + } + data_payload.push_str(content); + } + } + } + + if data_payload.is_empty() { + continue; + } + + if data_payload == "[DONE]" { + return Some(ChatCompletionStreamResponse::Done); + } + + match serde_json::from_str::(&data_payload) { + Ok(json) => { + if let Some(delta) = json + .get("choices") + .and_then(|choices| choices.get(0)) + .and_then(|choice| choice.get("delta")) + { + if let Some(tool_call_response) = delta + .get("tool_calls") + .and_then(|tool_calls| tool_calls.as_array()) + .map(|tool_calls_array| { + tool_calls_array + .iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect::>() + }) + .filter(|tool_calls_vec| !tool_calls_vec.is_empty()) + .map(ChatCompletionStreamResponse::ToolCall) + { + return Some(tool_call_response); + } + + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + let output = content.replace("\\n", "\n"); + return Some(ChatCompletionStreamResponse::Content(output)); + } + } + } + Err(error) => { + eprintln!("Failed to parse SSE chunk as JSON: {}", error); + } + } + } + + None + } +} + +impl> + Unpin> Stream + for ChatCompletionStream +{ + type Item = ChatCompletionStreamResponse; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if let Some(response) = self.next_response_from_buffer() { + return Poll::Ready(Some(response)); + } + + match Pin::new(&mut self.as_mut().response).poll_next(cx) { + Poll::Ready(Some(Ok(chunk))) => { + let chunk_str = String::from_utf8_lossy(&chunk).to_string(); + + if self.first_chunk { + self.first_chunk = false; + } + self.buffer.push_str(&chunk_str); + } + Poll::Ready(Some(Err(error))) => { + eprintln!("Error in stream: {:?}", error); + return Poll::Ready(None); + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::v1::chat_completion::{ReasoningEffort, ReasoningMode}; + + use super::*; + use serde_json::json; + + #[test] + fn test_reasoning_effort_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::High, + }), + exclude: Some(false), + enabled: None, + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "effort": "high", + "exclude": false + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_max_tokens_serialization() { + let reasoning = Reasoning { + mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }), + exclude: None, + enabled: Some(true), + }; + + let serialized = serde_json::to_value(&reasoning).unwrap(); + let expected = json!({ + "max_tokens": 2000, + "enabled": true + }); + + assert_eq!(serialized, expected); + } + + #[test] + fn test_reasoning_deserialization() { + let json_str = r#"{"effort": "medium", "exclude": true}"#; + let reasoning: Reasoning = serde_json::from_str(json_str).unwrap(); + + match reasoning.mode { + Some(ReasoningMode::Effort { effort }) => { + assert_eq!(effort, ReasoningEffort::Medium); + } + _ => panic!("Expected effort mode"), + } + assert_eq!(reasoning.exclude, Some(true)); + } + + #[test] + fn test_chat_completion_request_with_reasoning() { + let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + + req.reasoning = Some(Reasoning { + mode: Some(ReasoningMode::Effort { + effort: ReasoningEffort::Low, + }), + exclude: None, + enabled: None, + }); + + let serialized = serde_json::to_value(&req).unwrap(); + assert_eq!(serialized["reasoning"]["effort"], "low"); + } + + #[test] + fn test_transforms_none_serialization() { + let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is completely omitted from JSON output + assert!(!serialised.as_object().unwrap().contains_key("transforms")); + } + + #[test] + fn test_transforms_some_serialization() { + let mut req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]); + req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]); + let serialised = serde_json::to_value(&req).unwrap(); + // Verify that the transforms field is included as a proper JSON array + assert_eq!( + serialised["transforms"], + serde_json::json!(["transform1", "transform2"]) + ); + } + + #[test] + fn test_transforms_some_deserialization() { + let json_str = + r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#; + let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly populated with Some(vec) + assert_eq!( + req.transforms, + Some(vec!["transform1".to_string(), "transform2".to_string()]) + ); + } + + #[test] + fn test_transforms_none_deserialization() { + let json_str = r#"{"model": "gpt-4", "messages": []}"#; + let req: ChatCompletionStreamRequest = serde_json::from_str(json_str).unwrap(); + // Verify that the transforms field is properly set to None when absent + assert_eq!(req.transforms, None); + } + + #[test] + fn test_transforms_builder_method() { + let transforms = vec!["transform1".to_string(), "transform2".to_string()]; + let req = ChatCompletionStreamRequest::new("gpt-4".to_string(), vec![]) + .transforms(transforms.clone()); + // Verify that the transforms field is properly set through the builder method + assert_eq!(req.transforms, Some(transforms)); + } +} diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion/mod.rs similarity index 57% rename from src/v1/chat_completion.rs rename to src/v1/chat_completion/mod.rs index 35b891b8..757e5052 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion/mod.rs @@ -1,12 +1,13 @@ -use super::{common, types}; -use crate::impl_builder_methods; - +use crate::v1::types; use serde::de::{self, MapAccess, SeqAccess, Visitor}; use serde::ser::SerializeMap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use serde_json::Value; -use std::collections::HashMap; use std::fmt; + +#[allow(clippy::module_inception)] +pub mod chat_completion; +pub mod chat_completion_stream; + #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum ToolChoiceType { None, @@ -40,90 +41,6 @@ pub struct Reasoning { pub enabled: Option, } -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ChatCompletionRequest { - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub n: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub stop: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub presence_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub frequency_penalty: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub logit_bias: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub user: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub seed: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub parallel_tool_calls: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(serialize_with = "serialize_tool_choice")] - pub tool_choice: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning: Option, -} - -impl ChatCompletionRequest { - pub fn new(model: String, messages: Vec) -> Self { - Self { - model, - messages, - temperature: None, - top_p: None, - stream: None, - n: None, - response_format: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - seed: None, - tools: None, - parallel_tool_calls: None, - tool_choice: None, - reasoning: None, - } - } -} - -impl_builder_methods!( - ChatCompletionRequest, - temperature: f64, - top_p: f64, - n: i64, - response_format: Value, - stream: bool, - stop: Vec, - max_tokens: i64, - presence_penalty: f64, - frequency_penalty: f64, - logit_bias: HashMap, - user: String, - seed: i64, - tools: Vec, - parallel_tool_calls: bool, - tool_choice: ToolChoiceType, - reasoning: Reasoning -); - #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum MessageRole { @@ -272,17 +189,6 @@ pub struct ChatCompletionChoice { pub finish_details: Option, } -#[derive(Debug, Deserialize, Serialize)] -pub struct ChatCompletionResponse { - pub id: Option, - pub object: String, - pub created: i64, - pub model: String, - pub choices: Vec, - pub usage: common::Usage, - pub system_fingerprint: Option, -} - #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] #[allow(non_camel_case_types)] pub enum FinishReason { @@ -315,7 +221,7 @@ pub struct ToolCallFunction { pub arguments: Option, } -fn serialize_tool_choice( +pub fn serialize_tool_choice( value: &Option, serializer: S, ) -> Result @@ -347,75 +253,3 @@ pub struct Tool { pub enum ToolType { Function, } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_reasoning_effort_serialization() { - let reasoning = Reasoning { - mode: Some(ReasoningMode::Effort { - effort: ReasoningEffort::High, - }), - exclude: Some(false), - enabled: None, - }; - - let serialized = serde_json::to_value(&reasoning).unwrap(); - let expected = json!({ - "effort": "high", - "exclude": false - }); - - assert_eq!(serialized, expected); - } - - #[test] - fn test_reasoning_max_tokens_serialization() { - let reasoning = Reasoning { - mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }), - exclude: None, - enabled: Some(true), - }; - - let serialized = serde_json::to_value(&reasoning).unwrap(); - let expected = json!({ - "max_tokens": 2000, - "enabled": true - }); - - assert_eq!(serialized, expected); - } - - #[test] - fn test_reasoning_deserialization() { - let json_str = r#"{"effort": "medium", "exclude": true}"#; - let reasoning: Reasoning = serde_json::from_str(json_str).unwrap(); - - match reasoning.mode { - Some(ReasoningMode::Effort { effort }) => { - assert_eq!(effort, ReasoningEffort::Medium); - } - _ => panic!("Expected effort mode"), - } - assert_eq!(reasoning.exclude, Some(true)); - } - - #[test] - fn test_chat_completion_request_with_reasoning() { - let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]); - - req.reasoning = Some(Reasoning { - mode: Some(ReasoningMode::Effort { - effort: ReasoningEffort::Low, - }), - exclude: None, - enabled: None, - }); - - let serialized = serde_json::to_value(&req).unwrap(); - assert_eq!(serialized["reasoning"]["effort"], "low"); - } -} diff --git a/src/v1/common.rs b/src/v1/common.rs index ab16946f..934df91c 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -31,60 +31,144 @@ macro_rules! impl_builder_methods { #[derive(Debug, Serialize, Deserialize)] pub struct EmptyRequestBody {} -// https://platform.openai.com/docs/models/o3 +// O-series models +pub const O1: &str = "o1"; +pub const O1_2024_12_17: &str = "o1-2024-12-17"; +pub const O1_MINI: &str = "o1-mini"; +pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12"; +pub const O1_PREVIEW: &str = "o1-preview"; +pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12"; +pub const O1_PRO: &str = "o1-pro"; +pub const O1_PRO_2025_03_19: &str = "o1-pro-2025-03-19"; + pub const O3: &str = "o3"; pub const O3_2025_04_16: &str = "o3-2025-04-16"; pub const O3_MINI: &str = "o3-mini"; pub const O3_MINI_2025_01_31: &str = "o3-mini-2025-01-31"; -// https://platform.openai.com/docs/models#gpt-4-5 -pub const GPT4_5_PREVIEW: &str = "gpt-4.5-preview"; -pub const GPT4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27"; +pub const O4_MINI: &str = "o4-mini"; +pub const O4_MINI_2025_04_16: &str = "o4-mini-2025-04-16"; +pub const O4_MINI_DEEP_RESEARCH: &str = "o4-mini-deep-research"; +pub const O4_MINI_DEEP_RESEARCH_2025_06_26: &str = "o4-mini-deep-research-2025-06-26"; -// https://platform.openai.com/docs/models/o1 -pub const O1_PREVIEW: &str = "o1-preview"; -pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12"; -pub const O1_MINI: &str = "o1-mini"; -pub const O1_MINI_2024_09_12: &str = "o1-mini-2024-09-12"; +// GPT-5 models +pub const GPT5: &str = "gpt-5"; +pub const GPT5_2025_08_07: &str = "gpt-5-2025-08-07"; +pub const GPT5_CHAT_LATEST: &str = "gpt-5-chat-latest"; +pub const GPT5_CODEX: &str = "gpt-5-codex"; +pub const GPT5_MINI: &str = "gpt-5-mini"; +pub const GPT5_MINI_2025_08_07: &str = "gpt-5-mini-2025-08-07"; +pub const GPT5_NANO: &str = "gpt-5-nano"; +pub const GPT5_NANO_2025_08_07: &str = "gpt-5-nano-2025-08-07"; -// https://platform.openai.com/docs/models/gpt-4o-mini -pub const GPT4_O_MINI: &str = "gpt-4o-mini"; -pub const GPT4_O_MINI_2024_07_18: &str = "gpt-4o-mini-2024-07-18"; +// GPT-4.1 models +pub const GPT4_1: &str = "gpt-4.1"; +pub const GPT4_1_2025_04_14: &str = "gpt-4.1-2025-04-14"; +pub const GPT4_1_MINI: &str = "gpt-4.1-mini"; +pub const GPT4_1_MINI_2025_04_14: &str = "gpt-4.1-mini-2025-04-14"; +pub const GPT4_1_NANO: &str = "gpt-4.1-nano"; +pub const GPT4_1_NANO_2025_04_14: &str = "gpt-4.1-nano-2025-04-14"; -// https://platform.openai.com/docs/models/gpt-4o +// GPT-4o models pub const GPT4_O: &str = "gpt-4o"; pub const GPT4_O_2024_05_13: &str = "gpt-4o-2024-05-13"; pub const GPT4_O_2024_08_06: &str = "gpt-4o-2024-08-06"; +pub const GPT4_O_2024_11_20: &str = "gpt-4o-2024-11-20"; pub const GPT4_O_LATEST: &str = "chatgpt-4o-latest"; -// https://platform.openai.com/docs/models/gpt-3-5 -pub const GPT3_5_TURBO_1106: &str = "gpt-3.5-turbo-1106"; -pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; -pub const GPT3_5_TURBO_16K: &str = "gpt-3.5-turbo-16k"; -pub const GPT3_5_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; -// - legacy -pub const GPT3_5_TURBO_0613: &str = "gpt-3.5-turbo-0613"; -pub const GPT3_5_TURBO_16K_0613: &str = "gpt-3.5-turbo-16k-0613"; -pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301"; +pub const GPT4_O_MINI: &str = "gpt-4o-mini"; +pub const GPT4_O_MINI_2024_07_18: &str = "gpt-4o-mini-2024-07-18"; -// https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo -pub const GPT4_0125_PREVIEW: &str = "gpt-4-0125-preview"; -pub const GPT4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview"; -pub const GPT4_1106_PREVIEW: &str = "gpt-4-1106-preview"; -pub const GPT4_VISION_PREVIEW: &str = "gpt-4-vision-preview"; +// GPT-4o search models +pub const GPT4_O_SEARCH_PREVIEW: &str = "gpt-4o-search-preview"; +pub const GPT4_O_SEARCH_PREVIEW_2025_03_11: &str = "gpt-4o-search-preview-2025-03-11"; +pub const GPT4_O_MINI_SEARCH_PREVIEW: &str = "gpt-4o-mini-search-preview"; +pub const GPT4_O_MINI_SEARCH_PREVIEW_2025_03_11: &str = "gpt-4o-mini-search-preview-2025-03-11"; + +// GPT-4o realtime models +pub const GPT4_O_REALTIME_PREVIEW: &str = "gpt-4o-realtime-preview"; +pub const GPT4_O_REALTIME_PREVIEW_2024_10_01: &str = "gpt-4o-realtime-preview-2024-10-01"; +pub const GPT4_O_REALTIME_PREVIEW_2024_12_17: &str = "gpt-4o-realtime-preview-2024-12-17"; +pub const GPT4_O_REALTIME_PREVIEW_2025_06_03: &str = "gpt-4o-realtime-preview-2025-06-03"; +pub const GPT4_O_MINI_REALTIME_PREVIEW: &str = "gpt-4o-mini-realtime-preview"; +pub const GPT4_O_MINI_REALTIME_PREVIEW_2024_12_17: &str = "gpt-4o-mini-realtime-preview-2024-12-17"; + +// GPT-4o audio models +pub const GPT4_O_AUDIO_PREVIEW: &str = "gpt-4o-audio-preview"; +pub const GPT4_O_AUDIO_PREVIEW_2024_10_01: &str = "gpt-4o-audio-preview-2024-10-01"; +pub const GPT4_O_AUDIO_PREVIEW_2024_12_17: &str = "gpt-4o-audio-preview-2024-12-17"; +pub const GPT4_O_AUDIO_PREVIEW_2025_06_03: &str = "gpt-4o-audio-preview-2025-06-03"; +pub const GPT4_O_MINI_AUDIO_PREVIEW: &str = "gpt-4o-mini-audio-preview"; +pub const GPT4_O_MINI_AUDIO_PREVIEW_2024_12_17: &str = "gpt-4o-mini-audio-preview-2024-12-17"; + +// GPT-4o transcription models +pub const GPT4_O_TRANSCRIBE: &str = "gpt-4o-transcribe"; +pub const GPT4_O_MINI_TRANSCRIBE: &str = "gpt-4o-mini-transcribe"; + +// GPT-4 and GPT-4 Turbo models pub const GPT4: &str = "gpt-4"; -pub const GPT4_32K: &str = "gpt-4-32k"; pub const GPT4_0613: &str = "gpt-4-0613"; +pub const GPT4_32K: &str = "gpt-4-32k"; pub const GPT4_32K_0613: &str = "gpt-4-32k-0613"; -// - legacy pub const GPT4_0314: &str = "gpt-4-0314"; pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; -// https://platform.openai.com/docs/api-reference/images/object +pub const GPT4_TURBO: &str = "gpt-4-turbo"; +pub const GPT4_TURBO_2024_04_09: &str = "gpt-4-turbo-2024-04-09"; +pub const GPT4_TURBO_PREVIEW: &str = "gpt-4-turbo-preview"; +pub const GPT4_0125_PREVIEW: &str = "gpt-4-0125-preview"; +pub const GPT4_1106_PREVIEW: &str = "gpt-4-1106-preview"; +pub const GPT4_VISION_PREVIEW: &str = "gpt-4-vision-preview"; + +// GPT-3.5 Turbo models +pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; +pub const GPT3_5_TURBO_0125: &str = "gpt-3.5-turbo-0125"; +pub const GPT3_5_TURBO_1106: &str = "gpt-3.5-turbo-1106"; +pub const GPT3_5_TURBO_16K: &str = "gpt-3.5-turbo-16k"; +pub const GPT3_5_TURBO_0613: &str = "gpt-3.5-turbo-0613"; +pub const GPT3_5_TURBO_16K_0613: &str = "gpt-3.5-turbo-16k-0613"; +pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301"; + +pub const GPT3_5_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; +pub const GPT3_5_TURBO_INSTRUCT_0914: &str = "gpt-3.5-turbo-instruct-0914"; + +// Audio models +pub const GPT_AUDIO: &str = "gpt-audio"; +pub const GPT_AUDIO_2025_08_28: &str = "gpt-audio-2025-08-28"; +pub const GPT_REALTIME: &str = "gpt-realtime"; +pub const GPT_REALTIME_2025_08_28: &str = "gpt-realtime-2025-08-28"; + +// Text-to-Speech models +pub const TTS_1: &str = "tts-1"; +pub const TTS_1_HD: &str = "tts-1-hd"; +pub const TTS_1_1106: &str = "tts-1-1106"; +pub const TTS_1_HD_1106: &str = "tts-1-hd-1106"; +pub const GPT4_O_MINI_TTS: &str = "gpt-4o-mini-tts"; + +// Speech-to-Text models +pub const WHISPER_1: &str = "whisper-1"; + +// Image generation models pub const DALL_E_2: &str = "dall-e-2"; pub const DALL_E_3: &str = "dall-e-3"; +pub const GPT_IMAGE_1: &str = "gpt-image-1"; -// https://platform.openai.com/docs/guides/embeddings/embedding-models +// Embedding models pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; + +// Moderation models +pub const OMNI_MODERATION_LATEST: &str = "omni-moderation-latest"; +pub const OMNI_MODERATION_2024_09_26: &str = "omni-moderation-2024-09-26"; + +// Legacy models +pub const DAVINCI_002: &str = "davinci-002"; +pub const BABBAGE_002: &str = "babbage-002"; + +// Code models +pub const CODEX_MINI_LATEST: &str = "codex-mini-latest"; + +// Preview models (GPT-4.5) +pub const GPT4_5_PREVIEW: &str = "gpt-4.5-preview"; +pub const GPT4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27"; diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index 3f68054c..67df9c0f 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -21,6 +21,7 @@ pub enum EncodingFormat { pub struct EmbeddingRequest { pub model: String, pub input: Vec, + #[serde(skip_serializing_if = "Option::is_none")] pub encoding_format: Option, #[serde(skip_serializing_if = "Option::is_none")] pub dimensions: Option, diff --git a/src/v1/mod.rs b/src/v1/mod.rs index d44ed319..0dcbcbb6 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -13,6 +13,7 @@ pub mod fine_tuning; pub mod image; pub mod model; pub mod moderation; +pub mod responses; // beta pub mod assistant; diff --git a/src/v1/model.rs b/src/v1/model.rs index 2b0a044d..5aa2bc61 100644 --- a/src/v1/model.rs +++ b/src/v1/model.rs @@ -2,14 +2,51 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] pub struct ModelsResponse { - pub object: String, + pub object: Option, pub data: Vec, } #[derive(Debug, Deserialize, Serialize)] pub struct ModelResponse { - pub id: String, - pub object: String, - pub created: i64, - pub owned_by: String, + pub id: Option, + pub name: Option, + pub created: Option, + pub description: Option, + pub architecture: Option, + pub top_provider: Option, + pub pricing: Option, + pub canonical_slug: Option, + pub context_length: Option, + pub hugging_face_id: Option, + pub per_request_limits: Option, + pub supported_parameters: Option>, + pub object: Option, + pub owned_by: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Architecture { + pub input_modalities: Option>, + pub output_modalities: Option>, + pub tokenizer: Option, + pub instruct_type: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TopProvider { + pub is_moderated: Option, + pub context_length: Option, + pub max_completion_tokens: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Pricing { + pub prompt: Option, + pub completion: Option, + pub image: Option, + pub request: Option, + pub web_search: Option, + pub internal_reasoning: Option, + pub input_cache_read: Option, + pub input_cache_write: Option, } diff --git a/src/v1/responses.rs b/src/v1/responses.rs new file mode 100644 index 00000000..348b1fea --- /dev/null +++ b/src/v1/responses.rs @@ -0,0 +1,312 @@ +use crate::v1::types::Tools; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::BTreeMap; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CreateResponseRequest { + // background + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + + // conversation + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation: Option, + + // include + #[serde(skip_serializing_if = "Option::is_none")] + pub include: Option>, + + // input + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + + // instructions + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + // max_output_tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + + // max_tool_calls + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + + // metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + // model + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + // parallel_tool_calls + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + // previous_response_id + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + // prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + // prompt_cache_key + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + + // reasoning + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + // safety_identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_identifier: Option, + + // service_tier + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + + // store + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + + // stream + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + + // stream_options + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + + // temperature + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + // text + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + // tool_choice + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + // tools + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + // top_logprobs + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + // top_p + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + // truncation + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + + // user (deprecated) + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + // Future-proof + #[serde(flatten)] + pub extra: BTreeMap, +} + +impl CreateResponseRequest { + pub fn new() -> Self { + Self { + background: None, + conversation: None, + include: None, + input: None, + instructions: None, + max_output_tokens: None, + max_tool_calls: None, + metadata: None, + model: None, + parallel_tool_calls: None, + previous_response_id: None, + prompt: None, + prompt_cache_key: None, + reasoning: None, + safety_identifier: None, + service_tier: None, + store: None, + stream: None, + stream_options: None, + temperature: None, + text: None, + tool_choice: None, + tools: None, + top_logprobs: None, + top_p: None, + truncation: None, + user: None, + extra: BTreeMap::new(), + } + } +} + +impl Default for CreateResponseRequest { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseObject { + pub id: String, + pub object: String, + + // Core + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + + // Output + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_audio: Option, + + // Control / reasons + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + + // Tools + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option, + + // Misc + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + + // Errors / details + #[serde(skip_serializing_if = "Option::is_none")] + pub status_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + // Future-proof + #[serde(flatten)] + pub extra: BTreeMap, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ListResponses { + pub object: String, + pub data: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub first_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_id: Option, + pub has_more: bool, +} + +// Get input token counts (POST /v1/responses/input_tokens) +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CountTokensRequest { + // conversation + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation: Option, + + // input + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + + // instructions + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + // model + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + // parallel_tool_calls + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + // previous_response_id + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + // reasoning + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + // text + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + // tool_choice + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + // tools + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + // truncation + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + + // Future-proof + #[serde(flatten)] + pub extra: BTreeMap, +} + +impl CountTokensRequest { + pub fn new() -> Self { + Self { + conversation: None, + input: None, + instructions: None, + model: None, + parallel_tool_calls: None, + previous_response_id: None, + reasoning: None, + text: None, + tool_choice: None, + tools: None, + truncation: None, + extra: BTreeMap::new(), + } + } +} + +impl Default for CountTokensRequest { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CountTokensResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(flatten)] + pub extra: BTreeMap, +}